Wednesday, October 31, 2018

K-means Clustering in Java II: Classification

Target audience: Intermediate
Estimated reading time: 20'

The basic components of the implementation of K-means clustering algorithms has been introduced in the previous post K-means clustering in Java: Elements

This second part on the implementation of K-means in Java describes the two main tasks of any unsupervised or supervised learning algorithms:
  • training: executed off-line during analysis of historical data
  • classification: executed at run-time to classify new obsdervations
Note: For the sake of readability of the implementation of algorithms, all non-essential code such as error checking, comments, exception, validation of class and method arguments, scoping qualifiers or import is omitted.

The learning method, train, implements the clustering algorithm. It iterates to minimize the sum of distances between all cluster data points & its centroid.
For each iteration (or epoch) the train method:
  1. assign observations to each cluster
  2. compute the centroid for each cluster, computeCentroid
  3. compute the total distance of all the observations with their respective centroid computeTotalDistance
  4. estimate the closest cluster for each observation
  5. re-assign the observation, updateCentroids
public int train() {
  int numIter = _maxIters, k = 0
  boolean inProgress = true;
  while(inProgress) {
     for(KmeansCluster cluster : _clusters ) {
        if( ++k >= _obsList.length) {
           inProgress = false;
  for(KmeansCluster cluster : _clusters ) {
  List<Observation> obsList = null; 
  KmeansCluster closestCluster = null;

   // main iterative method, that traverses all the clusters
   // computes the distance of observations relative to their centroid
   // and re-assign the observations.
  for(int i = 0; i < _maxIterations; i++) {
    for(KmeansCluster cluster : _clusters ) { 
      obsList = new ArrayList<Observation>();
      for( Observation point : cluster.getDataPointsList()) {
      for( Observation point : obsList) {
        double minDistance = Double.MAX_VALUE, distance = 0.0;
        closestCluster = null;
        for(KmeansCluster cursor : _clusters ) {
          distance =  point.computeDistance(cursor.getCentroid());
         if( minDistance >  distance) {
            minDistance = distance;
            closestCluster = cursor;
       updateCentroids(point, cluster, closestCluster);
     // Simple convergence criteria              
   if( _convergeCounter >= _minNumConvergeIters ) {
     numIters= i;
 return numIters;

The classification of a new observations is simple and consists in evaluating the distance between the new data point and each centroid and selecting the cluster with the shortest distance. The classify method extract the index or label of the cluster that is the most suitable (closest in distance) to the new observation.

public int classify(double[] obs) {
  double bestScore = Double.MAX_VALUE, distance = 0.0;
  int clusterId = -1;
  for(int k = 0; k < _centroids.length; k++) {
     distance = _centroids[k].computeDistance(obs);
     if(distance < bestScore) {
        bestScore = distance;
        clusterId = k;
  return clusterId;

The code snippet below implements some of the supporting method to
- compute the loss function value (total distance) - initialize the centroid for each cluster - update the values of centroids.

private void computeTotalDistance() {
  float totalDistance = 0.0F;
  for(KmeansCluster cluster : _clusters ) {
     totalDistance += cluster.getSumDistances();
  double error = Math.abs(_totalDistance - totalDistance);
  convergeCounter = ( error < _convergeCriteria) ? convergeCounter +1 : 0;      
  _totalDistance = totalDistance;

private void initialize() {
   double[] params = getParameters();
   int numVariables = params.length>>1
   double[] range = new double[numVariables];
   for( int k = 0, j = numVariables; k <numVariables; k++, j++ ) {
      range[k] = params[k] - params[j];
   double[] x = new double[numVariables];
   int sz_1 = _clusters.length+1,  m = 1;
   for(KmeansCluster cluster : _clusters) {
      for( int k = 0, j = numVariables; k <numVariables; k++, j++ ) {
         x[k] = ((range[k]/sz_1)*m) + params[j];
private void updateCentroids(Observation point,
                             KmeansCluster cluster, 
                             KmeansCluster bestCluster) {
  boolean update = bestCluster != null && bestCluster != cluster; 
  if( update ) {
     for(KmeansCluster cursor : _clusters ) {

  • The Elements of Statistical Learning   - T. Hastie, R.Tibshirani, J. Friedman  - Springer 2001
  • Machine Learning: A Probabilisitc Perspective K-means algorithm - K. Murphy - MIT Press 2012
  • Pattern Recognition and Machine Learning: Chap 9 "Mixture Models and EM: K-means Clustering" C.Bishop - Springer Science 2006 


  1. Hi, Great.. Tutorial is just awesome..It is really helpful for a newbie like me.. I am a regular follower of your blog. Really very informative post you shared here. Kindly keep blogging. If anyone wants to become a Java developer learn from Java Training in Chennai. or learn thru Java Online Training India . Nowadays Java has tons of job opportunities on various vertical industry.

  2. I can only express a word of thanks! Nothing else. Because your topic is nice, you can add knowledge. Thank you very much for sharing this information.

    Avriq India
    pest control
    cctv camera

  3. It's hence the right time for you to ride the growth and build a career that you will be proud of. data science course syllabus

  4. I’m happy I located this blog! From time to time, students want to cognitive the keys of productive literary essays composing. Your first-class knowledge about this good post can become a proper basis for such people. nice one
    data science course in India