Friday, October 27, 2017

Reinforcement learning in Scala

You may wonder how robots, autonomous systems or a software game player learn. The answer lies in a field of AI known as reinforcement learning. For example, a robot navigating a maze plans his next move according to its current location and previous moves. Teaching a robot all possible move according to the different location in the maze is not realistic, making any supervised learning technique inadequate. This article describes a very common reinforcement learning methodology, Q-learning and its implementation in Scala.

Overview
There are many different reinforcement learning techniques. One of the most commonly used method is searching the value function space using temporal difference method.
All known reinforcement learning methods share the same objective of solving the problem of finding the optimum sequential decision tasks. In a sequential decision task, an agent interacts with a dynamic system by selecting actions that affect the transition between states in order to optimize a given reward function.

At any given step i, the agent select an action a(i) on the current state s(i). The dynamic system responds by rewarding the agent for its optimal selection of the next state:\[s_{i+1}=V(s_{i})\]
The learning agent infers the policy that maps the set of states {s} to the set of available actions {a}, using a value function  \[V(s_{i})\] The policy is defined at \[\pi :\,\{s_{i}\} \mapsto \{a_{i}\} \left \{ s_{i}|s_{i+1}=V(s_{i}) \right \}\]


Temporal Difference
The most common approach of learning a value function V is to use the Temporal Difference method (TD). The method uses observations of prediction differences from consecutive states, s(i) & s(i+1). If we note r the reward for selection an action from state s(i) to s(i+1) and n the learning rate, then the value V is updated as \[V(s_{i})\leftarrow V(s_{i})+\eta .(V(s_{i+1}) -V(s_{i}) + r_{i})\]
Therefore the goal of the temporal difference method is to learn the value function for the optimal policy. The 'action-value' function represents the expected value of action a on a state s and defined as \[Q(s_{i},a_{i}) = r(s_{i}) + V(s_{i})\] where r is the reward value for the state.

On-policies vs. Off-policy
The Temporal Difference method relies on the estimate of the final reward to be computed for each state. There are two methods of the Temporal Difference algorithm:On-Policy and Off-Policy:
  - On-Policy method learns the value of the policy used to make the decision. The value function is derived from the execution of actions using the same policy but based on history
 - Off-Policy method learns potentially different policies. Therefore the estimate is computed using actions that have not been executed yet.
The most common formula for temporal difference approach is the Q-learning formula. It introduces the concept of discount rate to reduce the impact of the first few states on the optimization of the policy. It does not need a model of its environment. The exploitation of action-value approach consists of selecting the next state is by computing the action with the maximum reward. Conversely the exploration approach focus on the total anticipated reward.The update equation for the Q-Learning is \[Q(s_{i},a_{i}) \leftarrow Q(s_{i},a_{i}) + \eta .(r_{i+1} +\alpha .max_{a_{i+1}}Q(s_{i+1},a_{i+1}) - Q(s_{i},a_{i}))\] \[Q(s_{i},a_{i}): \mathrm{expected\,value\,action\,a\,on\,state\,s}\,\,\eta : \mathrm{learning\,rate}\,\,\alpha : \mathrm{discount\,rate}\] . One of the most commonly used On-Policy method is Sarsa which does not necessarily select the action that offer the most value.The update equation is defined as\[Q(s_{i},a_{i}) \leftarrow Q(s_{i},a_{i}) + \eta .(r_{i+1} +\alpha .Q(s_{i+1},a_{i+1}) - Q(s_{i},a_{i}))\]
States and Actions
Functional languages are particularly suitable for iterative computation. We use Scala for the implementation of the temporal difference algorithm. We allow the user to specify any variant of the learning formula, using local functions or closures.
Firstly, we have to define a state class, QLState (line 1) that contains a list of actions of type QLAction (line 3) that can be executed from this state. The only purpose of this class is to connect a list of action to a source state. The parameterized class argument property (line 4) is used to "attach" some extra characteristics to this state.

1
2
3
4
5
6
7
8
class QLState[T](
  val id: Int, 
  val actions: List[QLAction[T]] = List.empty, 
  property: T) {
    
  @inline
  def isGoal: Boolean = !actions.isEmpty
}

As described in the introduction, an action of class QLAction has a source state from and a destination state to(state which is reached following the action). A state except the goal state, has multiple actions but an action has only one destination or resulting state.

case class QLAction[T](from: Int, to: Int)

The state and action can be loaded, generated and managed by a directed graph or search space of type QLSpace. The search space contains the list of all the possible states available to the agent.
One or more of these states can be selected as goals. The algorithm does not restrict the agent to a single state. The process ends when one of the goal states is reached (OR logic). The algorithm does not support combined goals (AND logic).

Let's implement the basic components of the search space QLSpace. The class list all available states (line 2) and one or more final or goal states goalIds (line 3). Although you would expect that the search space contains a single final or goal state, it is not uncommon to have online training using more than one goal state.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class QLSpace[T](
   states: Array[QLState[T]], 
   goalIds: Array[Int]) {

    // Indexed map of states 
  val statesMap: immutable.Map[Int, QLState[T]] = 
    states.map(st => (st.id, st)).toMap
    // List set of one or more goals  
  val goalStates = new immutable.HashSet[Int]() ++ goalIds
 
    // Compute the maximum Q value for a given state and policy
  def maxQ(st: QLState[T], policy: QLPolicy[T]): Double = { 
    val best = states.filter( _ != st)
       .maxBy(_st => policy.EQ(st.id, _st.id))
    policy.EQ(st.id, best.id)
  }
 
    // Retrieves the list of states destination of state, st
  def nextStates(st: QLState[T]): List[QLState[T]] =
     st.actions.map(ac => statesMap.get(ac.to).get)
 
  def init(r: Random): QLState[T] = 
    states(r.nextInt(states.size-1))
}

A hash map statesMap maintains a dictionary of all the possible states with the state id as unique key (lines 6, 7). The class QLSpace has three important methods:
  • init initializes the search with a random state for each training epoch (lines 22, 23)
  • nextStates returns the list of destination states associated to the state st (lines 19, 20)
  • maxQ return the maximum Q-value for this state st given the current policy policy(lines 12-15). The method filters out itself from the search from the next best action. It then compute the maximum reward or Q(state, action) value according to the given policy policy

The next step is to defined a policy.

Learning Policy
A policy is defined by three components
  • A reward collected after transitioning from one state to another state (line 2). The reward is provided by the user
  • A Q(State, Action) value, value associated to a transition state and an action (line 4)
  • A probability (with default values of 1.0) that defines the obstacles or hindrance to migrate from one state to another (line 3)
The estimate combine the Q-value (incentive to move to the best next step) and probability (hindrance to move to any particular state) (line 7).

1
2
3
4
5
6
7
8
class QLData {
  var reward: Double = 1.0
  var probability: Double = 1.0
  var value: Double = 0.0) {
  
  @inline
  final def estimate: Double = value*probability
}

The policy of type QLPolicy is a container for the state transition attributes, rewards, Q-values and probabilities.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class QLPolicy[T](numStates: Int, input: Array[QLInput]) {
 
  val qlData = {
    val data = Array.tabulate(numStates)(
      _ => Array.fill(numStates)(new QLData)
    )
 
    input.foreach(in => {  
      data(in.from)(in.to).reward = in.reward
      data(in.from)(in.to).probability = in.prob
    })
    data
  }
  
  def setQ(from: Int, to: Int, value: Double): Unit =
     qlData(from)(to).value = value
 
  def Q(from: Int, to: Int): Double = qlData(from)(to).value
}

The constructor for QLPolicy takes two arguments:
  • Number of states numStates (line 1)
  • Sequence of input of type QLInput to the policy
The constructor creates a numStates x numStates matrix of transition of type QLData (lines 3 - 12), from the input.

The type QLInput wraps the input data (index of the input state from, index of the output state to, reward and probability associated to the state transition) into a single convenient class.

case class QLInput(
   from: Int, 
   to: Int, 
   reward: Double = 1.0, 
   prob: Double = 1.0)


Model and Training
The first step is to define a model for the reinforcement learning. A model is created during training and is composed of
  • Best policy to transition from any initial state to a goal state
  • Coverage ratio as defined as the percentage of training cyles that reach the (or one of the) goal.

class QLModel[T](val bestPolicy: QLPolicy[T], val coverage: Double)

The QLearning class takes 3 arguments
  • A set of configuration parameters config
  • The search/states space qlSpace
  • The initial policy associated with the states (reward and probabilities) qlPolicy

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
class QLearning[T](
   config: QLConfig, 
   qlSpace: QLSpace[T], 
   qlPolicy: QLPolicy[T]) 

    //model in Q-learning algorithm
  val model: Option[QLModel[T]] = train.toOption
    
    // Generate a model through multi-epoch training
  def train: Try[Option[QLModel[T]]] {}
  private def train(r: Random): Boolean {}

   // Predict a state as a destination of this current 
   // state, given a model
  def predict : PartialFunction[QLState[T], QLState[T]] {}

  // Select next state and action index
  def nextState(st: (QLState[T], Int)): (QLState[T], Int) {} 
}

The model of type Option[QLModel] (line 7) is created by the method train (line 10). Its value is None if training failed.

The training method train consists of executing config.numEpisodes cycle or episode of a sequence of state transition (line 5). The random generator r is used in the initialization of the search space.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
def train: Option[QLModel[T]] = {
  val r = new Random(System.currentTimeMillis)

  Try {
    val completions = Range(0, config.numEpisodes).filter(train(r) )

    val coverage = completions.toSize.toDouble/config.numEpisodes
    if(coverage > config.minCoverage) 
       new QLModel[T](qlPolicy, coverage)
    else 
       QLModel.empty[T]
  }.toOption
}

The training process exits with the model if the minimum minCoverage (number of episodes for which the goal state is reached) is met (line 8).

The method train(r: scala.util.Random) uses a tail recursion to transition from the initial random state to one of the goal state. The tail recursion is implemented by the search method (line 4). The method implements the recursive temporal difference formula (lines 14-18).
The state for which the action generates the highest reward R given a policy qlPolicy (line 10) is computed for each new state transition. The Q-value of the current policy is then updated qlPolicy.setQ before repeating the process for the next state, through recursion (line 21).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def train(r: Random): Boolean = {
   
  @scala.annotation.tailrec
  def search(st: (QLState[T], Int)): (QLState[T], Int) = {
    val states = qlSpace.nextStates(st._1)
    if( states.isEmpty || st._2 >= config.episodeLength ) 
        (st._1, -1)
    
    else {
      val state = states.maxBy(s => qlPolicy.R(st._1.id, s.id))
      if( qlSpace.isGoal(state) )
          (state, st._2)
      else {
        val r = qlPolicy.R(st._1.id, state.id)   
        val q = qlPolicy.Q(st._1.id, state.id)
        // Q-Learning formula
        val deltaQ = r + config.gamma*qlSpace.maxQ(state, qlPolicy) -q
        val nq = q + config.alpha*deltaQ
        
        qlPolicy.setQ(st._1.id, state.id,  nq)
        search((state, st._2+1))
       }
     }
  } 
   
  r.setSeed(System.currentTimeMillis*Random.nextInt)

  val finalState = search((qlSpace.init(r), 0))
  if( finalState._2 != -1) 
    qlSpace.isGoal(finalState._1) 
  else 
    false
}

Note: There is no guarantee that one of the goal state is reached from any initial state chosen randomly. It is expected that some of the training epoch fails. This is the reason why monitoring coverage is critical. Obviously, you may choose a deterministic approach to the initialization of each training epoch by picking up any state beside the goal state(s), as a starting state.

Prediction
Once trained, the model is used to predict the next state with the highest value (or probability) given an existing state. The prediction is implemented as a partial function.

1
2
3
4
def predict : PartialFunction[QLState[T], QLState[T]] = {
  case state: QLState[T] if(model != None) => 
    if( state.isGoal) state else nextState(state, 0)._1
}

The method nextState does the heavy lifting. It retrieves the list of states associated with the current state st through its actions set (line 2). The next most rewarding state qState is computed using the reward matrix R of the best policy of the QLearning model (lines 6 - 8).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
def nextState(st: (QLState[T], Int)): (QLState[T], Int) =  {
  val states = qlSpace.nextStates(st._1)
  if( states.isEmpty || st._2 >= config.episodeLength) 
    st
  else {
    val qState = states.maxBy(
     s => model.map(_.bestPolicy.R(st._1.id, s.id))
           .getOrElse(-1.0)
    )

    nextState( (qState, st._2+1))
  }
}

Conclusion
An article or blog spot can not realistically describe all the elements and strategies of reinforcement learning from K-armed bandits to deep learning. However, this chapter should provide you with a road map on how to implement a simple reinforcement learning algorithm in Scala.

Sunday, June 11, 2017

Scala Immutability & Covariance

This posts illustrates the concept of immutable and covariant containers/collections in Scala in the case of the stack data structure.

Overview
There is a relation between immutability and covariance which may not be apparent to a novice Scala programmer. Let's consider the case of a mutable and immutable implementation of a stack. The mutable stack is a container of elements with method to push element into (pop the last element from) the stack.


class MutableStack[T]  {
  private[this] val _stack = new ListBuffer[T]
  
  final def pop: Option[T]= 
    if(_stack.isEmpty) 
      None 
    else 
      Some(_stack.remove(_stack.size-1))
  
   def push(t: T): Unit = _stack.append(t)
}

The internal container is defined as a ListBuffer instance. The elements are appended to the list buffer (push) and the method pop pops the last elements pushed onto the stack.
This implementation has a major inconvenient: It cannot accept elements of type other than T because ListBuffer is a invariant collection. Let's consider then a immutable stack

Immutability, covariance and tail recursion
An covariant immutable stack cannot access its elements unless its elements are contained by itself. This feat is accomplish by breaking down the stack recursively as the last element pushed into the stack and the previous state of the stack.


class ImmutableStack[+T](
   val t: T, 
   val stack: Option[ImmutableStack[T]]) {

  def this(t: T) = this(t, Some(new EmptyImmutableStack(t)))
  ...
}

In this recursive approach the immutable stack is initialized with a single element of type T and the option of the existing immutable stack. The stack can be defined as reusable with covariance because elements are managed by the stack itself stack.
The next step is to define the initial state of the stack. We could have chosen a singleton empty stack with no elements. Instead, we define the first state of the immutable stack as:


class EmptyImmutableStack[+T](t: T) 
   extends ImmutableStack[T](t, None)

Next let's define the pop and push operators for ImmutableStack. The pop method return the previous state of the immutable stack that is next to last element pushed into the stack. The push method is contra-variant as its push an element of super type of T. The existing state this stack is added as the previous (2nd argument) state.


final def pop: Option[ImmutableStack[T]] = 
  stack.map(sk => new ImmutableStack[T](sk.t, sk.stack))

def push[U >: T](u: U): ImmutableStack[U] = 
  new ImmutableStack[U](u, Some(this))

The next step is to traverse the entire stack and return a list of all its element. This is accomplished through a tail recursion on the state of the stack


 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
def popAll[U >: T]: List[U] = pop(this,List[U]())
 
@scala.annotation.tailrec
private def pop[U >: T](
  _stck: ImmutableStack[U], 
  xs: List[U]
): List[U] = _stck match { 
  
  case st: EmptyImmutableStack[T] => xs.reverse
  case st: ImmutableStack[T] => {
    val newStack = _stck.stack.getOrElse(
      new EmptyImmutableStack[T](_stck.t)
    )
    pop(newStack, _stck.t :: xs)
  }
}

The recursion call pop (line 4) updates the list xs (line 6) and exists when the ImmutableStack is empty of type EmptyImmutableStack (line 9). The list has to be reversed to index the list elements from the last to the first (line 9). As long as the stack is not empty (or type ImmutableStack) the method recurses (line 14).

It is time to test drive this immutable stack.


val intStack = new ImmutableStack[Int](4)
val newStack = intStack.push(56).push(14).push(77)
 
println(newStack.popAll.mkString(", "))

The values in the stack are: 77, 14, 56, 4.
This examples illustrates the concept of immutable, covariant stack by using the instance of the stack has its state (current list of elements it contains).


References
Scala By Example - M. Odersky - June 2014

Monday, May 22, 2017

Normalized Discounted Cumulative Gain in Scala

This post illustrates the Normalized Discounted Cumulative Gain (NDCG) and it implementation in Scala.
Numerous real-life applications of machine learning require the prediction the most relevant ranking of items to optimize an outcome. For instance

  • Evaluate and prioritize counter-measures to cyber-attach
  • Ranks symptoms in a clinical trial
  • Extract documents relevant to a given topic from a corpus
The Discounted Cumulative Gain (DCG) and its normalized counter part, Normalized Discounted Cumulative Gain (NDCG) is a metric original applied in textual information retrieval and extended to other domains.

Discounted Cumulative Gain
Let's dive into the mathematical formalism for the Discounted Cumulative Gain. 

For a indexed target values tj as illustrated in the diagram above, the discounted cumulative gain is computed as
\[DCG=\sum_{j=1}^{n}\frac{2^{t_{j}}-1}{log_{2}(j+1)}\] The objective is to compare any given list of ranked/sorted item with a benchmark which represent the optimum ranking (ranking with the highest DCG value).
\[p(ranking|IDCG)=\frac{log(2)}{IDCG}\sum_{j=0}^{n}\frac{2^{t_{j}}-1}{log(j+1)}\]

Scala implementation
The implementation of the computation of NDCG in Scala is quite simple, indeed. Given a ranked list of items. The three steps are
  • Compute the IDCG (or normalization factor)from the list
  • Compute the DCG for the list
  • Compute the NDCG = DCG/IDCF
First let's consider list of items, of type T to rank. The method ranking to sort a sample of sorted items is provided as an implicit function. The constructor for NDCG has a single argument: the sample of ranking:
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
class NDCG[T](
   firstSample: Seq[T])
   (implicit ranking: T => Int) {
  import NDCG._

  val iDCG: Double = normalize

  def score: Double = score(initialSample)

  private def dcg(samples: Seq[T]): Double =
    samples.zipWithIndex.aggregate(0.0)(
      (s, samplej) => s + compute(samplej._2 + 1, ranking(samplej._1))
      , _ + _)


  private def normalize: Double = {
    val sorted = initialSample.zipWithIndex.sortBy{
      case (sample, n) => -ranking(sample)
    }.map( _._1)
    dcg(sorted)
  }
}

The Ideal Discounted Cumulative Gain, iDCG is compute through the normalize method (line 6). iDCG (normalization factor) is computed by first sorting the items of type T by their value in decreasing order (line 16), then scoring this re-sorted list using the dcg method (line 17).
The computation of the Discounted Cumulative Gain by the method dcg (line 10) is a direct application of the formula described in the previous chapter.
Note: The logarithm function uses a base 2. It is computed as natural log(x)/natural log (2)
Let's now consider a list of items of type Item defined as follows:
case class Item(id: String, x: Double, rank: Int)

The list of items, itemsList is implicitly ranked through the last attribute, rank.

val itemsList = Seq[Item](
  Item("1", 3.4, 4), Item("2", 1.4, 3),
  Item("3", -0.7, 5), Item("4", 5.2, 2), 
  Item("5", 1.4, 1))

implicit val ranking = (item: Item) => item.rank

It is time to compute the NDCG coefficient for the list of items, by invoking the score method.

val nDCG = new NDCG[Item](itemsList)

println(s"IDCG = ${nDCG.iDCG}")    //45.64
println(s"Score = ${nDCG.score}")  // 0.801

The ideal discounted cumulative gain, iDCG is 45.6: It is the optimum ranking for this list of time. The first sample score a probability of 0.8
Note
The DCG of subsequent samples can be computed using the same iDCG value from the same instance of NDCG.

def score(samples: Seq[T]): Double =
  if( samples.size != initialSample.size) 0.0 
  else dcg(samples)/iDCG

References

Sunday, April 2, 2017

Recursive Minimum Spanning Tree in Scala

Description and implementation of the computation of the minimum spanning tree using Prim's algorithm and Scala tail recursion.

Overview
Finding the optimum arrangement to connect nodes is a common problem in Network design, transportation projects or electrical wiring. Each connectivity is usually defined as a weight (cost, length, time...). The purpose is to compute the schema that connects all the nodes that minimize the total weight. This problem is known as the minimum spanning tree or MST related to the nodes connected through an un-directed graph.
Several algorithms have been developed over the last 70 years to extract the MST from a graph. This post focuses on the implementation of the Prim algorithm in Scala.

Prim's algorithm
There are many excellent tutorial on graph algorithm and more specifically on the Prim's algorithm. I recommend Lecture 7: Minimum Spanning Trees and Prim’s Algorithm Dekai Wu, Department of Computer Science and Engineering - The Hong Kong University of Science & Technology
Let's PQ is a priority queue, a Graph G(V, E) with n vertices V and E edges w(u,v). A Vertex v is defined by

  • An identifier
  • A load factor, load(v)
  • A parent tree(v)
  • The adjacent vertices adj(v)
The Prim's algorithm can be easily expressed as a simple iterative process. It consists of using a priority queue of all the vertices in the graph and update their load to select the next node in the spanning tree. Each node are popped up (and removed) from the priority queue before being inserted in the tree.
PQ <- V(G)
foreach u in PQ
  load(u) <- INFINITY
 
while PQ nonEmpty
  do u <- v in adj(u)
    if v in PQ && load(v) < w(u,v)
    then
      tree(v) <- u
      load(v) <- w(u,v)
The Scala implementation relies on a tail recursion to transfer vertices from the priority queue to the spanning tree

Scala implementation: graph definition
The first step is to define a graph structure with edges and vertices. The graph takes two arguments:
  • numVertices number of vertices
  • start index of the root of the minimum spanning tree
The vertex class has three attributes
  • id identifier (arbitrary an integer)
  • load dynamic load (or key) on the vertex
  • tree reference to the minimum spanning tree
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
final class Graph(numVertices: Int, start: Int = 0) {
 
  class Vertex(val id: Int, 
     var load: Int = Int.MaxValue, 
     var tree: Int = -1) 

  val vertices = List.tabulate(numVertices)(new Vertex(_))
  vertices.head.load = 0
  val edges = new HashMap[Vertex, HashMap[Vertex, Int]]

  def += (from: Int, to: Int, weight: Int): Unit = {
    val fromV = vertices(from)
    val toV = vertices(to)
    connect(fromV, toV, weight)
    connect(toV, fromV, weight)
  }
  def connect(from: Vertex, to: Vertex, weight: Int): Unit = {
    if( !edges.contains(from))
      edges.put(from, new HashMap[Vertex, Int])    
    edges.get(from).get.put(to, weight)
  }   
  // ...
}

The vertices are initialized by with a unique identifier id, and a default load Int.MaxValue and a default depth tree (lines 3-5). The Vertex class resides within the scope of the outer class Graph to avoid naming conflict. The vertices are managed through a linked list (line 7) while the edges are defined as hash maps with a map of other edges as value (line 9). The operator += add a new edges between two existing vertices with a specified load (line 11)
In most case, the identifier is a characters string or a data structure. As described in the pseudo-code, the load for the root of the spanning tree is defined a 0.

The load is defined as an integer for performance's sake. It is recommended to convert (quantization) a floating point value to an integer for the processing of very large graph, then convert back to a original format on the resulting minimum spanning tree.
The edges are defined as hash table with the source vertex as key and the hash table of destination vertex and edge weight as value.

The graph is un-directed therefore the connection initialized in the method += are bi-directional.

Scala implementation: priority queue
The priority queue is used to reordered the vertices and select the next vertex to be added to the spanning tree.
Note: There are many different implementation of priority queues in Scala and Java. You need to keep in mind that the Prim's algorithm requires the queue to be reordered after its load is updated (see pseudo-code). The PriorityQueue classes in the Scala and Java libraries do not allow elements to be removed or to be explicitly re-ordered. An alternative is to used a binary tree, red-black tree for which elements can be removed and the tree re-balanced.
The implementation of the priority has a impact on the time complexity of the algorithm. The following implementation of the priority queue is provided only to illustrate the Prim's algorithm.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
class PQueue(vertices: List[Vertex]) {
   var queue = vertices./:(new PriorityQueue[Vertex])((pq, v) => pq += v)
    
   def += (vertex: Vertex): Unit = queue += vertex
   def pop: Vertex = queue.dequeue
   def sort: Unit = {}
   def push(vertex: Vertex): Unit = queue.enqueue(vertex)
   def nonEmpty: Boolean = queue.nonEmpty
}
  
type MST = ListBuffer[Int]
implicit def orderingByLoad[T <: Vertex]: Ordering[T] = Ordering.by( - _.load)  

The Scala PriorityQueue class required the implicit ordering of vertices using their load (line 2). This accomplished by defining an implicit conversion of a type T with upper-bound type Vertex to Ordering[T] (line 12).
Note: The type T has to be a sub-class of Vertex. A direct conversion from Vertex type to Ordering[Vertex] is not allowed in Scala.
We use the PriorityQueue from the Java library as it provides more flexibility than the Scala TreeSet.

Scala implementation: Prim
This implementation is the direct translation of the pseudo-code presented in the second paragraph. It relies on the efficient Scala tail recursion (line 5)


 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def prim: List[Int] = {
  val queue = new PQueue(vertices)
   
  @scala.annotation.tailrec
  def prim(parents: MST): Unit = {
    if( queue.nonEmpty ) {
      val head = queue.pop
      val candidates = edges.get(head).get
          .filter{ 
            case(vt,w) => vt.tree == -1 && w <= vt.load
          }
 
      if( candidates.nonEmpty ) {
        candidates.foreach {case (vt, w) => vt.load = w }
        queue.sort
      }
      parents.append(head.id)
      head.tree = 1
      prim(parents)
    }
  }
  val parents = new MST
  prim(parents)
  parents.toList
}

As long as the priority queue is not empty (line 6), the next element is the priority queue is retrieved (line 7) for which is select the most appropriate candidate for the next vertex (line 8 - 11). The load of each candidates is updated (line 14) and the priority queue is re-sorted (line 15).
Note: Although a tree set is more efficient to manage the vertices waiting to be weighted, it does not allow resorted the existing priority queue (immutability).

Time complexity
As mentioned earlier, the time complexity depends on the implementation of the priority queue. If E is the number of edges, and V the number of vertices:
Minimum spanning tree with linear queue: V2
Minimum spanning tree with binary heap: (E + V).LogV
Minimum spanning tree with Fibonacci heap: V.LogV
Note: See Summary of time complexity of algorithms for details.

References
  • Introduction to Algorithms Chapter 24 Minimum Spanning Trees - T. Cormen, C. Leiserson, R. Rivest - MIT Press 1989
  • Lecture 7: Minimum Spanning Trees and Prim’s Algorithm Dekai Wu, Department of Computer Science and Engineering - The Hong Kong University of Science & Technology
  • Graph Theory Chapter 4 Optimization Involving Tree - V.K. Balakrishnan - Schaum's Outlines Series, McGraw Hill, 1997