Tuesday, January 10, 2023

BERT with Deep Java Library & Spark

This post describes the architecture of a fast, distributed inference for Bidirectional Encoder Representations from Transformer (BERT) models using Apache Spark, Kafka and Deep Java Library (DJL).

Target audience: Advanced
Estimated reading time: 25'

The best of both world

Python environment are widely used to create and train deep learning models such as Tensor flow, PyTorch or MXNet. As an interpreted language, Python provides data science with flexible notebooks to interactively develop, evaluate and refine neural models. 

Python has an intensive library of natural language processing, machine learning models, statistics algorithms and data set management tools.

This dynamic development environment has two major drawbacks for run-time inference

  1. Python has limited ability to parallelize tasks: either by executing concurrent threads or distributing tasks over a network.
  2. Commercial applications rely on web services, executes on Java virtual machine (JVM), and leverages the vast array of Apache open source libraries.

Is it possible to use Python environment to define, train and evaluate deep learning models and JVM-based language for real-time inference?

The answer lies on the fact that deep learning frameworks such a PyTorch or TensorFlow are binary executable written in C++. The binary implementation of these deep learning libraries is accessible by both Python and Java through their respective interfaces.


Inference with Spark and Deep Java Library

Apache Spark and the Deep Java Library addresses the two key limitations of deployment of machine learning models in production using Python

  1. Apache Spark allows fast concurrent processing of large data set on multiple distributed services
  2. Deep Java Library is a Java library implementing the most common deep learning models accessed through a java native interface. 

The most common scenario consists of developing models in Python environment such as Jupyter, IDE, Anaconda, save the model parameters. DJL loads the model parameters and initializes the inference model to serve run-timer requests.


Distributed inference pipeline

The goal is to scale predictions by parallelizing the execution of predictions. The key elements of the distributed inference pipeline are

  • Apache Spark partitions run-time requests to prediction into batch that are executed concurrently on remote worker nodes
  • Apache Kafka is an asynchronous messaging queue that decouples client application from the inference pipeline.
  • DJL interfaces with the binary executable of the deep learning models
  • Kubernetes containerizes the instances of the inference pipelines to support scalable, automated deployment. Spark 3.2 and later version supports direct integration with Kubernetes
  • Deep learning frameworks include TensorFlow, MXNet and PyTorch

The two main benefits of such pipeline are simplicity (all tasks/processes run on JVM) and low latency.

Note: Spark and DJL can also be used in the training phase to distribute the training of a min batch

Apache Kafka

Apache Kafka is an open-source distributed event streaming platform for high volume data pipelines, streaming analytics, data integration, and mission-critical applications. Kafka supports event streaming ensures a continuous flow of data through a pipeline or sequence of transformation such as Extract, Transform and Load.

First ,we construct the handler class, KafkaPrediction that

  1. consumes requests from Kafka topic consumeTopic
  2. invokes the prediction model and transformation, predictionPipeline
  3. produces prediction into Kafka topic produceTopic
The actual request is wrapped into the consumed message, RequestMessage. Same for the prediction produced back to the Kafka queue.

class KafkaPrediction(
 consumeTopic: String,
 produceTopic: String,
 predictionPipeline: Seq[Request] => Seq[Prediction])  {
      // 1 - Constructs the transform of Kafka messages for prediction
  val transform = (requestMsg: Seq[RequestMessage]) => {
      // 2- Invoke the execution of the pipeline
      val predictions = predictionPipeline(requestMsg.map(_.requestPayload))

    // 3- Build the Kafka consumer for prediction request
  val consumer = new KafkaConsumer[RequestMessage](
    // 4- Build the Kafka producer for prediction response
  val producer = new KafkaProducer[ResponseMessage](
  1. We first need to create a wrapper function, transform to generate a prediction. The  function converts a request message of type RequestMessage into a prediction of type ResponseMessage.
  2. The wrapper, transform invoke the prediction pipeline predictionPipeline after converting the messages of type RequestMessage consumed from Kafka into actual request (Request). The predictions are converted into message of type ResponseMessage produced to Kafka
  3. The consumer is fully defined by the de-serialization of data consumed from Kafka and its associated topic
  4. The producer serialized the response back to Kafka service.
def executeBatch(
  consumeTopic: String, 
  produceTopic: String, 
  maxNumResponses: Int): Unit = { 
   // 1 - Initialize the prediction pipeline
 val kafkaHandler = new KafkaPrediction(

  while(running)  {
      // 2 - Pool the request topic (has its own specific Kafka exception handler)
   val consumerRecords = kafkaHandler.consumer.receive
   if(consumerRecords.nonEmpty) {
        // 3 - Generate and apply transform to the batch
     val input: Seq[RequestMessage] = consumerRecords.map(_._2)
     val responses = kafkaHandler.predict(input) 
     if(responses.nonEmpty) {
        // 4 - Produce to the output topic
        val respMessages = responses.map(
             response =>(response.payload.id, response)
       // 5- Produce the batch of response messages to Kafka
       // 6 - Get confirmation from Kafka has indeed processed the response
        logger.error(No response is produced to Kafka")
  1. First we instantiate the Kafka message handler class, KafkaPrediction we created earlier
  2. At regular interval, we pull a batch of new requests from Kafka
  3. If the batch is not empty, we invoke the handler, predict to the prediction models
  4. Once done, we encapsulate the predictions into a ResponseMessage instances
  5. The messages are produced into the producer topic in the Kafka queue 
  6. Finally, Kafka acknowledges the correct reception of the responses, asynchronously.
 Next, we leverage Spark to distribute the batch of requests across multiple computation nodes (workers)

Apache Spark

Apache Spark is an open-source, distributed processing system used for processing large scale data set. It utilizes in-memory caching, and optimized query logic to execute real-time analytics.

In our scenario, we use Spark to distribute a batch of requests consumed from Kafka across a network so multiple BERT models executes concurrently. This design avoids single point of failure (Fault tolerance) and allows use generic, cost effective hardware.

Leveraging Spark data set and partitioning is surprisingly simple.

def predict(
   requests: Seq[Request]
)(implicit sparkSession: SparkSession): Seq[Prediction] = {
  import sparkSession.implicits._

    // 1 - Convert request into a Spark data set
  val requestDataset = requests.toDS()
    // 2 - Execute the prediction by invoking the DJL model
  val responseDataset: Dataset[Prediction] = requestDataset(predict(_))
    // 3 - Convert Spark data set response 
  1. Once the spark session (context) is initiated, the batch of requests is converted into a data set, requestDataset
  2. Spark applies the prediction model (DJL) on each request on the partitioned data 
  3. Finally, the predictions are collected from the Spark worker nodes before been returned to the Kafka handler

Note: The Spark context is assumed to be created and passed as implicit parameter to the prediction method.


Deep Java Library

This is the last component that links the incoming and outgoing data to the deep learning models. Deep Java Library (DJL) is open-source Java framework that supports the most common deep learning frameworks; MXNet, PyTorch and TensorFlow.

DJL ability to leverage any hardware configuration (CPU, GPU) and integrated with big data frameworks makes it the ideal solution for a highly performant distributed inference engine. DJL can be optionally used for training too. 

The input tensors are processed by the deep learning models on GPU. The data is to be allocated in the native memory address space, outside the JVM and garbage collector. DJL library supports native tensor of type NDArray and list of tensors of type NDList and a simple memory management singleton, NDManager

The classifier executes on the Spark worker. The following code snippet is an over simplification of the step to invoke BERT-based classifier using DJL.

class BERTClassifier(
minTermFrequency: Int,
path: Path)(implicit sparkSession: SparkSession) {

  // 1 - Manage tensor allocation as NDArray
  val ndManager = NDManager.newManager()
  // 2 - Define the configuration of the classifier
  val classifyCriteria: Criteria[NDList, NDList] = Criteria.builder()
     .setTypes(classOf[NDList], classOf[NDList])
     .optProgress(new ProgressBar())
  // 3- Load the model from a local file
  val thisModel = classifyCriteria.loadModel()
  // 4 - Instantiate a new predictor
  val predictor = thisModel.newPredictor()

  // 5 - Execute this request on this worker node
  def predict(requests: Request): Prediction = {
    predictor.predict(ndManager, requests)

  // 6- Close resources
  def close(): Unit = {
  1. Set the manager for tensor in native memory
  2. Configure the classifier with its related neural block (classificationBlock)
  3. Load the model (MXNet, PyTorch or TensorFlow) from local file
  4. Instantiate a predictor from the model
  5. Submit the request to the DL model and return a prediction
  6. Close all the resources allocated in the native memory at the end of the run



No comments:

Post a Comment