You are viewing an old version of this page. View the current version.

Compare with Current View Page History

« Previous Version 2 Next »

Problem Description

Currently, Clojure's NDArray and Symbol APIs are generated based on reflection of Java classes. These classes contain function definitions that are generated by the Scala package based on the C/C++ API using JNI.

There are several issues with using reflection to generate the functions within Clojure:

  1.  Function argument count and types are unclear
  2.  Documentation is hard to generate and maintain
  3.  Specs are hard to write

All these issues pose a challenge to beginners who may sometime need to read Scala/Python documentation, use trial and error or in some cases read Clojure unit tests to find out what works.

Take for example, the ndarray activation function whose signature is (activation & nd-array-and-params). This function is actually called as (activation ndarray "relu") and cannot be discovered easily. The problem is further exacerbated when it comes to functions with large number of arguments such as convolution whose signature is similar to activation.

Scala package also provides a new typesafe API that can potentially be used to generate new ndarray-api and symbol-api packages for Clojure. This was recently attempted in NDArray/Symbol API PR. However, such an approach quickly becomes cumbersome especially when applied to functions with lot of arguments (e.g. convolution or rnn). For example, the signature of convolution function using the new approach becomes (convolution ndarray ndarray-1 ndarray-2 shape option option-1 option-2 num option-3 option-4 option-5 option-6 option-7 option-8 option-9). Although this approach (better) clarifies the arguments to the function it has many of the same problems as before.

Proposal

Unlike Scala's focus on types, the ideal clojure APIs should embrace the data structures available within the language, allow definition of specs and describe the arguments better (e.g. using better names, documentation).

Let's take the simple ndarray activation function as an example to see what the ideal state would look like:

(s/def ::ndarray ...)
(s/def ::act-type #{"relu" "sigmoid" ...})
(s/fdef ::activation ...)

(defn activation
 "Applies an activation function element-wise to the input.
  The following activation functions are supported: 
    - `relu`: Rectified Linear Unit, :math:`y = max(x, 0)`
    - `sigmoid`: :math:`y = \frac{1}{1 + exp(-x)}` 
    - `tanh`: Hyperbolic tangent, :math:`y = \frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}` 
    - `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))` 
    - `softsign`: :math:`y = \frac{x}{1 + abs(x)}`"
 ([ndarray act-type {:keys [out] :as opts}] ...)
 ([ndarray act-type] (activation ndarray act-type {}))

instead of just being (activation & nd-array-and-params).

Similarly the ndarray convolution function would look like:

(defn convolution
  "Compute *N*-D convolution on *(N+2)*-D input.
    ..."
   ([data weight bias kernel num-filter
{:keys [stride
dilate pad num-group workspace
no-bias cudnn-tune cudnn-off layout out]
:as opts}]
     ...)
   ([data weight bias kernel num-filter]
(convolution data weight bias kernel num-filter {})))

This also applies to the symbol API functions which currently have the following (fixed) signatures:

(function-name sym-name kwargs-map symbol-list kwargs-map-1)
(function-name sym-name attr-map kwargs-map)
(function-name sym-name kwargs-map-or-vec-or-sym)
(function-name kwargs-map-or-vec-or-sym)

Instead it would look like the following for the symbol fully-connected function:

(defn fully-connected
  "..."
  [{:keys [data weight bias num-hidden no-bias flatten name attr] :or {attr {}} :as opts}] ...)

Although this function is called the same way today, the main difference is that the API is more explicit.

Scala => Clojure Translation

We should also encourage the usage of Clojure data structures as function arguments:

ScalaClojure
Int | Float | Double | Boolean | Stringint | float | double | bool | string
Array[Int] | Array[...] | Array[NDArray]vec-of-ints | vec-of-... | vec-of-ndarrays
scala.collection.immutable.Mapmap (i.e. {...})
org.apache.mxnet.Shapevec-of-ints (e.g. [2 3])

function(req_arg: ..., opt_shape: Option[Shape] = Some(...), ...) {

    ...

}

(defn function

  [req-arg

   {:keys [opt-shape]

    :or {opt-shape [3 5]}

    :as opts}] ...)
 

Potential Approach(es)

We can potentially use the Scala package's GeneratorBase to create the functions described above. The current challenge is that the generator cannot be called directly from within Clojure and hence we have to rely on reflection instead.

In the future, we can directly use JNI to generate the functions based on the C/C++ API. However, this approach may be more time-consuming to implement and requires more effort that has already been expended on the excellent Scala package.

  • No labels