Currently, MXNet requires that operators statically infer the output shapes from the input shapes. However, there exist some operators that don't meet this requirement. Examples are:

Supporting this feature requires significant modification in the core of MXNet, including graph binding, the MXNet executor and the operator interface.

Graph partitioning

Unified integration with external accelerators designs a general-purpose graph partitioning mechanism. To split a graph into parts that contain only normal operators and the ones that only have dynamic shape operators, we need to define a new node selector:

The graph partitioning creates a node for each subgraph and the subgraph will be executed by the default subgraph operator (i.e., extended CachedOp).

Imperative graph executor

To implement such an executor, we can add another mode (dynamic_shape) in CachedOp. This mode executes the nodes in a graph one by one without static shape/dtype/storage inference and memory planning in advance. For simplicity, it doesn't even need to split the computation graph into segments before pushing them to the threaded engine because many of the nodes in the graph contain a subgraph. When executing an operator in a node, it creates empty NDArrays with engine variables and pass them to the operator as output arrays (basically, we need to reimplement mxnet::imperative::RunGraph).

One of the biggest problem remains: when and where to run shape inference code. Because we can't infer the shape of the output arrays in advance, the input arrays of the next operator don't have shape info when this operator is pushed to the engine. There are two options:

  1. we run shape/dtype inference for an operator in the main thread like what we are doing right now, but we need to wait for the previous operator to complete. This option complies with the current implementation for shape/dtype inference and all inference computation runs in the same thread. The problem is that synchronization prevents multi-GPU parallelism.
  2. we encapsulate the operator computation and shape/dtype inference in a lambda and push everything to the engine thread. This option has two advantages: 1) allows multi-GPU parallelism, 2) parallelizes shape/dtype inference on a computation graph to reduce the inference overhead (dynamic shape requires shape inference in every batch). The problem is that now the inference code no longer runs in the main thread and can run in parallel. This may break the original assumption for the inference code.

For simplicity, we can start with the first option and experiment with the next option for better performance.

For the symbol execution, we can define a new executor called DynamicShapeGraphExecutor, whose job is to execute the computation in the graph imperatively. In Executor::Bind, we first construct a graph from a given symbol and test if the graph has operators that can't infer shapes statically. If such operators exist, we create DynamicShapeGraphExecutor to run the graph; otherwise, we create GraphExecutor as we have right now. DynamicShapeGraphExecutor uses CachedOp to perform actually computation.

Symbol shape

Once the graph executor supports dynamic shape, we can utilize the mechanism to support symbol shapes. In this case, shape() of a symbol becomes a normal operator and returns an NDArray that store the shape of the symbol. For each operator that uses shape as input, now we need to add a C++ version that takes this NDArray as input and use it to infer the shape of the output of this operator. Take mx.sym.ones for example. We can create a hidden operator named _ones_shape, which takes an NDArray as shape. We'll need to create a Python wrapper that calls the original the ones operator if the input shape is a tuple, and calls _ones_shape if the input shape is a symbol. This approach doesn't require significant changes in the backend (NNVM and executor) and provide backward compatibility. However, we have to do the same thing for every front-end language.

We can enable dynamic shape inference in three steps:

  1. Modify Imperative::Invoke to run operators that don't support static shape inference. In this way, we can invoke this kind of operators in the imperative mode.
  2. Implement the dynamic_shape mode in CachedOp.
  3. Support symbol shape for some common operators.
  4. Create DynamicShapeGraphExecutor that partitions a graph and executes the operators in the new graph imperatively.