This document reviews the detailed design of Keras-MXNet library in the current state. By end of this document, we aim to capture a clear understanding of desing of Keras, Keras-MXNet integration design, design difference between Keras and MXNet, technical challenges in MXNet that are blocking us from being fully native to Keras core engine and getting merged to keras-team/keras.

Note: This document does not propose solutions to be implemented in MXNet for overcoming the design differences with Keras engine. However, this document aims at laying the foundation for a clear understanding of the Keras library and detailed technical challenges. In next part of the document, we will come up with path towards merging keras-mxnet with keras-team/keras.

Keras-MXNet Class Diagram

  1. A “Model” in Keras “is a Container”, “composed of Layers” that are connected to each other via “Nodes”.
  2. One “specialized Model” with a sequence of Layers stacked is called “Sequential Model”.
  3. A “Model is associated” with “Initializers” for weight/bias and other tensors initialization.
  4. A “Model is associated” with “losses”, “metrics”, “optimizers” and “callbacks” for managing the model training and inference workflows.
  5. “MXNet Model is a specialized Model” for managing the training and inference workflows for MXNet backend. This specialization is necessary because MXNet does not support Shared Tensors for in-place updates, Symbolic Optimizer and Low-Level Functional API for Graph execution. More about this in Keras-MXNet Design Differences section later in the document.
  6. “MXNet Optimizer is a specialized Optimizer” for exposing Keras optimizer as MXNet optimizer. This specialization is necessary because MXNet does not support symbolic optimizer and low-level Functional API for Graph Execution.
  7. “MXNet Model is associated” with “BucketingModule” of MXNet for managing training and inference workflows for MXNet backend using Module's forward, backward, update APIs.
  8. “Keras Symbol” is a tensor data structure used in Keras-MXNet that “is composed” of a tensor data (NDArray) and a symbol (Symbol). This data structure was introduced to be compatible with Keras Tensor concept that includes both the data and symbol.
  9. “MXNet Backend” is a collection of Keras low-level operator implementation using MXNet symbolic and NDArray APIs.

Let us see a simple Keras sigmoid operator implementation in MXNet backend using MXNet Symbolic APIs.

def sigmoid(x):
    """Element-wise sigmoid.

    # Arguments
        x: A tensor or variable.

    # Returns
        A tensor.
    return KerasSymbol(mx.sym.Activation(data=x.symbol, act_type='sigmoid'))

10. All “MXNet Core Components” utilize low-level operator implementation in the “Backend”.
Let us see a simple Dense Layer Implementation in Keras Core that utilizes low-level operators implemented in the Backend.
Remember that a Dense Layer Forward Propagation is represented as:
Output = activation(W.X + b)

Class Dense(Layer):
# K => Backend

     def call(self, inputs):
          # W.X
          output = K.dot(inputs, self.kernel)

          # W.X + b
          if self.use_bias:
              output = K.bias_add(output, self.bias)

          # activation(W.X + b)
          if self.activation is not None:
              output = self.activation(output)
          return output

# Sigmoid Activation Implementation
def sigmoid(x):
    return K.sigmoid(x)

Keras and MXNet Design Differences

  1. Symbolic Gradient Operator: Given a graph, return a graph that computes gradients (graph for the backward pass). MXNet does not support this currently as an operator.
    1. In Keras, this operator is used by core Keras engine, to prepare the graph for computing gradients (backward pass). This graph is then concatenated to the symbolic graph (forward pass). Together, Keras engine will have 1 symbolic graph that represents both forward and backward pass.
  2. Symbolic Optimizer: MXNet (KVStore) does not support symbolic optimizer. KVStore uses imperative optimizer logic for gradient updates.
    1. In Keras, all the optimizer are symbolic, i.e., Optimizers prepare a symbolic graph for updating the weights symbol with gradients symbol and optimizer logic. This optimizer symbolic graph is then concatenated together with an already concatenated symbolic graph giving one graph = network symbolic graph + gradient symbolic graph + optimizer weight update symbolic graph
  3. Shared Tensors (Variables) / In place update ops: MXNet does not support a use case of shared symbols pointing to same data (ndarray). i.e., 2 symbols pointing to the same ndarray, that can be modified from 2 different symbolic graphs.
    1. In Keras, Weights are symbolic tensor that is used in network graph (forward pass) and later used in optimizer symbolic graph for weight updates with gradients.
  4. Low-Level Functional API for graph execution: MXNet Module is a high level construct capable of handling Neural Network training/evaluation/inference. It is composed of Optimizer, KVStore and more. However, MXNet does not support a low-level graph execution interface.
    1. In Keras, core engine prepares one symbolic graph (computation_graph + gradient_computation + optimizer_updates). This graph is provided to the low-level functional API to be executed for each input_batch.

Keras-MXNet Specific Classes

Below, we zoom into specialized data structures introduced in the MXNet Backend.

NOTE: Not all components/attributes/operations are represented in above class diagram. Few critical components are represented for illustration of the design. See Appendix below for more detailed Keras Layers design.

In Keras, a tensor is a representation of a Symbol with Data. Since, in MXNet, a Symbol is just a symbol but do not have a data (NDArray), we create a new data structure KerasSymbol that is composed of both the data (NDArray) and symbol (Symbol).

KerasSymbol is the node in the symbolic graph of Keras-MXNet. KerasSymbol maintains a list of neighboring symbols in computation graph. Starting with input KerasSymbol you can traverse the complete computation graph. KerasSymbol, as described earlier, also contains data i.e., it provides operations for binding the NDArray value to the Symbol.

Keras-MXNet Optimizer

Keras-MXNet Optimizer extends both Keras optimizer and MXNet optimizer as shown in class diagram. This specialization is necessary because MXNet does not support symbolic optimizer and in place updates. This optimizer is then attached to the Module in MXNet Model.

MXNet Model

MXNet Model is a specialized Keras Model. It uses MXNet's BucketingModule for managing training and inference workflows. It maintains 3 buckets: 1) train 2) test 3) pred for corresponding phases. MXNet Model maintains arg_params and aux_params to be able to extract with weights associated with the current model (module).

Keras Training Workflow

Below, we provide the Keras training workflow sequence diagram, where we try to create a “Sequential Model”, add a “Dense Layer” to the network, compile and fit the model.

Keras-MXNet Training Workflow

Above, we saw the sequence diagram of Keras native Training workflow. For the same Sequential Model with 1 Dense layer usecase, below we provide the sequence diagram of training workflow in Keras-MXNet. Observe the highlighted difference in Model compile and Model fit stages. We utilize specialized MXNetModel to prepare MXNet module for training the model (forward, backward, update).

Performance Tuning in Keras-MXNet

Handling conv kernels and image_data_format


  1. Keras creates Convolution Kernels in 'channels_last' format irrespective of 'image_data_format'.
  2. This creates many costly transpose operations when users are using 'channels_first' image_data_format, which is also preferred for MXNet backend for best performance.


  1. Update Keras Convolution layers to respect the user's choice of 'image_data_format'. Create kernels, filters and other tensors in 'channels_first' if image_data_format by user is 'channels_first'.

Performance Gain

  1. ~15% improvement in Convolutional Layers.

Batchnorm channels_last


  1. MXNet BatchNorm operator on GPU does not use cuDNN implementation if 'channels_last' (axis=-1). GitHub Issue


  1. Since image_data_format is user's choice, we handle this situation with Warnings and suggestions for updating the 'image_data_format' to 'channels_first'.

Performance Gain

  1. We observe that BatchNorm in channels_first(axis=1) is up to 10X faster than BatchNorm in channels_last(axis=-1).

handling multi_gpu_model training


  1. Keras multi_gpu_model API, creates a clone of the model on each device (GPU) and one on the CPU. During training, divides the batch of input data across all the devices and calculates the gradients. CPU copy of the model is used to aggregate the gradients, update the weights and sync with other devices. This is synchronous and CPU bound, keeping GPU usage low and high communication cost per epoch.


  1. In Keras-MXNet, we specialize the Keras Model with MXNet Model. From multi_gpu_model Keras API, set the right context (multiple GPU IDs) in the module and efficiently handle model training within the module (KVStore, PS).

Performance Gain

  1. In the Benchmarks, we have seen close to linear scaling across multiple GPUs with Keras-MXNet. (7.2X on 8 GPUs).
  2. Keras-MXNet is upto 3X faster than Keras-TF on multi-GPU training.

Other Challenges

  1. Keras is designed around TF and Theano conventions and operators. Few fundamental conventions like Shared Tensors and operator signatures are different with MXNet.
  2. There are many missing fundamental operators in MXNet. Example: depthwise_convolution, separable_convolution, local_convolution, “SAME” mode pooling/padding, cumsum and more.
  3. Control Flow operators are major missing block in MXNet for RNN usecases.
  4. Not all Keras operators are 1:1 mapping with MXNet symbolic operators. For example, requires preparing padding, kernels, filters in Conv operators.
  5. Handling channels_first and channels_last could be more graceful in MXNet Conv operators.
  6. MXNet low level operators are not optimized - Example: mx.sym.dot(), mx.sym.broadcast_add() are slower. GitHub Issue


We now understand the Keras design, Keras-MXNet integration design and differences in Keras core and MXNet design. Below we summarize the design differences and functionalities to be supported in MXNet as a path forward for us to merge Keras-MXNet to keras-team/keras.

  1. Symbolic Gradient Operator
  2. Shared Tensors, Update Ops
  3. Low-Level functional API for graph execution
  4. Symbolic optimizers in KVStore

References in other backends

  1. Theano Symbolic Gradient Operator - http://deeplearning.net/software/theano/library/gradient.html
  2. TF Symbolic Gradient API - https://www.tensorflow.org/api_docs/python/tf/gradients
  3. Theano Functional API - http://deeplearning.net/software/theano/library/compile/function.html
  4. TensorFlow Session.run is used as low-level graph execution API in Keras.

Other Minor Differences

  1. MXNet does not support Scalars.
  2. MXNet does not support Booleans.
  3. MXNet does not support control flow operators (WIP by Zheng Da and team)
  4. MXNet does not support natively few critical operators - SeparableConv, DepthwiseConv, “SAME” mode padding/pooling.


Keras Layers Design

NOTE: Not all components/attributes/operations are represented in above class diagram. Few critical components are represented for illustration of the design.

  • No labels