Saturday, October 29, 2022

Deep Java Library memory management

This post introduces some techniques to monitor memory usage and leaks in machine learning applications using the Deep Java Learning (DJL) library [1]. This bag of tricks is far from being exhaustive.

DJL is an open source framework to support distributed inference in Java for deep learning models such as MXNet, Tensor flow or PyTorch.
The training of deep learning models may require a very significant amount of floating computations which are best supported by GPUs. However, the memory model in JVM is incompatible with column-based resident memory requires by the GPU. 

Vectorization libraries such as Blast are implemented in C/C++ and support fast execution of linear algebra operations. The ubiquitous Python numerical library, numpy [2] commonly used in data science is a wrapper around these low level math functions. The ND interface, used in DJL, provide Java developers with similar functionality.

Note: The code snippets in this post are written in Scala but can be easily reworked in Java

The basics

Memory types

DJL supports monitoring 3 memory components
  • Resident Set Size (RSS) is the portion of the memory used by a process that is held in RAM memory and cannot be swapped. 
  • Heap is the section of memory used by object dynamically allocated
  • Non-heap is the section encompassing static memory and stack allocation

Tensor representation

Deep learning frameworks operations on tensors. Those tensors are implemented as NDArray objects, created dynamically from array of values (integer, float,...). NDManager is memory collector/manager native to the underlying C++ implementation of the various deep learning frameworks. Its purpose is to create and delete (close) NDArray instances. NDManager has a hierarchical (single root tree) structure the child manager can be spawn from a parent [3].

Let's consider the following, simple example of the computation of the mean of a sequence of floating point values.
import ai.djl.ndarray.NDManager

// Set up the memory manager
val ndManager = ndManager.newBaseManager()
val input = Array.fill(1000)(Random.nexFloat())
// Allocate resources outside JVM
val ndInput = ndManager.create(input)
val ndMean = ndInput.means()
val mean = ndMean.toFloatArray.head

// Release ND resources

The steps implemented in the code snippet are:
  1. instantiates the root resource manager, ndManager
  2. creates an array of 1000 random floating point values
  3. convert into a ND array, ndInput
  4. computes the mean, ndMean
  5. convert back to Java data types
  6. and finally close the root manager.

The root NDManager can be broken down it child managers to allow a finer granularity of allocation and release of resources. The following method, computeMean, instantiates a child manager, subNDManager,  to compute the mean value.  The child manager has to be explicitly closed (releasing associated resources) before the function returns.
The memory associated with the local ND variables, ndInput and ndMean are automatically released when going out of scope.

import ai.djl.ndarray.NDManager

def computeMean(input: Array[Float], ndManager: NDManager): Float = 
  if(input.nonEmpty) {
    val subNDManager = ndManager.newSubManager()
    val ndInput = ndManager.create(input)
    val ndMean = ndInput.means()
    val mean = ndMean.toFloatArray.head
////f// Local resources, ndInput and ndMean are released
     // when going out of scope

JMX to the rescue

The JVM provides developers with the ability to access operating system metrics such as CPU, or heap consumption through the Java Management Extension (JMX) interface [4]

The DJL class, MemoryTrainingListener, leverages JMX monitoring capability, It provides developers with a simple method, collectMemoryInfo to collect metrics

First we need to instruct DJL to enable collection of memory stats as a Java property

System.setProperty("collect-memory", "true") 

Similarly to the VisualVM heap memory snapshot, described in the next section, we can collect memory metrics (RSS, Heap and NonHeap) before and after each new NDArray object is created or released. 

def computeMean(
  input: Array[Float], 
  ndManager: NDManager, 
  metricName: String): Float = {
  val manager = ndManager.newSubManager()
    // Initialize a new metrics
  val metrics = new Metrics()

    //  Initialize the collection of memory related metrics
  val initVal = metrics.latestMetric(metricName).getValue.longValue
  val ndInput = ndManager.create(input)
  val ndMean = ndInput.mean()

  collectMetric(metrics, initVal, metricName)
  val mean = ndMean.toFloatArray.head

  // Close the output array and collect metrics
  collectMetric(metrics, initVal, metricName)
  // Close the input array and collect metrics
  collectMetric(metrics, originalValue, metricName)
  // Close the sub manager and collect metrics
  collectMetric(metrics, initVal, metricName)

First we instantiate a Metrics that is passed along all the various snapshots. Given the metrics and current NDManager, we create a base line in heap memory size, initVal.  We then collect the value of the metric for each creation and release of NDArray instances (collectMetric) from our mean computation example.

Here is a simple snapshot method which compute the increase/decrease in heap memory from the base line.
def collectMetric(
  metrics: Metrics, 
  initVal: Long, 
  metricName: String): Unit = {

  val newVal = metrics.latestMetric(metricName).getValue.longValue
  println(s"$metricName: ${(newVal - initVal)/1024.0} KB")

Memory leaks detection

I have been a combination of several investigative techniques for estimating the source of a memory leak.


This method will dump basic memory and CPU stats into a local file for a given metrics

  MemoryTrainingListener.debugDump(metrics, outputFile)


It is not uncommon to have a NDArray objects associated with a sub manager not been properly closed. One simple solution is to prevent allocating new objects into the parent manager.

// Protect the parent/root manager from
// accidental allocation of NDArray objects
// Set up the memory manager
val ndManager = ndManager.newBaseManager()

// Results in an error
val ndInput = ndManager.create(input)


For reference, DJL introduces a set of experimental profilers to support investigation of memory consumption bottlenecks [5]


We select VisualVM [6] among the various JVM profiling solutions to highlight some key statistics in investigating a memory leak.  VisualVM is a utility that is to be downloaded for Oracle site. It is not bundled with JDK.
A simple way to identify excessive memory consumption is taking regular snapshots or dump of the objects allocated from the heap, as illustrated below.

VisualVM has an intuitive UI to drill down into the sequence or composite objects. Besides quantifying memory consumption during inference, the following details view illustrates the hierarchical nature of the ND manager.

No comments:

Post a Comment