Problem statement

DL models, besides compute intensive operations like convolutions and fully connected layers, feature a lot of simple pointwise (aka elementwise) operations (like elementwise addition etc.). Performance of those operations is fully memory bandwidth bound and so limit speedups from newer GPU hardware, which typically has high compute/memory bandwidth ratio. There are multiple attempts (e.g. TVM) ongoing to use compiler technology in order to deal with this and other performance problems. However, integration of e.g. TVM into MXNet is a long term effort and there is a need for a simpler, more focused, approach to deal with this problem in the meantime. This document proposes such approach.

Proposed Approach (High level)

  1. Introduce a new op (named _FusedOp) which holds a subgraph of pointwise ops to be fused (either forward or backward). That op does not have gradient (so fusion on forward pass and backward pass happens independently). During runtime, the first time it is run, it creates a GPU CUDA kernel code from the subgraph, compiles it with NVRTC and launches created function. During subsequent launches, the compiled function is reused as long as types of inputs did not change (shape change should not require recompilation).
  2. Introduce graph passes that look for subgraphs made of compatible pointwise ops and replace them with proper _FusedOp nodes.
  3. Fusion is guarded by MXNET_USE_FUSION environment variable. It should be decided what the default should be.

Proposed Approach (low level)

The approach as described in the previous paragraph works well when applying to inference workloads. However, when writing PoC code for backward pass fusion for training, we encountered problems with shape and type inference of backward nodes, which required additional changes and helper ops. This section describes the problem encountered and proposed solution to it.

Shape and type inference of backward nodes - intro to the problem

In MXNet, backward nodes generally do not have their own InferShape and InferType functions. Instead, they rely on their forward nodes (accessed via control deps in NNVM graph) to get the attributes (since the input and output gradients from backward node have the same attributes as, respectively, outputs and inputs to forward node). What is more, it is not really possible to introduce those functions to every backward op: to illustrate that, let us take a Cast op. This op has dtype parameter, which is a type of its output. The type of its input needs to come from the previous ops in the graph - which is OK. However, for the backward of Cast, this same dtype parameter gives information on the input gradients to the op, not the output, which makes it impossible to do type inference. The same problem happens with shape inference and reshape operation.

This is important problem when applying fusion that is independent on forward and backward pass, because what may happen is that the forward node that some backward node relies on for attribute inference is fused away and no longer accessible. To demonstrate those problems (and the proposed solution) let us look at 2 very simple networks: out = a + a * b and out = a * b.

The problem and solution, part 1

To illustrate the first part of the problem and solution to it, in this section we will assume fusion happening only on forward pass and network out = a + a * b.

The original graph (full, with both forward and backward) of this network looks like this (solid arrows are dataflow connections, dotted arrows are control dependencies):

The red edge on the forward pass (between mul and add) has exactly the same attributes as the red edge on the backward pass (between backward_add and backward_mul). To infer those attributes backward_add node uses its control dependency to add node and checks that node's input attributes. However, after the forward pass is fused, the graph changes:


Now red edge in the forward graph is fused away and backward nodes do not have enough info anymore to infer the attributes of their corresponding red edge - infershape and infertype fail.

The proposed solution to this problem is introduction of new nodes in the graph that contain pointers to the fused op and node_id of the proper forward node in the fused op's subgraph (dashed line is a pointer, not an actual edge in the graph!):

Now, during the inferattr pass, backward nodes can ask the helper node about their attribute values. To answer that request helper nodes reach out to the fused op and ask for attributes of node with their stored node_id from the fused op subgraph. That way backward_add can get info from add node inside the fused op subgraph and so fill the red edge attributes.

The reason for using pointers instead of graph edges will become apparent in the next section.

The problem and solution, part 2

This time we will look into a network out = a * b and full (forward and backward) fusion. The original graph:

Here the fusion does not happen on the forward pass, because mul is the only pointwise operation there (and so there is nothing to fuse). However, fusion can happen on the backward side, since there are 2 operations there: identity and backward_mul. The fused graph looks as follows:

This time problem happens inside the subgraph of fused op - the backward_mul node inside the subgraph does not have any connection to the mul node ( (since they are separate graphs), and therefore it cannot use its attributes. What is more, even if such connection existed, the attributes are stored inside a graph instead of a node, so they would still be inaccessible from inside a subgraph.

The proposed solution here is similar to the one from the previous section.

  1. During construction of the subgraph, transfer all control dependencies from the fused nodes to the fused op.
  2. Inside the subgraph introduce helper nodes which have a pointer to the fused op (they cannot have an edge since they are not in the same graph) and index of the transferred control dependency.
  3. During InferAttr pass, when a fusion node is encountered, it requests values of attributes from all of its control dependencies and stores them.
  4. Then it launches its internal inferattr on a subgraph, where helper nodes access this stored information.

The resulting graph:


Here during outer graph's inferattr, fused op gets the attribute values from all its control deps (there is only 1, mul) and then launches its own internal inferattr. There, backward_mul requests the values from its helper, which grabs values for dependency 0 from the fused op.

Ghost node property

The helper nodes are only metadata nodes, they don;t have inputs, outputs nor even FCompute function. Therefore they should not be considered when doing graph preparation for execution. A new option was added to NNVM, TIsGhost, to mark ops that should not exist in IndexedGraph. Since all of the attribute passes as well as execution relies solely on the IndexedGraph and not actual Graph, this is enough to make those nodes invisible to the rest of MXNet.


Benchmark results gethered on TitanV GPU.

  • I took a residual unit from
import mxnet as mx
import numpy as np
import os
import mxnet.gluon as gluon
import time

n = 500
m = 100
l = 1500

cell = gluon.rnn.ResidualCell(gluon.rnn.GRUCell(n, prefix='rnn_'))
inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(2)]
outputs, _ = cell.unroll(2, inputs)
outputs = mx.sym.Group(outputs)

orig = outputs.simple_bind(ctx=mx.gpu(0), rnn_t0_data=(m,n), rnn_t1_data=(m,n),
                                          rnn_i2h_weight=(l,n), rnn_i2h_bias=(l,),
                                          rnn_h2h_weight=(l,n), rnn_h2h_bias=(l,))
fused = outputs.simple_bind(ctx=mx.gpu(0), rnn_t0_data=(m,n), rnn_t1_data=(m,n),
                                           rnn_i2h_weight=(l,n), rnn_i2h_bias=(l,),
                                           rnn_h2h_weight=(l,n), rnn_h2h_bias=(l,))


t = time.time()
for i in range(500):
print("Original time: ", time.time() - t)
t = time.time()
for i in range(500):
print("Fused time: ", time.time() - t)

I tested it with original values from the test (50, 10, 150) and 10 times larger (500, 100, 1500). The first set gave 46 ms vs 90 ms, and the larger set gave 117 ms vs 159 ms → over 35% speedup in both cases.

  • ResNet-50 (using NHWC data layout for convolutions, bs 128) → 930 imgs/s without fusion, 1050 with fusion 

  • No labels


  1. Can you describe more on the helper node ? What is the difference between FusedOpOutHelper and FusedOpHelper ?

    Do you mean the op with _identity_with_attr_like_rhs ? Is this the op that needs to be fused and optimized for the backward pass ?

    How does the algorithm decide how to fuse something ? Is there a list of ops to be fused that need to be maintained ?

  2. The difference between the FusedOpHelper (which I guess should be named FusedOpInHelper to be consistent) and FusedOpOutHelper:

    • FusedOpHelper is a helper that is used by backward nodes to access forward nodes that have been fused away (so it reaches into the fusion subgraph to get the information)
    • FusedOpOutHelper is a helper that is used by backward nodes that were fused to access forward nodes outside the fusion (so it reaches outside the fusion subgraph to get the information)

    In the examples I gave the id op is identity_with_attr_like_rhs, but that is only because I was looking for simplest possible examples. There are different examples where backward and forward pass do not match. For example, if you take a popular ResNet architecture, it contains (on a forward pass) an add+relu+split block. In the backward pass this corresponds to add (backward of split) + relu_backward+split (backward of add). You do not want to fuse split, because then your fused op would have to have 2 outputs that need to be populated (and with split it is the same output that goes in 2 directions). So in the end you get:

    forward:      (add+relu)+split

    backward:    split+(backward_relu+add)

    which do not match.

    The pointwise operations fusion is pretty simple in that those are all memory bandwidth bound operations, and so the more you fuse the better, there is no need to use any mechanisms like tuning etc.. There is a list of ops that are eligible to be fused, that is also a mapping between MXNet op name and the code that needs to be generated for it.

  3. Did you create the first graph by hand? I would expect the inputs of backward_add to contain "a" and the output of "a*b" 

    Also the "id" node is not clear to me. Is this a real node that we insert? I haven't seen this before.

    Can't you use control deps instead of adding new nodes?

    Does this cause a problem if you don't retain the graph and is freed? Do you have guarantees that the pointers of the new nodes remain valid?

    I think the solution is certainly clever. I'm concerned about the increase of complexity of code which I'm not already a big fan of. I would much rather have propper, ideally inmutable graph transformations, instead of more mutations and spaghetti of pointers. But I guess that's why you say it's a temporary solution...

  4. Backward of add is just a propagation of the input gradient to other nodes, it does not need any other inputs.

    The "id" node is _identity_with_attrs_like_rhs op, that takes the input gradient and reshapes it to match the output of the forward pass. It is a real op in the graph generated by MXNet, yes.

    I'm not sure what do you mean by "Can't you use control deps instead of adding new nodes?". If you mean by that - can't we have a control dep from the backward node to the fusion instead - then the answer is no. Basically you need this additional info of "what is my corresponding forward node in the fusion". Even the current (non-fusion, the current MXNet) solution is very brittle in that regard - basically it relies on the fact that in the Gradient of a given op there will be only 1 instance of a given backward op. This is OK if you do 1:1 forward/backward mapping (even though NNVM does not force you to do it, you can make FGradient function that creates subgraph as a gradient), but is not OK for the fusion.

    What do you mean by "don't retain a graph"? Like in Gluon non-hybridized mode? For Gluon this fusion works only with the full graph inside CachedOp, so hybridize without inlining (so basically static_alloc=True).

    The code change to core MXNet  is actually not that big - basically the only meaningful difference happens in inferattr pass (for symbolic). For Gluon I needed to reorganize slightly where different things are kept in order to ensure that the fusion happens only in CUDA context (so more things are kept in per-context state vs global CachedOp).