Wednesday, October 28, 2015

Recursive Mean & Standard Deviation in Scala

Implementation of the recursive computation of the mean and standard deviation of a very large data set in Scala using tail recursion

Overview
The computation of the mean and standard deviation of a very large data set may cause overflow of the summation of values. Scala tail recursion is a very good alternative to compute mean and standard deviation for data set of unlimited size.

Direct computation
There are many ways to compute the standard deviation through summation. The first mathematical expression consists of the sum the difference between each data point and the mean.
\[\sigma =\sqrt{\frac{\sum_{0}^{n-1}(x-\mu )^{2}}{n}}\]
The second formula allows to update the mean and standard deviation with any new data point (online computation).
\[\sigma =\sqrt{\frac{1}{n}\sum_{0}^{n-1}x^{2}-{\mu ^{2}}}\]
This second approach relies on the computation the sum of square values that can overflow

1
2
3
4
val x = Array[Double]( /* ... */ )
val mean = x.sum/x.length
val stdDev = Math.sqrt((x.map( _ - mean)
    .map(t => t*t).sum)/x.length)

A reduceLeft can be used as an alternative of map{ ... }.sum for the computation of the standard deviation (line 3).

Recursive computation
There is actually no need to compute the sum and the sum of squared values to compute the mean and standard deviation. The mean and standard deviation for n observations can be computed from the mean and standard deviation of n-1 observations.
The recursive formula for the mean is
\[\mu _{n}=\left ( 1-\frac{1}{n} \right )\mu _{n-1}+\frac{x_{n}}{n}\] The recursive formula for the standard deviation is
\[\varrho _{n}=\varrho _{n-1}+(x_{n}-\mu _{n})(x_{n}-\mu _{n-1}) \ \ \ \ \sigma _{n} = \sqrt{\frac{\varrho _{n}}{n}}\]
Let's implement these two recursive formula in Scala using the tail recursion (line 4).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
def meanStd(x: Array[Double]): (Double, Double) ={
    
  @scala.annotation.tailrec
  def meanStd(
      x: Array[Double], 
      mu: Double, 
      Q: Double, 
      count: Int): (Double, Double) = 
    if (count >= x.length) (mu, Math.sqrt(Q/x.length))
    else {
      val newCount = count +1
      val newMu = x(count)/newCount + mu * (1.0 - 1.0/newCount)
      val newQ = Q + (x(count) - mu)*(x(count) - newMu)
      meanStd(x, newMu, newQ, newCount)
    }
  
  meanStd(x, 0.0, 0.0, 0)
}

This implementation update the mean and the standard deviation for each new data point simultaneously. The recursion exits when all elements have been accounted for (line 9).

No comments:

Post a Comment