Problem

There are 2 phases to applying Deep Learning to a ML problem, the first phase is where a neural network is created and trained using training data to generate pre-trained model and In the second phase, this pre-trained model is put to work by running inference(forward-pass) on new data in the customer’s application in production. Model Creation and Training is typically performed by Data Scientists who prefer using Python as a primary language which provides rich set of libraries(numpy, pandas, pillow) etc., to setup the training pipeline. MXNet already has very good support for Python to quickly prototype and develop models.

Inference on the other hand is run and managed by Software Engineers in a production eco-system which is built with tools and frameworks that use Java/Scala as a primary Language.

Inference on a trained model has two different use-cases:

  1. Real time or Online Inference - tasks that require immediate feedback, such as fraud detection
  2. Batch or Offline Inference - tasks that don't require immediate feedback, these are use-cases where you have massive amounts of data and want to run Inference or pre-compute inference results 


Batch Inference is performed on big data platforms such as Spark using Scala or Java while Real time Inference is typically performed and deployed on popular web frameworks such as Tomcat, Netty, Jetty, etc. which use Java. 

With this project, we want to build a new set of APIs which are Java friendly, compatible with Java 7+, are easy to use for inference, and lowers the entry barrier of consuming MXNet for production use-cases.

Goals

As a user, I’d like to have a Java Inference API that allows me to use deep learning models from my existing Java application.

As a user, I'd like for the new Java Inference API to be thread safe.

As a user, I’d like for the new Java inference API to be idiomatic and easy to use so that I can quickly learn to deploy models.

As a user, I’d like for the new Java Inference API to introduce as few dependencies as possible so that it’s easy to add into my existing environment.

As a user, I’d like for the new Java API to have full test coverage so that I can be confident in it’s stability.

As a user, I’d like for the Java API to be as performant as possible with the performance results measured and available so that I can compare implementations and make informed decisions.

As a user, I would like for the new Java Inference API to support RNNs so that the models I’ve trained can be deployed to our production environments.

As a user already familiar with MXNet, I’d like for the new API to be similar to existing implementations so that it’s easy for me to use.

As a user, I'd like to have examples and tutorials available to help learn how to use the new Java Inference API.

Proposed Approach

The proposed implementation for the new Java API is to create a Java friendly wrapper around the existing Scala API. The Scala API is already fully implemented and is undergoing significant improvements (most notably simplifying the memory management of off-heap memory). By utilizing the existing Scala API, the development effort require for the new Java API is greatly decreased. Additionally, the Java API would automatically (or with minimal work) benefit from new features and code improvements allowing for development efforts to remain focused. This is a very similar approach to how Apache Spark developed their Java API.

Since both Java and Scala are JVM languages, it is already possible for the Scala bindings to be called from Java code by loading the jar into the classpath. Due to differences in the languages, this process is currently very painful for users to implement. Most notably, the difficulty comes from the liberal use of default values in the Scala code being unsupported by Java and converting between Java/Scala collections. To improve upon this experience, a Java wrapper would be created which will call the Scala bindings. The wrapper would be designed so that it abstracts away the complexities of the Java/Scala interaction by automating the conversions, simplifying the method calls, and making the API more idiomatic for the Java inferencing use case.

  1. Advantages
    • Fastest time to market requiring the least amount of engineering effort.
    • Interaction with the native code is already done.
    • The Scala API is already designed and decided. Implementing a wrapper limits design decisions which needs to be made and keeps the APIs consistent. 
    • Allows for development continue to be focused on a single JVM implementation which can be utilized by other JVM languages.
    • The implementation, adding new features, maintenance would be greatly simplified.
    • Implementation is not one way. In the future we maintain the ability to walk this decision back and go with another implementation.
  2. Disadvantages
    • Interaction with the Scala code could be complicated due to differences in the languages. Known issues are forming Scala collections and leveraging default values. Changes to the Scala API or the introduction of a builder function seems like the most obvious solutions here.
    • Some overhead in converting collections should be expected.
    • The JAR files will be larger than they would be without Scala in the middle. Theoretically, this could be an issue for some memory constrained edge devices.

Planned Release Milestones

Milestone 1: Initial release with support for all existing Scala Inference APIs. Includes integration into the existing CI, working examples, tutorials, documentation, benchmarking, and integrations into Maven distribution pipeline.

Milestone 2: General improvements to Inference API, improved better support specific use cases, and add sparse support (required for RNNs).

Milestone 3+: ?? (Ideas include: auto grad, exposing module api, control flow support)

Known Difficulties

Converting Java collections into Scala collections - Scala and Java use different collections. Generally, these can be converted through the scala.collection.JavaConverters library. Ideally, this will be done automatically on behalf of the user. The Java methods should take Java collections, do the necessary conversion, then call the corresponding Scala method. 

Java doesn’t support methods with default arguments - The current Scala implementation makes liberal use of default arguments. For class instantiation, a simple builder pattern will work. Class methods with default values will likely need to be overloaded.

Limited by existing Scala Inference API - The current Scala Inference API is lacking support for some models such as RNNs. Since this API will be utilized by the new Java Inference API, it will be necessary to improve and expand the Scala Inference API. This work can be done in parallel and should undergo it’s own design process. On the plus side this will serve as a forcing function to improve the Scala API.

Performance

Performance should be very similar to Scala. Since both are JVM languages doing inference will be calling the same byte code from Java as it is in Scala. The only known issue which will cause a performance difference is converting the Java collections into Scala collections. Preliminary testing with simple models shows negligible to nonexistent impact to performance. Java performance should be measured via Benchmark Scripts in a manner similar to how it's measured in Scala. More details on Scala benchmarks are available here.

Preliminary comparison results are a WIP and will be added soon.

Distribution

The new Java inference API can be distributed alongside the existing Scala API. Currently, the Scala API is distributed via a jar file using a Maven repository. There is ongoing work to automate this process and ideally this work will include the new Java API as well. The design for the Automated Scala Release is available here. Releases for the Java Inference API will be aligned with the MXNet release schedule and follow the same versioning.

Improving Scala Inference API

The existing Scala Inference API will need to be expanded and improved. These changes will need to undergo their own design process and can easily be incorporated into the new Java API. Although these improvements are not a requirement to begin working on the the Java API, ideally it will be done in parallel so that the Java API will be more useful upon release.

Known improvements to could made to the Scala API include:

  • Support for RNNs
  • Adding domain specific use cases
  • Improving interface of existing APIs (for example, it should be possible to do batch inference using just an NDArray)

Existing Scala Infer API Class Diagram

Sequence Diagram

Java Inference API Design for Predictor Class

The Java Inference API will be a wrapper around the high level Scala Inference interface. Here is an example of what the Java wrapper will look like for the Scala inference Predictor class.

Predictor
/**
 * Implementation of prediction routines.
 *
 * @param modelPathPrefix     Path prefix from where to load the model artifacts.
 *                            These include the symbol, parameters, and synset.txt
 *                            Example: file://model-dir/resnet-152 (containing
 *                            resnet-152-symbol.json, resnet-152-0000.params, and synset.txt).
 * @param inputDescriptors    Descriptors defining the input node names, shape,
 *                            layout and type parameters
 *                            <p>Note: If the input Descriptors is missing batchSize
 *                            ('N' in layout), a batchSize of 1 is assumed for the model.
 * @param contexts            Device contexts on which you want to run inference; defaults to CPU
 * @param epoch               Model epoch to load; defaults to 0

 */
Predictor(String modelPathPrefix, List<DataDesc> inputDescriptors,
                List<Context> Contexts, int epoch)
/**
 * Predict using NDArray as input
 * This method is useful when the input is a batch of data
 * Note: User is responsible for managing allocation/deallocation of input/output NDArrays.
 *
 * @param inputBatch        List of NDArrays
 * @return                  Output of predictions as NDArrays
 */
List <NDArray> predictWithNDArray(List <NDArray> inputBatch)

Java Inference API usage

A primary goal of the Java Inference API is to provide a simple means for Java users to load and do inference on an existing model. Ideally, this will typically be as simple as defining the context (cpu vs gpu) to be used, defining what the input will look like, and setting up the model that will be used. After setting up the model like this, it should be simple to do input on the model.

/*
 * Psudeocode for how ObjectDetector Class can be used to do SSD detection 
 * A full working SSD example will be included in the release.
*/

// Set the context to be used
List<Context> context = new ArrayList<Context>();
context.add(Context.cpu());

// Define the shape and data type of the input
Shape inputShape = new Shape(new int[] {1, 3, 512, 512});
List<DataDesc> inputDescriptors = new ArrayList<DataDesc>();
inputDescriptors.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));


// Instantiate the object detector with the model, input descriptors, context, and epoch
JavaObjectDetector objDetector = new JavaObjectDetector(modelPathPrefix, inputDescriptors, context, 0);


// Load an image and run inference on it
BufferedImage img = JavaImageClassifier.loadImageFromFile(inputImagePath);
objDetector.imageObjectDetect(img, 3);


Open Questions

How to deal with Option[T] field in Java when calling from Scala?

On Java side:

  • Option 1: Create a wrapper class in Java that allow users to use Scala Option field smoothly, something like this.

On Scala side:

  • Option 1: Use null to replace some field defined in Scala to match Java's need
  • Option 2: Build override method in Scala to allow Java user to use them
  • Option 3: Create builder on Scala side that allow Java user by pass the field for optional.


SCALA/JAVA INTEGRATION TIP

 Construct interfaces in Java that define all types that will be passed between Java and Scala. Place these interfaces into a project that can be shared between the Java portions of code and the Scala portions of code. By limiting the features used in the integration points, there won’t be any feature mismatch issues. (referred from "scala-in-depth" page 242)

Possible Alternative Approaches

Writing a Java Inference API that directly calls the native code - Doing this would be designing and implementing a Java Inference API that will interact with the native code using the existing JNI code. The API would be designed to make Java inferencing simple and idiomatic. The existing JNI code could be shared by both the existing Scala API and the new Java Inference API. The biggest drawback to this approach is that it involves a significant amount of duplicate work that would be very difficult to maintain with current resources.

  1. Advantages
    • No overhead from converting collections.
    • No surprises from interacting with the Scala code.
    • Would likely be the faster implementations since benchmarks generally show that Java outperforms Scala.

2. Disadvantages

    • Duplication of efforts between this and the Scala API (this means reimplementing executor, ndarray, module, etc which is a significant effort).
    • Will have to reimplement off-heap memory management.
    • Added design effort to decide the Java API.

Adopt Java as the primary JVM language - This approach is basically to spend a very significant effort rewriting the entire Scala API into Java. After that was done we could begin adding support for other JVM languages using Java as a base and eventually the current Scala API could be deprecated. Obviously this involves a very significant upfront effort. Long-term it would be reasonable to expect improved performance across all JVM languages (since benchmarks generally show Java to outperform Scala) and it would likely be easier to add support for other JVM languages. The performance gains would likely be offset tremendously by the fact that most of the workload is done in the C++ native code and not in the JVM.

1. Advantages

    • Likely to see better performance across all JVM languages.

    • Easier to add support for more JVM languages in the future.


2. Disadvantages

    • Tremendous amount of upfront effort.

    • Scala API already exists and has been well received.

    • Scala is a popular data science language
    • Apache Spark is a Scala first language and is a popular analytics engine


15 Comments

  1. In general, I am hesitant to create Yet Another MXNet Language binding. I recommend and prefer to reuse the existing Scala work and write Java wrappers to make it easy for Java developers

    • Write JNI code and call MXNet backend in one place
    • Write code that generates all the MXNet APIs in Java, we did using Scala macros now we have to do using AspectJ or similar.
    • Off-heap memory management in one place – This is already a problem in Scala API and don't want to solve it in 2 different places.
    • MXNet Interaction in one place
    • Also dilutes the community focus and adds more maintenance to support 2 different APIs.

    that being said If we truly have to write all new Java Inference APIs, I would like to de-couple the MXNet internals, dependencies from the new APIs so we don't carry and cater to the design limitations of MXNet(In Scala we are already having issues to generate 300+ APIs with a insane amount of parameters that each API takes), I tried do that in Scala Inference APIs but I am sure we can do better. 

     

     

  2. Just FYI, we can create 'java' folder in current scala package, and write Java APIs there (if we do need to). You can take a look at how Spark does that. Spark is a Scala project but has pure-Java APIs, which is built right inside the Scala project.

  3. String modelSource
    Does this handle the case if we have multiple files? Most MXNet model contains .param and .json file
  4. I support 1st approach rather than writing java bindings. There is a value in fixing whatever issues we have with scala APIs - it will serve scala users as well as new users who want to use java APIs.

    So, rather than duplicating efforts, focussing on Scala APIs and making them better is a good option I feel.

  5. void predict(Collection<E> modelInput)
    should also support Structs stored in Native memory – if the user wants to use the same input for multiple operations(example different models), same for output
  6. What are limited dependencies?

    Is it possible to rethink Interface C++ engine which is common for both Java and Scala inference Solutions

    Any other options considered as an alternative for JNI?

    Inference models to be more specific to applications? CNN/RNN is too generic.

  7. Why make totally APIs other than reuse Module?

    How to use NDArray?

  8. Please mention what version of Java do you plan to support in the document

  9. In Problems - We should add abstracting by domain as one of the goal

  10. Data structures are not clear yet. Can we design the interface and module to be more generic so that we can have different backend engines. Ex: TVM / Module / Scala current enggine / small CPP binary for inference on Android or small device.

    This will help us keep top API layer to be consistent and application specific like and decoupled from core DL concepts.

  11. getPredictions ?
    Can the user not wrap in a future instead of calling a separate API? 
  12. /** * Implementation of prediction routines. * * @param modelPathPrefix Path prefix from where to load the model artifacts. * These include the symbol, parameters, and synset.txt * Example: file://model-dir/resnet-152 (containing * resnet-152-symbol.json, resnet-152-0000.params, and synset.txt). * @param inputDescriptors Descriptors defining the input node names, shape, * layout and type parameters * <p>Note: If the input Descriptors is missing batchSize * ('N' in layout), a batchSize of 1 is assumed for the model. * @param contexts Device contexts on which you want to run inference; defaults to CPU * @param epoch Model epoch to load; defaults to 0 */Predictor(String modelPathPrefix, List<DataDesc> inputDescriptors, List<Context> Contexts, int epoch) /** * Takes input as IndexedSeq one dimensional arrays and creates the NDArray needed for inference * The array will be reshaped based on the input descriptors. * * @param input: A List of a one-dimensional array. A List is needed when the model has more than one input. * @return Indexed sequence array of outputs */List <List <Float>> predict(List <List <Float>> input) /** * Predict using NDArray as input * This method is useful when the input is a batch of data * Note: User is responsible for managing allocation/deallocation of input/output NDArrays. * * @param inputBatch List of NDArrays * @return Output of predictions as NDArrays */List <NDArray> predictWithNDArray(List <NDArray> inputBatch)
    how do you plan to expose NDArray and the related operators?
  13. List <List <Float>> predict(List <List <Float>> input)
    Is this a blocking call?
  14. epoch
    Since these are Inference API's, we should probably move away from using epoch numbers and instead just use a "param file name". Not sure if this is intuitive outside of training.