Currently MXNet FFI overhead is significantly larger than official numpy [1]. MXNet FFI passes its arguments mainly via ctypes. All keyword arguments are serialized to string at frontend and deserialized from string at backend, which is slow. Also the number of APIs exposed to dll is very large, which is difficult to maintain. For now, few of them have cython interfaces, which worsens the FFI performance.

Admittedly FFI overhead is negligible in computation intensive ops, but for widely used small ops like zeros, the overhead can still take up a significant amount of time. For example, when pertaining BERT [2], zeroing gradient arrays takes around 5% of time. This encourages developers to implement efficient aggregated ops, which takes effort.


The TVM FFI module serves as a good replacement, as is suggested by @tqchen (also thank him for proposing a feasible workflow). It has been successfully adopted in DGL [3]. There are a few advantages:

  • It is fast when accelerated by cython. The performance of TVM FFI is firmly supported by benchmark results [4]. Also I built a poc (see that illustrates this in the Performance section.
  • It provides a fixed set of APIs exposed to dll, and new APIs (not exposed to dll) will be built upon them. To be more specific, new APIs are registered as PackedFunc into a global registry. We can call any of these PackedFunc via one exposed API.


  • Adapt TVM registry for MXNet PackedFunc registration in backend

  • Adapt TVM function as MXNet PackedFunc wrapper in frontend

  • Extend TVMValue for customized and efficient MXNet argument passing.


A demo op np.zeros with the new FFI interface is implemented as an example. Files mentioned in the example can be found in the POC mentioned above. It will help to read both of the frontend and backend registration if one wants to register some other ops.

Firstly, in frontend, np.zeros is registered in python/mxnet/ndarray/numpy/ as usual:

Front End
def zeros(shape, dtype=None, order='C', ctx=None):  # pylint: disable=redefined-outer-name
    if order != 'C':
        raise NotImplementedError
    # If the following code (4 lines) regarding ctx is removed
    # np.zeros((3, 4)) can be as fast as 4.96 us
    if ctx is None:
        ctx = str(current_context())
        ctx = str(ctx)
    if dtype is not None and not isinstance(dtype, str):
        dtype = _np.dtype(dtype).name
    return _api_internal.zeros(shape, dtype, ctx)

Secondly, in backend, _npi.zeros is registered in src/api/operator/numpy/, as follows.

.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {  // Part1: populate NodeAttrs
  // Part1: populate NodeAttrs
  using namespace runtime;
  const nnvm::Op* op = Op::Get("_npi_zeros");
  nnvm::NodeAttrs attrs;
  op::InitOpParam param;
  if (args[0].type_code() == kDLInt) {
    param.shape = TShape(1, args[0].operator int64_t());
  } else {
    param.shape = TShape(args[0].operator ObjectRef());
  if (args[1].type_code() == kNull) {
    param.dtype = mshadow::kFloat32;
  } else {
    param.dtype = String2MXNetTypeWithBool(args[1].operator std::string());
  attrs.parsed = std::move(param);
  attrs.op = op;
  if (args[2].type_code() != kNull) {
    attrs.dict["ctx"] = args[2].operator std::string();
  // Part2: invoke
  int num_outputs = 0;
  auto ndoutputs = Invoke(op, &attrs, 0, nullptr, &num_outputs, nullptr);
  *ret = ndoutputs[0];

The first part of npi.zeros is mainly responsible for populating NodeAttrs, and the second part invokes the op in an imperative way, which is similar to the api MXImperativeInvokeImpl.

As is seen, _npi.zeros is not directly exposed to the dll boundary but can be invoked in the frontend. If we go deeper into operator TShape(), we will see that TShape is initialized from ADTObj, which is like an array of integers in this case. As it works without (de)serialization between string and int, some speedup is obtained here.

Also, here are a few things to note:

  • In the original np.zeros, the default value for dtype is specified in front end, while in np.zeros with the new FFI interface as is shown above, it is specified in backend, which slightly speeds it up. Generally I think it may be better to move code to backend, as python is sometimes slow.
  • ctx is still getting (de)serialized. Joint effort with engine overhead optimization may be required, as ctx gets deserialized in engine.

Finally, I would like to walk through the invoking process from python to c++, which might be of help when navigation fails to work for FFI.

  1. python/mxnet/ndarray/numpy/ def zeros. The python entry.
  2. python/mxnet/_ffi/ class Function. The front end wrapper to hide the details about ctypes and cython.
  3. python/mxnet/_ffi/_cython/function.pxi: cdef class FunctionBase. The cython entry.
  4. python/mxnet/_ffi/_cython/function.pxi: def __call__. Here we call make_ret to convert MXNetValue into python object.
  5. python/mxnet/_ffi/_cython/function.pxi: cdef inline int FuncCall. Here we call make_arg to convert python object into MXNetValue.
  6. src/runtime/ int MXNetFuncCall. The cpp entry. It is an API exposed to dll. It wraps the abi-compatible c struct MXNetValue into easy-to-use cpp class MXNetArgs and MXNetRetValue.
  7. include/mxnet/runtime/packed_func.h: inline void PackedFunc::CallPacked.
  8. src/api/operator/numpy/ MXNET_REGISTER_API(“_npi.zeros”). Here we arrive at the user-defined backend function.


Benchmarked on c5n.4x Ubuntu 16.04 LTS with NaiveEngine and cython enabled.


Current FFI (us)

TVM FFI (us)

zeros((3, 4))



zeros((3, 4), dtype=‘float64’)



zeros((3, 4), ctx = “cpu(0)”, dtype=‘float64’)



tensordot(a, b, ((1, 0), (0, 1)))




[1] [RFC] MXNet Imperative Op Invocation Overhead
[2] [PR] Fix collect_params().zero_grad() in gluon numpy interface
[3] [DGL]

  • No labels