Implementation of a NumPy-compatible operator should be similar to a normal operator. Please read this link for a tutorial on implementation of normal operators before reading this one. The following sections will illustrate some key differences between NumPy-compatible and normal MXNet operators.
NumPy-specific names for operators
For NumPy-compatible operators, they should be registered with a prefix of _np_, _npi_ or _npx_ so that they could appear in the frontend under the NumPy namespace.
Here _np_ is for an operator that you simply want to expose the backend interface directly to the users.
_npi_ is for an operator that you want to be only seen under _internal_ namespace (usually you want to do this for any operator that needs an extra wrapper in Python). For the operators with `_npi_` prefix, you will need to define wrapper functions under the module mxnet.numpy, mxnet.ndarray.numpy, and mxnet.symbol.numpy, with similar (if not the same) signatures to the APIs in the official NumPy package. See operator mxnet.numpy.mean for example.
NumPy-specific code paths
All code for NumPy-compatible operators could be found and should be put under src/operator/numpy, and all tests for NumPy operators shall go to tests/python/unittest/test_numpy_op.py.
The declaration of Python wrappers of _npi_ operators for mxnet.ndarray.numpy will be in python/mxnet/ndarray/numpy/_op.py.
The declaration of Python wrappers of _npi_ operators for mxnet.symbol.numpy will be in python/mxnet/symbol/numpy/_symbol.py.
The declaration of Python wrappers of _npi_ operators for mxnet.numpy will be in python/mxnet/ndarray/numpy/multiarray.py.
Re-using the existing implementations
If some MXNet operator already implements the same semantic as its NumPy counterpart, then it's possible to re-use the code by simply adding an alias for that operator, but with the following cautions:
NumPy-compatible InferShape functions
Previously MXNet did not support zero-size and zero-dim tensors, now for the new NumPy-compatible interface we need to handle such cases, especially in the InferShape functions. If you're re-using some existing operators, make sure you re-visit the InferShape function of it and examine it against zero-size and zero-dim tensors.
NumPy-compatible InferType functions
In the official NumPy, float64 is the default dtype, while in deep learning, we use float32 as the default. We should keep the behavior self-consistent in MXNet where dtype is supposed to be a default value. For example, np.ones((2, 2)) should return an mxnet.numpy.ndarray of dtype equal to float32.
A NumPy operator may have different behaviors regarding the input-output type relationships under certain cases. Make sure you compare the documentations of both.
NumPy-compatible FCompute functions
With the new support for zero-size and zero-dim tensors, sometimes the existing FCompute functions need to be changed to accommodate those cases. There're two cases:
If the operator was originally written with mshadow library, then there's a big possibility that it's not compatible with zero-size and zero-dim tensors.
If the operator did not use mshadow library, but used mxnet_op::Kernel instead, then you need to prevent it from launch 0 threads on GPUs as that's not a legal behavior for GPUs.
New implementation based on existing kernels
Sometimes a NumPy operator could be implemented with a combination of existing kernels in MXNet. For example, some NumPy operators may share exactly same computation.
Brand new operators
The components you need to implement from scratch for this case are the same as the ones of a normal MXNet operator, and, as stated in above sections, you need to pay attention to zero-size and zero-shape cases in all components.
Writing a test
Tests should cover:
- Gluon forward and backward
- Gluon hybridized and un-hybridized
- Coverage of scalar and zero-size tensors
- Consistency check with NumPy
- Also a separate test is needed for checking interoperability in tests/python/unittest/test_numpy_interoperability.py
A template for unit tests:
How to implement operators using TVM: https://github.com/hgt312/misc/blob/master/TVMOp%20Tutorial.ipynb