Target audience: Intermediate
Estimated reading time: 20'
The basic components of the implementation of K-means clustering algorithms has been introduced in the previous post K-means clustering in Java: Elements
This second part on the implementation of K-means in Java describes the two main tasks of any unsupervised or supervised learning algorithms:
This second part on the implementation of K-means in Java describes the two main tasks of any unsupervised or supervised learning algorithms:
- training: executed off-line during analysis of historical data
- classification: executed at run-time to classify new obsdervations
Training
The learning method, train, implements the clustering algorithm. It iterates to minimize the sum of distances between all cluster data points & its centroid.
For each iteration (or epoch) the train method:
- assign observations to each cluster
- compute the centroid for each cluster, computeCentroid
- compute the total distance of all the observations with their respective centroid computeTotalDistance
- estimate the closest cluster for each observation
- re-assign the observation, updateCentroids
public int train() { int numIter = _maxIters, k = 0 boolean inProgress = true; initialize(); while(inProgress) { for(KmeansCluster cluster : _clusters ) { cluster.attach(_obsList[k]); if( ++k >= _obsList.length) { inProgress = false; break; } } } computeTotalDistance(); for(KmeansCluster cluster : _clusters ) { cluster.computeCentroid(); } computeTotalDistance(); List<Observation> obsList = null; KmeansCluster closestCluster = null; // main iterative method, that traverses all the clusters // computes the distance of observations relative to their centroid // and re-assign the observations. for(int i = 0; i < _maxIterations; i++) { for(KmeansCluster cluster : _clusters ) { obsList = new ArrayList<Observation>(); for( Observation point : cluster.getDataPointsList()) { obsList.add(point); } for( Observation point : obsList) { double minDistance = Double.MAX_VALUE, distance = 0.0; closestCluster = null; for(KmeansCluster cursor : _clusters ) { distance = point.computeDistance(cursor.getCentroid()); if( minDistance > distance) { minDistance = distance; closestCluster = cursor; } } updateCentroids(point, cluster, closestCluster); } } // Simple convergence criteria if( _convergeCounter >= _minNumConvergeIters ) { numIters= i; break; } } return numIters; }
Classification
The code snippet below implements some of the supporting method to
- compute the loss function value (total distance) - initialize the centroid for each cluster - update the values of centroids.
References
The classification of a new observations is simple and consists in evaluating the distance between the new data point and each centroid and selecting the cluster with the shortest distance. The classify method extract the index or label of the cluster that is the most suitable (closest in distance) to the new observation.
public int classify(double[] obs) { double bestScore = Double.MAX_VALUE, distance = 0.0; int clusterId = -1; for(int k = 0; k < _centroids.length; k++) { distance = _centroids[k].computeDistance(obs); if(distance < bestScore) { bestScore = distance; clusterId = k; } } return clusterId; }
The code snippet below implements some of the supporting method to
- compute the loss function value (total distance) - initialize the centroid for each cluster - update the values of centroids.
private void computeTotalDistance() { float totalDistance = 0.0F; for(KmeansCluster cluster : _clusters ) { totalDistance += cluster.getSumDistances(); } double error = Math.abs(_totalDistance - totalDistance); convergeCounter = ( error < _convergeCriteria) ? convergeCounter +1 : 0; _totalDistance = totalDistance; } private void initialize() { double[] params = getParameters(); int numVariables = params.length>>1 double[] range = new double[numVariables]; for( int k = 0, j = numVariables; k <numVariables; k++, j++ ) { range[k] = params[k] - params[j]; } double[] x = new double[numVariables]; int sz_1 = _clusters.length+1, m = 1; for(KmeansCluster cluster : _clusters) { for( int k = 0, j = numVariables; k <numVariables; k++, j++ ) { x[k] = ((range[k]/sz_1)*m) + params[j]; } cluster.setCentroid(x); m++; } } private void updateCentroids(Observation point, KmeansCluster cluster, KmeansCluster bestCluster) { boolean update = bestCluster != null && bestCluster != cluster; if( update ) { bestCluster.attach(point); cluster.detach(point); for(KmeansCluster cursor : _clusters ) { cursor.computeCentroid(); } computeTotalDistance(); } }
References
- The Elements of Statistical Learning - T. Hastie, R.Tibshirani, J. Friedman - Springer 2001
- Machine Learning: A Probabilisitc Perspective 11.4.2.5 K-means algorithm - K. Murphy - MIT Press 2012
- Pattern Recognition and Machine Learning: Chap 9 "Mixture Models and EM: K-means Clustering" C.Bishop - Springer Science 2006
Hi, Great.. Tutorial is just awesome..It is really helpful for a newbie like me.. I am a regular follower of your blog. Really very informative post you shared here. Kindly keep blogging. If anyone wants to become a Java developer learn from Java Training in Chennai. or learn thru Java Online Training India . Nowadays Java has tons of job opportunities on various vertical industry.
ReplyDeleteGreat Article
DeleteArtificial Intelligence Projects
Project Center in Chennai
JavaScript Training in Chennai
JavaScript Training in Chennai
I can only express a word of thanks! Nothing else. Because your topic is nice, you can add knowledge. Thank you very much for sharing this information.
ReplyDeleteAvriq India
avriq
pest control
cctv camera
It's hence the right time for you to ride the growth and build a career that you will be proud of. data science course syllabus
ReplyDeleteI’m happy I located this blog! From time to time, students want to cognitive the keys of productive literary essays composing. Your first-class knowledge about this good post can become a proper basis for such people. nice one
ReplyDeletedata science course in India