Problem Description

Currently, Clojure's NDArray and Symbol APIs are generated based on reflection of Java classes. These classes contain function definitions that are generated by the Scala package based on the C/C++ API using JNI.

There are several issues with using reflection to generate the functions within Clojure:

  1.  Function argument count and types are unclear
  2.  Documentation is hard to generate and maintain
  3.  Specs are hard to write

All these issues pose a challenge to beginners who may sometime need to read Scala/Python documentation, use trial and error or in some cases read Clojure unit tests to find out what works.

Take for example, the ndarray activation function whose signature is (activation & nd-array-and-params). This function is actually called as (activation ndarray "relu") and cannot be discovered easily. The problem is further exacerbated when it comes to functions with large number of arguments such as convolution whose signature is similar to activation.

Scala package also provides a new typesafe API that can potentially be used to generate new ndarray-api and symbol-api packages for Clojure. This was recently attempted in NDArray/Symbol API PR. However, such an approach quickly becomes cumbersome especially when applied to functions with lot of arguments (e.g. convolution or rnn). For example, the signature of convolution function using the new approach becomes (convolution ndarray ndarray-1 ndarray-2 shape option option-1 option-2 num option-3 option-4 option-5 option-6 option-7 option-8 option-9). Although this approach (better) clarifies the arguments to the function it has many of the same problems as before.


Unlike Scala's focus on types, the ideal clojure APIs should embrace the data structures available within the language, allow definition of specs and describe the arguments better (e.g. using better names, documentation).

Let's take the simple ndarray activation function as an example to see what the ideal state would look like:

(s/def ::ndarray #(instance? NDArray %))
(s/def ::act-type #{"relu" "sigmoid" ...})
(s/fdef ::activation ...)

(defn activation
 "Applies an activation function element-wise to the input.
  The following activation functions are supported: 
    - `relu`: Rectified Linear Unit, :math:`y = max(x, 0)`
    - `sigmoid`: :math:`y = \frac{1}{1 + exp(-x)}` 
    - `tanh`: Hyperbolic tangent, :math:`y = \frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}` 
    - `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))` 
    - `softsign`: :math:`y = \frac{x}{1 + abs(x)}`"
 ([ndarray act-type {:keys [out] :as opts}] ...)
 ([ndarray act-type] (activation ndarray act-type {}))

instead of just being (activation & nd-array-and-params).

Similarly the ndarray convolution function would look like:

(defn convolution
  "Compute *N*-D convolution on *(N+2)*-D input.
   ([data weight bias kernel num-filter
{:keys [stride
dilate pad num-group workspace
no-bias cudnn-tune cudnn-off layout out]
:as opts}]
   ([data weight bias kernel num-filter]
(convolution data weight bias kernel num-filter {})))

This also applies to the symbol API functions which currently have the following (fixed) signatures:

(function-name sym-name kwargs-map symbol-list kwargs-map-1)
(function-name sym-name attr-map kwargs-map)
(function-name sym-name kwargs-map-or-vec-or-sym)
(function-name kwargs-map-or-vec-or-sym)

Instead it would look like the following for the symbol fully-connected function:

(defn fully-connected
  [{:keys [data weight bias num-hidden no-bias flatten name attr] :or {attr {}} :as opts}] ...)

Although this function is called the same way today, the main difference is that the API is more explicit.

Scala => Clojure Translation

We should also encourage the usage of Clojure data structures as function arguments:

Int | Float | Double | Boolean | Stringint | float | double | bool | string
Array[Int] | Array[...] | Array[NDArray]vec-of-ints | vec-of-... | vec-of-ndarrays
scala.collection.immutable.Mapmap (i.e. {...})
org.apache.mxnet.Shapevec-of-ints (e.g. [2 3])

function(req_arg: ..., opt_shape: Option[Shape] = Some(...), ...) {



(defn function


   {:keys [opt-shape]

    :or {opt-shape [3 5]}

    :as opts}] ...)

Potential Approach(es)

We can potentially use the Scala package's GeneratorBase to create the functions described above. The current challenge is that the generator cannot be called directly from within Clojure and hence we have to rely on reflection instead.

In the future, we can directly use JNI to generate the functions based on the C/C++ API. However, this approach may be more time-consuming to implement and requires more effort that has already been expended on the excellent Scala package.

  • No labels


  1. Kedar Bellare I think this proposal looks great and is a good step forward. I think the next step is to spike out what it would like like using the lower level functions to LIB. I don't think we will need to have the Scala package modify anything for use since we can get the LibInfo (that can do calls for us across the JNI bindings) from the Base. Here is an example of getting all the operator names:

    Scala snippet from GeneratorBase's get BackendFunctions:

    val opNames = ListBuffer.empty[String]
    _LIB.mxListAllOpNames(opNames) => {
    val opHandle = new RefLong
    _LIB.nnGetOpHandle(opName, opHandle)
    And Clojure interop to just get the function names for now
    (import '(org.apache.mxnet Base))
    (import '(scala.collection.mutable ListBuffer))
    (require '[org.apache.clojure-mxnet.util :as util])
    (def libinfo (Base/_LIB))
    (def mylist ($ ListBuffer/empty))
    (do (.mxListAllOpNames libinfo mylist))
    (take 5 (util/buffer->vec mylist))
    ;=> ("Activation" "BatchNorm" "BatchNorm_v1" "BilinearSampler" "BlockGrad")

    We will need to figure out all the interop code to get the parameters too, but I **think** it's possible. Happy to help work with you to figure this part out.

  2. Carin Meier sweet! i was looking for `Base` within `org.apache.mxnet.init` (because of but couldn't find it. it is much more straightforward once i have `libinfo` (i just checked that this works).