Problem Statement

MXNet provides a Scala Interface for Training and Inference. However there are no separate APIs for Inference, the existing APIs are not user-friendly for Inference use-case. There are also correctness problems as reported by an Expert Scala User, Calum in [Thread Safety issues in MXNet Scala]( some of which have been fixed by the user and contributed back to MXNet. This user has overcome these difficulties by putting in a lot of engineering effort and written custom code to make it work for them.

With regards to Usability, the code example in the Appendix.1 shows the number of API calls required to run inference on a single input. It requires the user to understand many MXNet Scala API modules such as Module API, DataIterator, DataBatch, etc., Many of these APIs are designed mainly for training which take too many arguments, some are not scala Idiomatic with insufficient documentation.

With regards to Thread-Safety, as the user has noted, it is not only that a particular model cannot be called from multiple threads in Scala, Users have to make all calls to MXNet on only on one thread throughout the process lifetime. This is a very rigid restriction and requires users to know this detail and even requires workarounds to be built.

All of the above increases friction and difficulty of using MXNet Scala API for inference


1. A set of new Scala Inference APIs.
2. Idiomatic and Easy to use.
3. Has full test coverage.
4. Is thread safe.
5. Performance - it should be at least as performant as the Python API and the old Scala API
6. Memory Management - it should efficiently allocate and release memory without incurring any leaks.

Appendix 2 shows an example of how the new APIs can be used.

( I recommend to implement Inference APIs in Scala, because Scala is a typed language that is widely used on Apache Spark for distributed data processing, these APIs will add DL capabilities to Spark eco-system and help capture those users. MXNet also already provides exhaustive set of Scala APIs which can be re-used for this implementation )


1. Load a pre-trained model and run single input single-input inference on CPU.
2. Load a pre-trained model and run inference on batch of input on CPU.

API Spec

The APIs are designed based on Object Oriented concepts and look like the below at a high level.

DataDescriptor describing input and output of the model

DataDesc is a MXNet class defined in the MXNet package. We will use this class to define the input and output nodes. 

package org.apache.mxnet
// Named data desc descriptor contains name, shape, type, layout and other extended attributes.
class DataDesc(name: String, shape: Shape, dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW")

1. MXInferBase trait defines routines needed for Prediction and Classification


private[infer] trait MXInferBase {
* Base Trait for MXNnet Inference classes.
* This method will take input as Traversable one dimensional arrays and creates NDArray based on the Shape and layout
* of the input descriptors.
* @param input: Input as Traversable Arrays. A traversable array is needed for input to be of more than one node in the graph.
* @return Traversable Array of Numbers reshaped in row-major order from the output NDArray.
def predict[T <: Number ](input: Traversable[Array[T]]): Traversable[Array[T]]

NDArray objects allocated off heap(for a particular model) in this API can be maintained in an ObjectPool and reused(since they will be of the same shape across calls).

* This method is useful the input is a batch of data or when multiple operations on the input/output have to performed.
* @param input: Traversable of MXNet NDArrays.
* @return Traversable NDArray
def predict(input: Traversable[NDArray]): Traversable[NDArray]

3. An abstract class that initializes the Module object and implements the predict routines in MXNet.

The predict routines will perform a runtime check on the input type and the Dtype defined in the input descriptor and throw run time exceptions if there is a mismatch.

There might be way to perform this at compile time(Scala Macros, etc.,) that needs to researched.

* This is the abstract class that all predictors and classifiers will specialize on
* @param model_prefix Prefix from where load to the model. This will support loading from file://, https://, hdfs://, s3://
* Example: file://model-dir/resnet-152(contains resnet-152-symbol.json, resnet-152-XXXX.params and optionally synset.txt).
* @param input_descriptors: Descriptors defining the input node names, shape, layout and Type parameters
* @param output_descriptors: Descriptors defining the output node names, shape, layout and Type parameters
abstract class MXNetInference(model_prefix: String, input_descriptors: Traversable[DataDesc], output_descriptors: Traversable[DataDesc]) extends MXInferBase {
// Initialize Module similar to Appendix, associate a thread from the Handler and use that for every operation.
def predict[T <: Number ](input: Traversable[Array[T]], top_k: Int = 1): Traversable[Array[T]] ={
def predict(input: Traversable[NDArray]): Traversable[NDArray] ={

4. A generic Classifier that also returns the labels associated with the results

* Class for classification tasks
* @param model_prefix Prefix from where load the model. Example:
* @param input_desc: Descriptors defining the input node names, shape, layout and Type parameters
* @param output_desc: Descriptors defining the output node names, shape, layout and Type parameters
class MXNetClassifier(model_prefix: String, input_desc: Traversable[DataDesc], output_desc: Traversable[DataDesc]) extends MXNetInference(model_prefix, input_desc, output_desc) {
* Takes an Array of Numbers and returns corresponding labels.
* @param input: Traversable array of numbers.
* @param top_k: (Optional) How many top_k elements to return, if not passed returns unsorted output.
* @tparam T: Number
* @return Traversable Sequence of (Label, Score) tuple.
def classify[T <: Number](input: Traversable[Array[T]], top_k: Option[Int] = None): Traversable[Seq[(String, T)]] = {
* Takes a Batch of Images and returns Label, Score tuples.
* @param input: Batch of Images
* @param top_k: (Optional) How many top_k elements to return, if not passed returns unsorted output.
* @return Traversable Sequence of (Label, Score) tuple, Score will be in the form of NDArray
def classify(input: Traversable[NDArray], top_k: Option[Int] = None): Traversable[Seq[(String, NDArray)]] = {

5. Predictor - Concrete class of MXNetInference class for Regression use-cases

class MXNetPredictor(model_prefix: String, input_desc: Traversable[DataDesc], output_desc: Traversable[DataDesc]) extends MXNetInference(model_prefix, input_desc, output_desc) {

6. ImageClassifier - Specializes Classifier.

//TODO: consider providing preprocess, postprocess, common image operations.

class MXNetImageClassifier(model_prefix: String, input_desc: Traversable[DataDesc], output_desc: Traversable[DataDesc]) extends MXNetClassifier(model_prefix, input_desc, output_desc) {

* Image Classifier that takes an Image and returns a Traversable Sequence of class, Scores.
* @param input: Java Image(java.awt.image)
* @return Traversable Sequence of class and Scores.
def classify(input: Image, top_k:Option[Int] = None): Traversable[Seq[(String, Double)]] = {

* Takes a batch of images and returns Traversable Sequence of Class, Scores.
* @param input_batch A Batch of Java images(Java Image(java.awt.image))
* @return Traversable sequence of class and score(represented as Double).
def classify(input_batch: Traversable[Image]): Traversable[Seq[(String, Double)]] = {



7. ObjectDetector - Specializes Predictor

class MXNetImageDetector(model_prefix: String, input_desc: Traversable[DataDesc], output_desc: Traversable[DataDesc]) extends MXNetPredictor {

8. Specialized Inference APIs for RNN.

To add value to RNNs users, we need to support Bucketing Module which is used commonly when the input is of variable length. They also need support for mapping input to embeddings. Embeddings contain forward and backward dictionary, projections of input char/words in vector space, these are learnt parameters before the training phase(Word2Vec), MXNet Embeddings do not have a standard format yet(json/binary/). I will consider this after the first Iteration.

Approaches to handling Thread Safety


1. MXNet Backend should provide thread-safety atleast on a single model or allow the client to instantiate the same model across threads and also allow the Modules(executors) to share the same set of weights, in the case of inference these weights are not changed. Temporary Workspace used by the Operators can be efficiently managed by using Memory Pools, this would be efficient the shape of these objects does not change in the case of Inference.
2. The client(users of Inference APIs) can also fork MXNet processes per model and manage a process pool, however the client has to now manage Interprocess communication between its own process and MXNet processes. Although we hope that fork will Copy on Write, it is not always guaranteed when forked process memory is not contiguous which will lead cloning of parameters, graph and explode.
3. TBD - Test using MXNet Naive Threads.
4. TBD - Better ways to handle this.
5. The first approach would be to use the Singleton Dispatcher thread to handle all calls to MXNet as suggested by Calum. This approach has some overhead and does not guarantee Latency for the APIs.

Create a Singleton MXNetHandlerFactory object that encapsulates interaction with MXNet.
This is a private object and is not exposed to the user.
Initial implementation will use dispatcher patter of having a single thread running all MXNet related tasks.

a. The Key will be a Module Object and returns a thread associated with it.

private[infer] object MXNetThreadFactory {
private val threadMap: Map[Module, Thread] = ListMap.empty
* Return a Thread with the Module Should be initialized
def getMXNetThread(symbol_file: File): Thread = {
// return thread associated with this Symbol


1. Inference using current Scala APIs

val dataDesc = DataDesc("data", Shape(1, 3, 224,224))
//1. Load the module using Module APIs
private val mod = Module.loadCheckpoint(prefix, 0)

//2. Bind the input Shape to the module
mod.bind(IndexedSeq(dataDesc), forTraining = false)
//3. User will have data in Scala primitive such as a Scala Array
//4. Convert to NDarray
//5. Create a DataIterator
//6. call Predict and convert the result to Scala Array

def predict(input: Array[Float], top_k: Int = null) : Array[Float] = {
val shape = Shape(1, 3, 224,224)
val a = NDArray.array(input, shape)
val iter = new NDArrayIter(IndexedSeq(a))
val result = mod.predict(iter)

2. New Inference APIs

//define input and output descriptors
val input_descriptor = DataDesc("data", Shape(1, 3, 224,224))
val output_descriptor = DataDesc("softmax", Shape(1, 1000))
//create a inference object
val infer_object = MXNetImageClassifier(model_files, input_descriptor, output_descriptor)
val input_image = new BufferedImage(..)
//run classification inference task.
val result = infer_object.classify(input_image,5)

  • No labels