**Overview**

The basic components of the implementation of K-means clustering algorithms has been introduced in the previous post K-means clustering in Java: Elements

This second part on the implementation of K-means in Java describes the two main tasks of any unsupervised or supervised learning algorithms:

This second part on the implementation of K-means in Java describes the two main tasks of any unsupervised or supervised learning algorithms:

- training: executed off-line during analysis of historical data
- classification: executed at run-time to classify new obsdervations

**Note**: For the sake of readability of the implementation of algorithms, all non-essential code such as error checking, comments, exception, validation of class and method arguments, scoping qualifiers or import is omitted**Training**

The learning method,

*train*, implements the clustering algorithm. It iterates to minimize the sum of distances between all cluster data points & its centroid.

For each iteration (or epoch) the

*train*method:

- assign observations to each cluster
- compute the centroid for each cluster,
*computeCentroid* - compute the total distance of all the observations with their respective centroid
*computeTotalDistance* - estimate the closest cluster for each observation
- re-assign the observation,
*updateCentroids*

public int train() { int numIter = _maxIters, k = 0 boolean inProgress = true; initialize(); while(inProgress) { for(KmeansCluster cluster : _clusters ) { cluster.attach(_obsList[k]); if( ++k >= _obsList.length) { inProgress = false; break; } } } computeTotalDistance(); for(KmeansCluster cluster : _clusters ) { cluster.computeCentroid(); } computeTotalDistance(); List<Observation> obsList = null; KmeansCluster closestCluster = null; // main iterative method, that traverses all the clusters // computes the distance of observations relative to their centroid // and re-assign the observations. for(int i = 0; i < _maxIterations; i++) { for(KmeansCluster cluster : _clusters ) { obsList = new ArrayList<Observation>(); for( Observation point : cluster.getDataPointsList()) { obsList.add(point); } for( Observation point : obsList) { double minDistance = Double.MAX_VALUE, distance = 0.0; closestCluster = null; for(KmeansCluster cursor : _clusters ) { distance = point.computeDistance(cursor.getCentroid()); if( minDistance > distance) { minDistance = distance; closestCluster = cursor; } } updateCentroids(point, cluster, closestCluster); } } // Simple convergence criteria if( _convergeCounter >= _minNumConvergeIters ) { numIters= i; break; } } return numIters; }

**Classification**

The classification of a new observations is simple and consists in evaluating the distance between the new data point and each centroid and selecting the cluster with the shortest distance. The classify method extract the index or label of the cluster that is the most suitable (closest in distance) to the new observation.

public int classify(double[] obs) { double bestScore = Double.MAX_VALUE, distance = 0.0; int clusterId = -1; for(int k = 0; k < _centroids.length; k++) { distance = _centroids[k].computeDistance(obs); if(distance < bestScore) { bestScore = distance; clusterId = k; } } return clusterId; }

The code snippet below implements some of the supporting method to

- compute the loss function value (total distance) - initialize the centroid for each cluster - update the values of centroids.

private void computeTotalDistance() { float totalDistance = 0.0F; for(KmeansCluster cluster : _clusters ) { totalDistance += cluster.getSumDistances(); } double error = Math.abs(_totalDistance - totalDistance); convergeCounter = ( error < _convergeCriteria) ? convergeCounter +1 : 0; _totalDistance = totalDistance; } private void initialize() { double[] params = getParameters(); int numVariables = params.length>>1 double[] range = new double[numVariables]; for( int k = 0, j = numVariables; k <numVariables; k++, j++ ) { range[k] = params[k] - params[j]; } double[] x = new double[numVariables]; int sz_1 = _clusters.length+1, m = 1; for(KmeansCluster cluster : _clusters) { for( int k = 0, j = numVariables; k <numVariables; k++, j++ ) { x[k] = ((range[k]/sz_1)*m) + params[j]; } cluster.setCentroid(x); m++; } } private void updateCentroids(Observation point, KmeansCluster cluster, KmeansCluster bestCluster) { boolean update = bestCluster != null && bestCluster != cluster; if( update ) { bestCluster.attach(point); cluster.detach(point); for(KmeansCluster cursor : _clusters ) { cursor.computeCentroid(); } computeTotalDistance(); } }

**References**

*The Elements of Statistical Learning - T. Hastie, R.Tibshirani, J. Friedman - Springer 2001*

Machine Learning: A Probabilisitc Perspective 11.4.2.5 K-means algorithm - K. Murphy - MIT Press 2012

Pattern Recognition and Machine Learning: Chap 9 "Mixture Models and EM: K-means Clustering" C.Bishop - Springer Science 2006