This Confluence has been LDAP enabled, if you are an ASF Committer, please use your LDAP Credentials to login. Any problems file an INFRA jira ticket please.

Page tree

Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

...

The new Scala API design is shown below:

Code Block
languagescala
def Activation(data : Option[Symbol], act_type : String, name : Option[String], attr : Option[Map[String, String]]) : Symbol= {

...


...

 
    val map : mutable.Map[String, Any] = mutable.Map()

...


    if (!data.isEmpty) map("data") = data.get

...


    map("act_type") = act_type

...


    val currName = name.getOrElse("")

...


    val currAttr = attr.getOrElse(Map())

...


    createSymbolGeneral("Activation", currName, currAttr, Seq(), map.toMap)

...


...

 
}


This is a very simple example of how we can use Scala to use MXNet. The thing that behind this can be explained from this flowchart. This is how an Scala API called and executed (Python API have a similar approach).


Current implementation of this API is very limited, we need to put all parameters as key-value args in a map and send to the backend to process. User will have no ideas what parameters are required for each of the function. The implementation of the old API is as shown below:

Code Block
languagescala
def Activation(name : String = null, attr : Map[String, String] = null)

...


 (args : Symbol*)(kwargs : Map[String, Any] = null) : Symbol = {

...


    createSymbolGeneral("Activation",name,attr,args,kwargs)

...


}


We are proposing a new design of the API, brings different arguments to different functions as well as a full implementation on the backend side in order to execute them.

...


We use Scala Macros that generate the code in here. the “mxSymbolGetAtomicSymbolInfo” function would bring us all of the information of functions that used in Symbol and NDArray. The old Scala API only takes the name and create a universal argument field for all of the APIs. This is limited by the poor Code generation support from old Scala version. We need to write a lot of the Parser code in order to translate a function in the code generator.

New Scala API macros

In order to solve this problem, we have studied deeply into the current Scala version and found a great methods that could simplify the process and help us to build a better function namely quasiquote. It is a super powerful tool which can automatically parse string into the function structure as shown above. The implementation of the Activation function following universal argument field can be simplified as shown below:

...

Code Block
languagescala
val funcName = symbolfunction.name

...


...

val tName = TermName(funcName)

...


     q"""

...


         def $tName(name : String = null, attr : Map[String, String] = null)

...


         (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null)

...


          : org.apache.mxnet.Symbol = {

...


           createSymbolGeneral($funcName,name,attr,args,kwargs)

...


           }

...


      """

It brings developer lesser code and much clearer view of what code will be generated in this case. Inspired from this method, we are creating a brand new Pipeline in Macros to generate code:

...

As discussed previously, we need to use the API that provided from C++ to generated corresponding Scala APIs. The C++ side only provide a type description rather than an actual type in order to make sure the other language developer can correctly translate that into the correct type. The argType will display in the following formats:

<C++ Type>, <Required/Optional>, <Default=>

C++ to Scala type converter

Code Block
languagescala
  // Convert C++ Types to Scala Types

...


  def typeConversion(in : String, argType : String = "") : String = {

...


    in match {

...


      case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape"

...


      case "Symbol" | "NDArray" | "NDArray-or-Symbol" => "org.apache.mxnet.Symbol"

...


      case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]"

...


      => "Array[org.apache.mxnet.Symbol]"

...


      case "float" | "real_t" | "floatorNone" => "org.apache.mxnet.Base.MXFloat"

...


      case "int" | "intorNone" | "int(non-negative)" => "Int"

...


      case "long" | "long(non-negative)" => "Long"

...


      case "double" | "doubleorNone" => "Double"

...


      case "string" => "String"

...


      case "boolean" => "Boolean"

...


      case "tupleof<float>" | "tupleof<double>" | "ptr" | "" => "Any"

...


      case default => throw new IllegalArgumentException(

...


        s"Invalid type for args: $default, $argType")

...


    }

...


  }

This conversion tool will catch all possible types that defines in C++ and translate it into something that can be used in Scala. It is not finished fully, as you can see, we still have some types defined as type “Any”.

Code Block
languagescala
def argumentCleaner(argType : String) : (String, Boolean) = {

...


  val spaceRemoved = argType.replaceAll("\\s+", "")

...


  var commaRemoved : Array[String] = new Array[String](0)

...


  // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}  if (spaceRemoved.charAt(0)== '{') {

...


    val endIdx = spaceRemoved.indexOf('}')

...


    commaRemoved = spaceRemoved.substring(endIdx + 1).split(",")

...


    commaRemoved(0) = "string"  } else {

...


    commaRemoved = spaceRemoved.split(",")

...


  }

...


  // Optional Field  if (commaRemoved.length >= 3) {

...


    // arg: Type, optional, default = Null    require(commaRemoved(1).equals("optional"))

...


    require(commaRemoved(2).startsWith("default="))

...


    (typeConversion(commaRemoved(0), argType), true)

...


  } else if (commaRemoved.length == 2 || commaRemoved.length == 1) {

...


    val tempType = typeConversion(commaRemoved(0), argType)

...


    val tempOptional = tempType.equals("org.apache.mxnet.Symbol")

...


    (tempType, tempOptional)

...


  } else {

...


    throw new IllegalArgumentException(

...


      s"Unrecognized arg field: 

...

$argType, ${commaRemoved.length}")

...


  }

...



}


The following code would do the following translation

...

After we finish up the cleaning part, we can make a start building up the Functions in the class. The target for this will be creating a series of functions that lives in an existing object. However, there is no such a tool kit for us to build a Scala API conveniently. The accepted input from Scala will looks like below:

...

Code Block
languagescala
DefDef(functionScope, TermName(newName), List(),

...


        List(

...


          List(

...


            ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), TermName("name"),

...


              Ident(TypeName("String")), Literal(Constant(null))),

...


            ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), TermName("attr"),

...


              AST_TYPE_MAP_STRING_STRING, Literal(Constant(null)))

...


          ),

...


          List(

...


            ValDef(Modifiers(), TermName("args"), AST_TYPE_SYMBOL_VARARG, EmptyTree)

...


          ),

...


          List(

...


            ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), TermName("kwargs"),

...


              AST_TYPE_MAP_STRING_ANY, Literal(Constant(null)))

...


          )

...


        ), TypeTree(),

...


        Apply(

...


          Ident(TermName("createSymbolGeneral")),

...


          List(

...


            Literal(Constant(funcName)),

...


            Ident(TermName("name")),

...


            Ident(TermName("attr")),

...


            Ident(TermName("args")),

...


            Ident(TermName("kwargs"))

...


          )

...


        )

...


      )

Let's walk through this part of the code, “DefDef()” will return a function back as we placed functionScope inside. It contains a “TermName” which contains the name of the function. Then a 2D list contains the structure of ()()(). First List in the List contains arg “name : String = null” and “attr : Map[String, String] = null”. Second one is “arg : Symbol*” and third one is “kwargs : Map[String, Any] = null”. Then, we can define what inside the function, where Apply() means the return line. This function returns a Symbol.

def $tName(name : String = null, attr : Map[String, String] = null)
(args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null)
: org.apache.mxnet.Symbol = {
createSymbolGeneral($funcName,name,attr,args,kwargs)
}

This is what the previous code generated, a function with complete implementation.

However, we decide to change the current implementation to a dynamic way where the args and implementation field will be changed dynamically.

...

We name this API into a namespace as: Symbol.api.<Function name>. This api contains every functions that generated from new Macros

Code Block
languagescala
package org.apache.mxnet

...



@AddSymbolAPIs(false)

...


/** 

...


* typesafe Symbol API: Symbol.api._ 

...


* Main code will be generated during compile time through Macros 

...


*/

...


object SymbolAPI extends SymbolAPIBase {

...


}

API Docs generation

Due to the limitation of the Macros generation, the API docs would not come as a part of the current code. In order to achieve that, we call the same methods (mxSymbolGetAtomicSymbolInfo) to obtain all the information related to the function as well as the arguments. Currently, the implementation was done by writing files as abstract class and use a plugin to execute in the Macros compile stage. This use the same C++ to Scala functions to generate function body.

Code Block
languagescala
def $tName(name : String = null, attr : Map[String, String] = null)
         (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null)
          : org.apache.mxnet.Symbol = {
           createSymbolGeneral($funcName,name,attr,args,kwargs)
           }