Wednesday, October 31, 2012

K-means Clustering in Java II: Classification

Overview
The basic components of the implementation of K-means clustering algorithms has been introduced in the previous post K-means clustering in Java: Elements

This second part on the implementation of K-means in Java describes the two main tasks of any unsupervised or supervised learning algorithms:
  • training: executed off-line during analysis of historical data
  • classification: executed at run-time to classify new obsdervations
Note: For the sake of readability of the implementation of algorithms, all non-essential code such as error checking, comments, exception, validation of class and method arguments, scoping qualifiers or import is omitted

Training
The learning method, train, implements the clustering algorithm. It iterates to minimize the sum of distances between all cluster data points & its centroid.
For each iteration (or epoch) the train method:
  1. assign observations to each cluster
  2. compute the centroid for each cluster, computeCentroid
  3. compute the total distance of all the observations with their respective centroid computeTotalDistance
  4. estimate the closest cluster for each observation
  5. re-assign the observation, updateCentroids

public int train() {
  int numIter = _maxIters, k = 0
  boolean inProgress = true;
   
  initialize();  
  while(inProgress) {
     for(KmeansCluster cluster : _clusters ) {
        cluster.attach(_obsList[k]);
        if( ++k >= _obsList.length) {
           inProgress = false;
           break;
        }
     }
  }
               
  computeTotalDistance();
  for(KmeansCluster cluster : _clusters ) {
     cluster.computeCentroid();
  }           
  computeTotalDistance();
  
  List<Observation> obsList = null; 
  KmeansCluster closestCluster = null;

   // main iterative method, that traverses all the clusters
   // computes the distance of observations relative to their centroid
   // and re-assign the observations.
  for(int i = 0; i < _maxIterations; i++) {
    for(KmeansCluster cluster : _clusters ) { 
      obsList = new ArrayList<Observation>();
      for( Observation point : cluster.getDataPointsList()) {
        obsList.add(point);
      }
   
      for( Observation point : obsList) {
        double minDistance = Double.MAX_VALUE, distance = 0.0;
        closestCluster = null;
     
        for(KmeansCluster cursor : _clusters ) {
          distance =  point.computeDistance(cursor.getCentroid());
         if( minDistance >  distance) {
            minDistance = distance;
            closestCluster = cursor;
         }
       }
       updateCentroids(point, cluster, closestCluster);
     }
   }
     // Simple convergence criteria              
   if( _convergeCounter >= _minNumConvergeIters ) {
     numIters= i;
     break;
   }
 }
 return numIters;
}

Classification
The classification of a new observations is simple and consists in evaluating the distance between the new data point and each centroid and selecting the cluster with the shortest distance. The classify method extract the index or label of the cluster that is the most suitable (closest in distance) to the new observation.

public int classify(double[] obs) {
  double bestScore = Double.MAX_VALUE, distance = 0.0;
  int clusterId = -1;
       
  for(int k = 0; k < _centroids.length; k++) {
     distance = _centroids[k].computeDistance(obs);
     if(distance < bestScore) {
        bestScore = distance;
        clusterId = k;
     }
  }
  return clusterId;
}

The code snippet below implements some of the supporting method to
- compute the loss function value (total distance) - initialize the centroid for each cluster - update the values of centroids.


private void computeTotalDistance() {
  float totalDistance = 0.0F;
     
  for(KmeansCluster cluster : _clusters ) {
     totalDistance += cluster.getSumDistances();
  }
  
  double error = Math.abs(_totalDistance - totalDistance);
  convergeCounter = ( error < _convergeCriteria) ? convergeCounter +1 : 0;      
  _totalDistance = totalDistance;
}

private void initialize() {
   double[] params = getParameters();
   int numVariables = params.length>>1
      
   double[] range = new double[numVariables];
   for( int k = 0, j = numVariables; k <numVariables; k++, j++ ) {
      range[k] = params[k] - params[j];
   }
        
   double[] x = new double[numVariables];
   int sz_1 = _clusters.length+1,  m = 1;
      
   for(KmeansCluster cluster : _clusters) {
      for( int k = 0, j = numVariables; k <numVariables; k++, j++ ) {
         x[k] = ((range[k]/sz_1)*m) + params[j];
      }
      cluster.setCentroid(x);
      m++;
   }
}
 
private void updateCentroids(Observation point,
                             KmeansCluster cluster, 
                             KmeansCluster bestCluster) {
  boolean update = bestCluster != null && bestCluster != cluster; 
  if( update ) {
     bestCluster.attach(point);
     cluster.detach(point);
     for(KmeansCluster cursor : _clusters ) {
        cursor.computeCentroid();
     }
     computeTotalDistance();
  }
}



References
The Elements of Statistical Learning   - T. Hastie, R.Tibshirani, J. Friedman  - Springer 2001
Machine Learning: A Probabilisitc Perspective 11.4.2.5 K-means algorithm - K. Murphy - MIT Press 2012
Pattern Recognition and Machine Learning: Chap 9 "Mixture Models and EM: K-means Clustering" C.Bishop - Springer Science 2006

No comments:

Post a Comment