In this design document, we will talk about the implementation of Keras RNN layers with MXNet Backend. This is not a new API design in Keras/MXNet, it's a design to support RNN usability.

All functionality talked are pure symbolic and we used MXNet Symbol API.

1. Feature Shepherd:

Sandeep Krishnamurthy @Kalyanee Chendke

2. Problem:


In Keras RNN Layers with MXNet Backend, users can add RNN layers and specify whether to unroll it or not.

  • unroll: Boolean (default False). If True, the network will be unrolled, else a symbolic loop will be used. Unrolling can speed-up a RNN, although it tends to be more memory-intensive. Unrolling is only suitable for short sequences.

For example a typical LSTM network construction in Keras is as follows:

model = Sequential()
model.add(Embedding(max_features, 128))
model.add(LSTM(128, dropout=0.2, recurrent_dropout=0.2))
model.add(Dense(1, activation='sigmoid'))

Here unroll in LSTM Layer is set to False by default. If users want to unroll the LSTM layer for speed up, the following code will be used:

model.add(LSTM(128, dropout=0.2, recurrent_dropout=0.2, unroll=True))

Problem statement

In Keras-MXNet previous releases (v2.2.2 and older), the usage of RNN is limited and user experience was impacted:

  1. Forced unrolling due to lack of control flow operators, it can be very memory intensive for large sequences. Option to not unroll is needed.
  2. Does not support variable input length, users have to specify input length and pad input sequence to same length manually.
  3. Keras RNN related unit tests were disabled since it used unroll=False we didn't support. Dropout in RNN cells does not work

As MXNet added control flow operators (foreach, while_loop) in 1.3.0, we can use them to address the above issues.

3. Goals:

To resolve the problems above

  1. Support RNN Layers with unroll=False by default
  2. Support variable length inputs and dropouts, users does not have to change RNN code to use Keras-MXNet
  3. Enable Keras-RNN unit tests that was disabled in v2.2.2 and previous releases (18 total)

4. Open Questions:

  1. Can we change foreach operator to take a KerasSymbol? 
  2. Is there any improvement on KerasSymbol design?

5. Prerequisites/Background Knowledge

Following are some of the background knowledge needed prior to coming up with a design to solve the problem. Due to length of the document, we only provide necessary points and please follow reference links for more details.

5.1 RNN Layers and RNN unrolling

  1. Keras RNN APIs
    1. Similar to Gluon RNN API (RNN, LSTM, GRU), but is pure symbolic
    2. Iterate over the time dimension of input and apply a step function based on different type of RNN. (A step function is a user defined function, similar to hybrid_forward in Gluon RNN Cell)
    3. Outputs of previous time step are inputs of current time step. There are cycles in computation graph.
  2. What is RNN unrolling?
    1. Unrolling is to unfold and make N (N=number of time steps) copies of the graph to execute in sequence, so the cycle is removed.
      1. Implemented by slicing the input(Symbol) over time dimension, using python control flow (for loop) to iterate and concatenate back at the end
    2. Not unrolling is to use a symbolic loop (Symbolic control flow operator) to iterate in-place on the input(Symbol)
      1. implemented in this design
      2. Symbolic loop is turning dynamic python like loop into static graphs (MXNet Symbol control flow Design)

                   unrolled RNN

RNN with cycle                                      unrolled RNN

5.2 Control Flow Operators in MXNet

We will use foreach operator in our RNN implementation.  In one sentence summary, it runs a for loop with user-defined computation over Symbols on time dimension. Note, here body is the same as step, it's the user-defined python function (called 'body' in MXNet and 'step' in Keras ) that applies to every time step of input. Please refer to control flow design document for more details

Following is the function signature from MXNet:

"foreach" is a special form of loops. It iterates over the time dimension of the input NDArray/Symbol, so the number of iterations is determined before entering the loop.

foreach(body, input, state)

Input arguments:

  • "input" is a symbol/NDArray or a list of symbols/NDArrays.
  • "body" is a Python function that defines computation for each iteration.
  • "state" is a list of symbols/NDArrays passed to "body" as part of the inputs for the first iteration.

Return values:

  • A tuple of (out_data, state), where "out_data" is a symbol/NDArray or a list of symbols/NDArrays that is a concatenation of all outputs from "body" and "state" is the output state in the last iteration.

The signature of "body" is

def body(input, state): output, new_state

"input" is a symbol/NDArray or a list of symbols/NDArrays that is a slice from the input arrays of "foreach"; "state" is a list of symbols/NDArrays that represent data from the previous iteration; "output" is a symbol/NDArray or a list of symbols/NDArrays that contains the output data generated in this iteration; "new_state" is a list of symbols/NDArrays that contain data passed to the next iteration. All "output" from this function are concatenated as the output of "foreach". As such, the shape and type of "output" from each iteration should always be the same. "body" is invoked once to generate a symbol that represents the computation in the function.

5.3 Keras-MXNet Design

Due to the design differences in TensorFlow and MXNet, MXNet backend is designed differently than other backends. Key concept which is required for this design is the definition of KerasSymbol:

“Keras Symbol” is a tensor data structure used in Keras-MXNet that “is composed” of a tensor data (NDArray) and a symbol (Symbol). This data structure was introduced to be compatible with Keras Tensor concept that includes both the data and symbol.

In summary: it's a wrapper class to close the gap between Keras front end (required Tensor with symbol and data) to MXNet backend (Symbol only has Symbol, data is introduced at binding stage)

A common technique used in Keras MXNet backend is to convert KerasSymbol to MXNet Symbol, apply MXNet operation, and convert returned MXNet Symbol back to KerasSymbol. For example, the K.square operator is defined as:

def square(x):
    """Element-wise square.

    # Arguments
        x: Tensor or variable.

    # Returns
        A tensor.
    return KerasSymbol(mx.sym.square(data=x.symbol))

For more details refer to Keras-MXNet design.

5.4 Keras-MXNet RNN Operator

Now let's take a look at Keras-MXNet RNN operator, following is how RNN operator is called in Keras:

When a RNN layer is added to model, inside RNN Layer, K.rnn is called (similar to F.operator in Gluon), and step function defined by Keras is passed:

last_output, outputs, states = K.rnn(step,

What required from MXNet backend is given inputs and initial_state, it will loop over inputs and apply a step function at every time step (0 dimension of inputs). A step function define how inputs and hidden states are updated according to different RNN Cells (Simple, GRU, LSTM). Similar to hybrid_forward in Gluon.

Ideally, we just need to call MXNet foreach operator and pass on the definition of step function, inputs and states so it does the job. However, due to the Keras-MXNet design in 5.3, it's more complex than that and requires some special handling. We will talk about this with more detail in proposed design.

This is the idea implementation of RNN operator in MXNet backend.

def rnn(step, inputs, initial_state):
    outputs, states = mx.symbol.contrib.foreach(step, inputs, initial_state)
    return outputs, states

6. Proposed Approaches

In summary, the key challenge this design is trying to solve is how to make MXNet Symbolic foreach operator to take a body(step function) that operates only on KerasSymbol.

6.1 Step function in Keras

In Keras, it's defined using various different K.operators, it can be K.sum,, etc. All the K.operators operates on KerasSymbol (contains a MXNet symbol as its symbol, and a MXNet NDArray as its data) Following is an example of the step function defined in Keras SimpleRNNCell:

def call(self, inputs, states, training=None):
    prev_output = states[0]
    if 0 < self.dropout < 1 and self._dropout_mask is None:
        self._dropout_mask = _generate_dropout_mask(
    if (0 < self.recurrent_dropout < 1 and
            self._recurrent_dropout_mask is None):
        self._recurrent_dropout_mask = _generate_dropout_mask(

    dp_mask = self._dropout_mask
    rec_dp_mask = self._recurrent_dropout_mask

    if dp_mask is not None:
        h = * dp_mask, self.kernel)
        h =, self.kernel)
    if self.bias is not None:
        h = K.bias_add(h, self.bias)

    if rec_dp_mask is not None:
        prev_output *= rec_dp_mask
    output = h +, self.recurrent_kernel)
    if self.activation is not None:
        output = self.activation(output)

    # Properly set learning phase on output tensor.
    if 0 < self.dropout + self.recurrent_dropout:
        if training is None:
            output._uses_learning_phase = True
    return output, [output]

6.2  Make Keras Step Functions work with MXNet Foreach Operator

Now a MXNet symbolic foreach operator takes only a MXNet symbol or list of MXNet symbols. Same applies to the step function the foreach operator uses. It does not recognize a step function defined in Keras and KerasySymbol.  

To resolve the above problem, we need to define step functions in pure MXNet symbols according to each RNN Cell in Keras. In order to do that, we need to pass the cell object in Keras to MXNet backend so we can retrieve the cell configuration and related kernel weights. With this approach, we are able to run control flow ops in RNN operator.

Following is the design diagram explaining how step function is converted from Keras to MXNet Symbol in MXNet backend.

6.3 pseudo code of the implementation

Inside Keras RNN Layer, the invocation becomes, we are passing additional cell object.

last_output, outputs, states = K.rnn(step,

In MXNet backend, it becomes:

def rnn(step, inputs, initial_states, cell):
  # define step function in pure mxnet for each cell
  def _simple_rnn_cell_step(data, states):
     .... pure mxnet operators ....
  def _lstm_cell_step(data, states):
     .... pure mxnet operators ....

  def _gru_cell_step(data, states):
    .... pure mxnet operators ....
  # covnert Keras Symbol to MXNet Symbol 
  inputs_mx = inputs.symbol 
  initial_states_mx = initial_states.symbol 
  # choose the pure mxnet step functions according to cell type
  if cell == SimpleRNNCell:
    _step_mx = _simple_rnn_cell_step
  elif cell == LSTMCell:
    _step_mx = _lstm_cell_step
  elif cell == GRUCell:
    _step_mx = _gru_cell_step
    _step_mx = _custom_cell_step

  # call mxnet foreach operator and pass mxnet symbols 
  outputs_mx, states_mx = mx.sym.contrib.foreach(step_mx, inputs_mx, states_mx) 

  # warp outputs to KerasSymbol and return to Keras 
  return KerasSymbol(outputs_mx), KerasSymbol(states_mx)

6.4 Drawback on the design

High maintain cost, have to redefine step function in pure MXNet Symbols for each RNN Cell in Keras. Have to manually add support for each new RNN Cell.

7. Alternative Approach: Conversion between Keras step function and MXNet step function

Note: This is the initial design I tried. It's a generic solution but it does not work. this sections explain why this approach does not work. You may skip this section

How about we apply the same technique in 5.3 to the step function, we can create a MXNet version of step function, called step_mx, it takes MXNet Symbol from foreach operator. Inside, it coverts the inputs back to Keras Symbol and call step function defined in Keras.

Following is the pseudo code: 

def rnn(step, inputs, initial_states):
    # covnert Keras Symbol to MXNet Symbol
    inputs_mx = inputs.symbol
    initial_states_mx = initial_states.symbol
    # convert Keras Step function to MXNet step function
    def step_mx(inputs_mx_i, states_mx):
        inputs_i = KerasSymbol(inputs_mx_i)
        states = KerasSymbol(states_mx)
        # call Keras defined step function
        output_i, states = step(inputs_i, states)
        # convert back to MXNet symbol and return to foreach operator
        output_i_mx = output_i.symbol
        states_mx = states.symbol
        return output_i_mx, states_mx
    # call mxnet foreach operator and pass mxnet symbols
    outputs_mx, states_mx = mx.sym.contrib.foreach(step_mx, inputs_mx, states_mx)
    # warp outputs to KerasSymbol and return to Keras
    return KerasSymbol(outputs_mx), KerasSymbol(states_mx)

Following is the flow chart of the above code:

Problem with this design:

There are two conversions between KerasSymbol and MXNet symbol, one before calling foreach operator, and one before calling the step function. However, the former is only called once while the later will be called every time step_mx is applied, which equals the number of time steps in inputs. This will result in too many conversions, and we are actually doing the conversion inside MXNet operator. This will raise an error as we are trying to create a KerasSymbol inside MXNet forearch operator.

8. Addition of New APIs


9. Backward Compatibility

This design is backward compatible, it's adding a new functionality, no previous functions were impacted

10. Test Plan

Proposed functionality is covered by Keras Unit tests and Integration Tests on RNN Layers. These RNN tests were first disabled in our first release as RNN is experimental. Now these tests were re-enabled after this design.

11. Technical Challenges

The main technical challenge is the KerasSymbol design in Keras-MXNet introduced difficulty to use foreach operator.  Keras requires Tensors to be defined similar to Tensors in TensorFlow which have both Symbol and Data attributes. TensorFlow Tensors are accepted in both Keras front end and TensorFlow backend, no conversion is needed. KerasSymbol Class was introduced to close the gap but it brought additional cost of conversion between Keras Tensors, MXNet Symbols and KerasSymbol. It caused problem making MXNet control flow operators to accept Keras step functions, and we have to design carefully for it to work.

Open questions in section 4 were not explored.

12. Milestones


The implementation can be found in this PR:


With this support, now we have the following benefit:

  1. Do not need to do forced unrolling when using RNN in Keras
  2. Users also don't need to specify the input length
  3. Dropout works out of the box according to the config user defined in Keras
  4. All examples work out of the box, no extra modification needed
  5. 18 unit tests were enabled and been tested in CI for PR and nightly tests. 

Note: For the best performance, it's still recommended to unroll RNN cells according to Keras documentation:

13. References

1. Keras-MXNet v2.2.2 unsupported functionalities:

2. MXNet control flow operators:

3. Keras-MXNet RNN with forced unrolling:

4. Foreach opreator in MXNet:

5. Keras-MXNet design:

  • No labels