Saturday, October 29, 2022

Deep Java Library memory management

This post introduces some techniques to monitor memory usage and leaks in machine learning applications using the Deep Java Learning (DJL) library [1]. This bag of tricks is far from being exhaustive.

DJL is an open source framework to support distributed inference in Java for deep learning models such as MXNet, Tensor flow or PyTorch.
The training of deep learning models may require a very significant amount of floating computations which are best supported by GPUs. However, the memory model in JVM is incompatible with column-based resident memory requires by the GPU. 

Vectorization libraries such as Blast are implemented in C/C++ and support fast execution of linear algebra operations. The ubiquitous Python numerical library, numpy [2] commonly used in data science is a wrapper around these low level math functions. The ND interface, used in DJL, provide Java developers with similar functionality.

Note: The code snippets in this post are written in Scala but can be easily reworked in Java

The basics

Memory types

DJL supports monitoring 3 memory components

  • Resident Set Size (RSS) is the portion of the memory used by a process that is held in RAM memory and cannot be swapped. 
  • Heap is the section of memory used by object dynamically allocated
  • Non-heap is the section encompassing static memory and stack allocation

Tensor representation

Deep learning frameworks operations on tensors. Those tensors are implemented as NDArray objects, created dynamically from array of values (integer, float,...). NDManager is memory collector/manager native to the underlying C++ implementation of the various deep learning frameworks. Its purpose is to create and delete (close) NDArray instances. NDManager has a hierarchical (single root tree) structure the child manager can be spawn from a parent [3].

Let's consider the following, simple example of the computation of the mean of a sequence of floating point values
import ai.djl.ndarray.NDManager

// Set up the memory manager
val ndManager = ndManager.newBaseManager()
val input = Array.fill(1000)(Random.nexFloat())
// Allocate resources outside JVM
val ndInput = ndManager.create(input)
val ndMean = ndInput.means()
val mean = ndMean.toFloatArray.head

// Release ND resources

The steps implemented in the code snippet are:
  1. instantiates the root resource manager, ndManager
  2. creates an array of 1000 random floating point values
  3. convert into a ND array, ndInput
  4. computes the mean, ndMean
  5. convert back to Java data types
  6. and finally close the root manager.

The root NDManager can be broken down it child managers to allow a finer granularity of allocation and release of resources. The following method, computeMean, instantiates a child manager, subNDManager,  to compute the mean value.  The child manager has to be explicitly closed (releasing associated resources) before the function returns.
The memory associated with the local ND variables, ndInput and ndMean are automatically released when going out of scope.

import ai.djl.ndarray.NDManager

def computeMean(input: Array[Float], ndManager: NDManager): Float = 
  if(input.nonEmpty) {
    val subNDManager = ndManager.newSubManager()
    val ndInput = ndManager.create(input)
    val ndMean = ndInput.means()
    val mean = ndMean.toFloatArray.head
////f// Local resources, ndInput and ndMean are released
     // when going out of scope

JMX to the rescue

The JVM provides developers with the ability to access operating system metrics such as CPU, or heap consumption through the Java Management Extension (JMX) interface [4]

The DJL class, MemoryTrainingListener, leverages JMX monitoring capability, It provides developers with a simple method, collectMemoryInfo to collect metrics

First we need to instruct DJL to enable collection of memory stats as a Java property

System.setProperty("collect-memory", "true")

Similarly to the VisualVM heap memory snapshot, described in the next section, we can collect memory metrics (RSS, Heap and NonHeap) before and after each new NDArray object is created or released. 

 def computeMean(
  input: Array[Float], 
  ndManager: NDManager, 
  metricName: String): Float = {
  val manager = ndManager.newSubManager()
    // Initialize a new metrics
  val metrics = new Metrics()

    //  Initialize the collection of memory related metrics
  val initVal = metrics.latestMetric(metricName).getValue.longValue
  val ndInput = ndManager.create(input)
  val ndMean = ndInput.mean()

  collectMetric(metrics, initVal, metricName)
  val mean = ndMean.toFloatArray.head

  // Close the output array and collect metrics
  collectMetric(metrics, initVal, metricName)
  // Close the input array and collect metrics
  collectMetric(metrics, originalValue, metricName)
  // Close the sub manager and collect metrics
  collectMetric(metrics, initVal, metricName)

First we instantiate a Metrics that is passed along all the various snapshots. Given the metrics and current NDManager, we create a base line in heap memory size, initVal.  We then collect the value of the metric for each creation and release of NDArray instances (collectMetric) from our mean computation example.

Here is a simple snapshot method which compute the increase/decrease in heap memory from the base line.
def collectMetric(
  metrics: Metrics, 
  initVal: Long, 
  metricName: String): Unit = {

  val newVal = metrics.latestMetric(metricName).getValue.longValue
  println(s"$metricName: ${(newVal - initVal)/1024.0} KB")

Memory leaks detection

I have been a combination of several investigative techniques for estimating the source of a memory leak.


This method will dump basic memory and CPU stats into a local file for a given metrics

MemoryTrainingListener.debugDump(metrics, outputFile)


It is not uncommon to have a NDArray objects associated with a sub manager not been properly closed. One simple solution is to prevent allocating new objects into the parent manager.

// Protect the parent/root manager from
// accidental allocation of NDArray objects
// Set up the memory manager
val ndManager = ndManager.newBaseManager()

// Results in an error
val ndInput = ndManager.create(input)


For reference, DJL introduces a set of experimental profilers to support investigation of memory consumption bottlenecks [5]


We select VisualVM [6] among the various JVM profiling solutions to highlight some key statistics in investigating a memory leak.  VisualVM is a utility that is to be downloaded for Oracle site. It is not bundled with JDK.
A simple way to identify excessive memory consumption is taking regular snapshots or dump of the objects allocated from the heap, as illustrated below.

VisualVM has an intuitive UI to drill down into the sequence or composite objects. Besides quantifying memory consumption during inference, the following details view illustrates the hierarchical nature of the ND manager.

Friday, April 8, 2022

Normalized Discounted Cumulative Gain in Scala

Target audience: Advanced
Estimated reading time: 10'

This post illustrates the Normalized Discounted Cumulative Gain (NDCG) and its implementation in Scala.
Numerous real-life applications of machine learning require the prediction the most relevant ranking of items to optimize an outcome. For instance
  • Evaluate and prioritize counter-measures to cyber-attach
  • Ranks symptoms in a clinical trial
  • Extract documents relevant to a given topic from a corpus
The Discounted Cumulative Gain (DCG) and its normalized counter part, Normalized Discounted Cumulative Gain (NDCG) is a metric original applied in textual information retrieval and extended to other domains.
This post uses Scala 2.11.8

Discounted Cumulative Gain

Let's dive into the mathematical formalism for the Discounted Cumulative Gain. 

For a indexed target values tj as illustrated in the diagram above, the discounted cumulative gain is computed as
\[DCG=\sum_{j=1}^{n}\frac{2^{t_{j}}-1}{log_{2}(j+1)}\] The objective is to compare any given list of ranked/sorted item with a benchmark which represent the optimum ranking (ranking with the highest DCG value).

Scala implementation

The implementation of the computation of NDCG in Scala is quite simple, indeed. Given a ranked list of items. The three steps are
  • Compute the IDCG (or normalization factor)from the list
  • Compute the DCG for the list
  • Compute the NDCG = DCG/IDCF
First let's consider list of items, of type T to rank. The method ranking to sort a sample of sorted items is provided as an implicit function. The constructor for NDCG has a single argument: the sample of ranking:

class NDCG[T](
   firstSample: Seq[T])
   (implicit ranking: T => Int) {
  import NDCG._

  val iDCG: Double = normalize

  def score: Double = score(initialSample)

  private def dcg(samples: Seq[T]): Double =
      (s, samplej) => s + compute(samplej._2 + 1, ranking(samplej._1))
      , _ + _)

  private def normalize: Double = {
    val sorted = initialSample.zipWithIndex.sortBy{
      case (sample, n) => -ranking(sample)
    }.map( _._1)

The Ideal Discounted Cumulative Gain, iDCG is computed through the normalize method (line 6). iDCG (normalization factor) is computed by first sorting the items of type T by their value in decreasing order (line 16), then scoring this re-sorted list using the dcg method (line 17). 
The computation of the Discounted Cumulative Gain by the method dcg (line 10) is a direct application of the formula described in the previous chapter.
Note: The logarithm function uses a base 2. It is computed as natural log(x)/natural log (2)

Let's now consider a list of items of type Item defined as follows: 

case class Item(id: String, x: Double, rank: Int)

The list of items, itemsList is implicitly ranked through the last attribute, rank.

val itemsList = Seq[Item](
  Item("1", 3.4, 4), Item("2", 1.4, 3),
  Item("3", -0.7, 5), Item("4", 5.2, 2), 
  Item("5", 1.4, 1))

implicit val ranking = (item: Item) => item.rank

It is time to compute the NDCG coefficient for the list of items, by invoking the score method.

val nDCG = new NDCG[Item](itemsList)

println(s"IDCG = ${nDCG.iDCG}")    //45.64
println(s"Score = ${nDCG.score}")  // 0.801

The ideal discounted cumulative gain, iDCG is 45.6: It is the optimum ranking for this list of time. The first sample score a probability of 0.8

Note The DCG of subsequent samples can be computed using the same iDCG value from the same instance of NDCG.

def score(samples: Seq[T]): Double =
  if( samples.size != initialSample.size) 0.0 
  else dcg(samples)/iDCG

Monday, January 24, 2022

Bloom Filter in Scala

Target audience: Intermediate
Estimated reading time: 20'

A brief introduction to the Bloom filter and its implementation in Scala using a cryptographic digest.


Bloom filter became a popular probabilistic data structure to enable membership queries (object x belonging to set or category Y) a couple of years ago. The main benefit of Bloom filter is to reduce the requirement of large memory allocation by avoiding allocating objects in memory much like HashSet or HashMap. The compact representation comes with a trade-off: although the filter does not allow false negatives it does not guarantee that there is no false positives. 
In other words, a query returns:
  • very high probability that an object belong to a set
  • an object does not belong to a set
A Bloom filter is quite often used as a front end to a deterministic algorithm

Note: For the sake of readability of the implementation of algorithms, all non-essential code such as error checking, comments, exception, validation of class and method arguments, scoping qualifiers or import is omitted.


Let's consider a set A = {a0,.. an-1} of n elements for which a query to determine membership is executed. The data structure consists of a bit vector V of m bits and k completely independent hash functions that are associated to a position in the bit vector. The assignment (or mapping) of hash functions to bits has to follow a uniform distribution. 
The diagram below illustrates the basic mechanism behind the Bloom filter. The set A is defined by the pair a1 and a2. The hash functions h1 and h2 map the elements to bit position (bit set to 1) in the bit vector. The element b has one of the position set to 0 and therefore does not belong to the set. The element c belongs to the set because its associated positions have bits set to 1

However, the algorithm does not prevent false positive. For instance, a bit may have been set to 1 during the insertion of previous elements and the query reports erroneously that the element belongs to the set.
The insertion of an elements depends on the h hash functions, therefore the time needed to add a new element is h (number of hash functions) and independent from size of the bit vector: asymptotic insertion time = O(h). However, the filter requires h bits for each element and is less effective that traditional bit array for small sets.
The probability of false positives decreases as the number n of inserted elements decreases and the size of the bitvector m, increases. The number of hash functions that minimizes the probability of false positives is defined by h = m.ln2/n.

Implementation in Scala

The implementation relies on the MessageDigest java library class to generated the unique hash values. Ancillary methods and condition on methods arguments are ommitted for sake of clarity.
The first step is to define the BloomFilter class and its attributes
  • length Number of entries in the filter (line 2)
  • numHashs Number of hash functions (line 3)
  • algorithm Hashing algorithm with SHA1 as default (line 4)
  • set Array of bytes for entries in the Bloom filter (line 6)
  • digest Digest used to generate hash values (line 7)
class BloomFilter(
  length: Int,
  numHashs: Int, 
  algorithm: String="SHA1") {
  val set = new Array[Byte](length)
  val digest = Try(MessageDigest.getInstance(algorithm))

  def add(elements: Array[Any]): Int {}
  final def contains(el: Any): Boolean = {}

  private def hash(value: Int): Int {}
  private def getSet(el: Any): Array[Int] = {}

The digest using the message digest of the java library
The next step consists of defining the methods to add single generic element add(any: Any) line 8 and array of elements add(elements: Array[Any]) (line 2).

// add an array of elements to the filter
def add(elements: Array[Any]): Int = => {
   elements.foreach( getSet(_).foreach(set(_) = 1) )
def add(any: Any): Boolean = this.add(Array[Any](any))
final def contains(any: Any): Boolean = _ => !getSet(el).exists(set(_) !=1))

The method contains (line 10) evaluates whether an element is contained in the filter. The method returns
  • true if the filter very likely contains the element
  • false if the filter DOES NOT contain this element
The contains method relies on a accessing an element from the set using the recursive getSet method.

def getSet(any: Any): Array[Int] = {
  val newSet = new Array[Int](numHashs)
  newSet.update(0, hash(any.hashCode))
  getSet(newSet, 1)
def getSet(values: Array[Int], index: Int): Unit =
  if( index < values.size) {
    values.update(index, hash(values(index-1)))
    getSet(values, index+1) // tail recursion

Similarly to the add method, the getSet methods has two implementations
  • Generate a new set from any new element (line 1)
  • A recursive call to initialize the Bloom filter with an array if integers (line 9).
The hash method is the core of the Bloom filter: It consists of computing an index of an entry.

def hash(value: Int) : Int = => {
  Math.abs(new BigInteger(1, d.digest).intValue) % (set.size -1)

The instance of the MessageDigestclass, digest generates a hash value using either MD5 or SHA-1 algorithm. Tail recursion is used as an alternative to the iterative process to generate the set.

The next code snippet implements a very simple implicit conversion from Int to Array[Byte] conversion (line 5)

object BloomFilter {
 val NUM_BYTES = 4
 implicit def int2Bytes(value: Int) : Array[Byte] =
    Array.tabulate(NUM_BYTES)(n => {
      val offset = (LAST_BYTE - n) << LAST_BYTE
      ((value >>> offset) & 0xFF).toByte

The conversion relies on the manipulation of bits from a 32 bit Integer to 4 bytes (line 6 - 8). Alternatively, you may consider a conversion from a long value to a 8 byte array.


This simple test consists of checking if a couple of values are indeed contains in the set. The filter will definitively reject 22 and very likely accept 23. If the objective is to confirm that 23 belongs to the set, then a full-fledged hash table would have to be used.

val filter = new BloomFilter(100, 100, "SHA")
final val newValues = Array[Any](57, 97, 91, 23, 67,33)  

println( filter.contains(22) )
println( filter.contains(23) )

Performance evaluation
Let's look at the behavior of the bloom filter under load. The test consists of adding 100,000,000 new random values then test if the filter contains a value (10,000) times. The test is run 10 times after a warm up of the JVM.

final val newValues = Array[Any](57, 97, 91, 23, 67,33)                                  
  // Measure average time to add a new data set
filter.add(Array.tabulate(size)(n => Random.nextInt(n + 1)))

  // Measure average time to test for a value.

The first performance test evaluates the average time required to insert a new element into a Bloom filter which size range from 100M to 1Billion entries.
The second test evaluates the average search/query time for bloom filters with same range of size.

As expected the average time to load a new set of values and check the filter contains a specific value is fairly constant.