Wednesday, April 15, 2015

Recursive Viterbi Algorithm for Hidden Markov Model

Description and implementation of the Viterbi algorithm for hidden Markov models using Scala.

Overview
A significant number of algorithms used in machine learning relies on dynamic programming and hidden Markov model (HMM) is no exception. Scala tail elimination represents an excellent alternative to the traditional iterative implementation of the 3 canonical forms of the HMM.

Introduction to HMM
Markov processes, and more specifically HMM, are commonly used in speech recognition, language translation, and text classification, document tagging, data compression and decoding.
A HMM algorithm uses 3 key components

  • A set of observations
  • A sequence of hidden states
  • A model that maximizes the joint probability of the observations and hidden states, known as the Lambda model
HMM are usually visualized as lattice of states with transition and observations.
There are 3 use cases (or canonical forms) of the HMM.
  • Evaluation: Evaluate the probability of a given sequence of observations, given a model
  • Training: Identify (or learn) a model given a set of observations
  • Decoding Estimate the state sequence with the highest probability to generate a given as set of observations and a model
The last use case, Decoding, is implemented numerically using the Viterbi algorithm.

Viterbi recursive implementation
Given a sequence of states {qt} and sequence of observations {oj}, the probability δt(i) for any sequence to have the highest probability path for the first T observations is defined for the state Si.
  • Delta: sequence to have the highest probability path for the first i observations is defined for a specific test δt(i)
  • psi Matrix that contains the indices used in the maximization of the probabilities of pre-defined states
  • Qstar: the optimum sequence q* of states Q0:T-1
The state of the Viterbi/HMM computation regarding the state transition and emission matrices is defined in the HMMState class

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
final class HMMState(
   val lambda: HMMLambda, 
   val maxIters: Int) {
  
  // Matrix of elements (t, i) that defines the highest 
  // probability of a single path of t 
  // observations reaching state S(i)
  val delta = Matrix[Double](lambda.getT, lambda.getN)
 
  // Auxiliary matrix of indices that maximize the probability 
  //of a given sequence of states
  val psi = Matrix[Int](lambda.getT, lambda.getN)

  // Singleton to compute the sequence Q* of states with 
  // the highest probability given a sequence 
  // of observations.
  object QStar {
    private val qStar = Array.fill(lambda.getT)(0)

    // Update Q* the optimum sequence of state using backtracking.. 
    def update(t: Int, index: Int): Unit
       ...
  }
}

The class HMMLambda contains the three components of the Markov processes
  • State transition matrix A
  • Observation emission matrix B
  • Initial state probabilities, pi

The algorithm is conveniently illustrated by the following diagram.

First, Let's create the key member and method for the Lambda model for HMM. The model is defined as a tuple of the transition probability matrix A, emission probability matrix B and the initial probability π

Then let's define the basic components for implementing the Viterbi algorithm. The class ViterbiPath encapsulates the member variables critical to the recursion.
The Viterbi algorithm is fully defined with

  • lambda: Lambda model as described in the previous section (line 2)
  • state: State of the computation (line 3)
  • obsIndx: Index of observed states (line 4)
1
2
3
4
5
6
7
8
class ViterbiPath(
   lambda: HMMLambda, 
   state: HMMState, 
   obsIndx: Array[Int]) {

  val maxDelta = recurse(lambda.config._T, 0)
   // ...
}

The recursive method, recurse that implements the formula or steps defined earlier. The method relies on the tail recursion. Tail recursion or tail elimination algorithm avoid the creation of a new stack frame for each invocation, improving the performance of the entire recursion.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
@scala.annotation.tailrec
private def recurse(t: Int): Double = {
  
  // Initialization of the delta value, 
  // return -1.0 in case of error
  if( t == 0)
    initial 

   // then for the subsequent observations ...
  else { 
    // Update the maximum delta value and its state index 
    // for the observation t
    Range(0, lambda.getN).foreach( updateMaxDelta(t, _) )
   
     // If we reached the last observation... 
     //exit by backtracing the  computation of the 
    if( t ==  obs.size-1) {
      val idxMaxDelta = Range(0, lambda.getN).map(i => 
               (i, state.delta(t, i))).maxBy(_._2)

   // Update the Q* value with the index that maximize 
   //the delta.A
      state.QStar.update(t+1, idxMaxDelta._1)
      idxMaxDelta._2
    }
    else    
      recurse(t+1)
  }
}

Once initialized (line 7), the maximum value of delta, maxDelta, is computed through the method updateMaxDelta after the first iteration (line 13). Once the step t reaches the maximum number of observation labels, last index in the sequence of observations obs.size-1) (line 17), the optimum sequence of state q* / state.QStar is updated (line 23). The index of the column of the transition matrix A, idxMaxDelta corresponding to the maximum of delta is computed (lines 18, 19). The last step is to update the matrix QStar (line 23).

References
Scala for machine learning Packt Publishing 2014
Pattern Recognition and Machine Learning; Chap 13 Sequential Data/Hidden Markov Models C. Bishop Springer 2009

No comments:

Post a Comment