Problem Statement

Users want to bring a FP32 model to convert it to a mixed precision model to run inference on it. They want to use model zoo to convert pretrained models in Python and other frontends. They can achieve FP16 inference by casting inputs and params in gluon but mixed precision inference with certain layers running in FP16 while others running in FP32 cannot be achieved in a trivial way. Also, this cannot be done easily for symbolic models (json and params). Proposing to add APIs to convert FP32 models to mixed precision models.

There is some nice ongoing work to add automatic mixed precision support for training to mxnet[1]. Among other things, it automatically
adds cast layers, for conversion to FP16 or FP32 based on the operator. There are specific operator lists maintained for ops that should always run in FP16, ops that should always run in FP32 and op which should run in FP16 or FP32 based on whichever is the widest type among its inputs. It also takes into account operators should run in specific precision only if a condition is met (for example Activation with act_type as softrelu).

I think we can use some of the ideas from AMP, to add an API to convert a model to mixed precision model and add it under AMP namespace. The proposal is elaborated more below:

API Addition


def convert_model(sym, arg_params, aux_params, target_dtype="float16", target_dtype_ops=None,
fp32_ops=None, conditional_fp32_ops=None, excluded_sym_names=None):
"""API for converting a model from FP32 model to a mixed precision model.
MXNet tries to convert the FP32 model to mixed precision model by adding
cast layers using amp_cast and amp_multicast operators. The decision on
which cast layer to add is based on hardcoded lists for Automatic Mixed Precision
in MXNet. These lists can be overridden by the user by providing their own lists
using : targe_precision_ops, fp32_ops, widest_precision_ops, conditional_fp32_ops

Parameters
----------
sym : str or Symbol
Defines the structure of a neural network for FP32 types.
arg_params : dict
Dictionary of name to `NDArray`.
aux_params : dict
Dictionary of name to `NDArray`.
target_dtype : str
Currently only supports float16. The target dtype indicates to add cast layers
when possible so that lower precision computation can be leveraged.
target_dtype_ops : list of strs
Override the list of operator names casted to target_dtype.
If None, uses the framework's default list to be casted to target dtype.
fp32_ops : list of strs
Override the lists of operator names casted to FP32.
If None, uses the framework's default list to be casted to FP32.
conditional_fp32_ops : list of (string, string, list of string)
Override the list of operators to be casted to FP32.
The format of the list is
(name of the function, name of the parameter,
list of values of the parameter that make the operator to be casted to
fp32)
excluded_sym_names : list of strs
A list of strings that represent the names of symbols that users want to exclude
from being quantized.
"""



target_dtype should decide which lists need to be overridden.
For example, in the future bfloat16 support may be added in which case these lists for operators running in bfloat16 will also be added to AMP.
In this case, target_dtype will allow users to choose the right dtype for the mixed precision model.


def convert_block(block, target_dtype="float16", target_dtype_ops=None,
                  fp32_ops=None, conditional_fp32_ops=None,
                  excluded_sym_names=None, input_names=['data']):
    """Given a hybrid block/symbol block representing a neural network of data type FP32 and target_dtype,
    return a block with mixed precision support

    Parameters
    ----------
    block : HybridBlock or SymbolBlock object
        FP32 HybridBlock or SymbolBlock object
    target_dtype : str or numpy
        currently only supports float16. The target dtype indicates to add cast layers
        when possible so that lower precision computation can be leveraged.
    target_precision_ops : list of strs
        Override the list of operator names casted to target_dtype.
        If None, uses the framework's default list to be casted to target dtype.
    fp32_ops : list of strs
        Override the lists of operator names casted to FP32.
        If None, uses the framework's default list to be casted to FP32.
    conditional_fp32_ops : list of (string, string, list of string)
        Override the list of functions casted to FP32.
        The format of the list is
        (name of the function, name of the parameter,
         list of values of the parameter that make the operator to be casted to
        fp32)
    excluded_sym_names : list of strs
        A list of strings that represent the names of symbols that users want to exclude
        from being quantized.
    input_names : list of strs
        A list of strings representing the names of input variables
	"""

User experience will be similar to the export API experience today. Users will have to call hybridize followed by one forward pass before calling convert_model.

Backend Changes

NNVM Pass

Add a NNVM pass for the backend. This would use the amp lists based on the target_dtype.
This pass will perform graph traversal and add amp_cast and amp_multicast layers for FP16 and FP32 ops based on the op whitelists and excluded_sym_names. Some of the ideas have been borrowed from quantization pass added as part of quantization support [2].

Outline of algorithm:


1. Three additional data structures used:

  1. map from a node to a copy node in the casted graph (mirror_map)
  2. map from an input entry to the corresponding fp32 casted entry (mirror_entry_fp32_map)
  3. map from an input entry to the target dtype (e.g. fp16) casted entry (mirror_entry_target_map) (please see below fig for why 2 and 3 are needed)

Consider the below script:


data = mx.sym.var("data")
x = mx.sym.cos(data)
x2 = mx.sym.sin(data)
x3 = mx.sym.exp(data)
x4 = mx.sym.sqrt(data)
result = x + x2 + x3 + x4
casted_result = mx.contrib.amp._convert_symbol(result, target_dtype="float16",
target_dtype_ops=["sin", "cos"], fp32_ops=["exp", "sqrt"])

   

Without the mirror_entry_target_map there would have been 3 cast nodes instead of 2: 1 for amp_cast float16 and two others going as amp_casted fp32 input to exp0 and sqrt0. Thus, the two additional data structures help optimize and share common cast node inputs among different nodes.


2. Visit nodes of the graph in a topologically sorted order and when each node is visited do the following:

  1. Create a copy node
  2. Clear inputs of the copy node
  3.  
    1. If node is a variable:
      1. Add mapping of the node to the copy node in mirror_map.
    2. Else if node is not in any whitelist:
      1. Find all inputs of original node and find corresponding mirror entry and add as inputs to the copy node.
    3. Else if node's op_name is in the fp32 whitelist:
      1. Iterate through inputs: If input is already in mirror_entry_fp32_map, add the mapped node to the copy node inputs. If not, insert a cast node between copy_node and mirror node of previous node. Add mapping from input node to fp32 cast node in mirror_entry_fp32_map.
    4. Else if node's op_name is in the target dtype whitelist:
      1. Iterate through inputs:
        1. If input is already in mirror_entry_target_map, add the mapped node to the copy node inputs.
        2. If not, insert a cast node between copy_node and mirror_node of previous node. Add mapping from input node to target dtype cast node in mirror_entry_target_map.
    5. Else if node's op_name is in the widest_precision_ops whitelist:
      1. Add amp_multicast between mirror of node's inputs and copy of current node.
  4. Add mapping from node to copy node in mirror_map.

3. Create a new graph using the copy nodes obtained from mirror_map. 

4. Return the newly created graph. 

Please take a look at the PoC[3] for more details on the NNVM pass.

Example Usage


import mxnet as mx

# Simple demo model
data = mx.sym.var("data")
data2 = mx.sym.var("data2")
data3 = mx.sym.var("data3")
x = mx.sym.exp(data)
x2 = mx.sym.sin(data)
x3 = mx.sym.cos(data)
sym = x + x2 + x3
result = mx.sym.add_n(sym, data2, data3)
casted_result = mx.contrib.amp._convert_symbol(result, target_dtype="float16",
target_dtype_ops=["sin", "cos", "exp"], fp32_ops=["elemwise_add"],
widest_dtype_ops=["add_n"], conditional_fp32_ops=None)

Before/After Conversion:

                                                                                                                                                                                                                                          

   

As you can see above the converted graph has amp_cast, amp_multicast nodes which allow for appropriate casting of inputs.


Symbol Changes

After mixed precision pass is done and amp_cast and amp_multicast layers are added, the symbolic representation needs to be modified to store the right dtype attrs for some of its inputs. This will require running InferType pass after the ReducePrecision pass and then using the obtained information to set the data types of weights and auxiliary states.

This will ensure that the dtype corresponding to each param or aux input is correct, by casting the arg_params and aux_params accordingly.

Thus the symbol returned by convert_model API will have amp_cast and amp_multicast symbols and the "__dtype__" attribute of weight and aux symbols will be updated. Also the returned arg_params and aux_params ndarrays will have the same dtype as the "__dtype__" attribute in the returned symbol.

Gluon Changes

For Gluon code, we need to add an internal API to retrieve sym, arg_params and aux_params from a hybrid_block. Following this, convert_model can be used to convert a symbol json, model params and auxiliary params. After conversion, the symbolic model (json, arg_params, aux_params) can be imported back into gluon with SymbolBlock.imports. The returned symbolblock is ready to use for inference.

Frontend Bindings

Need to add amp convert_model API support for different bindings like C++, Scala etc. 

Performance

Setup

EC2 Instance: p3.8xlarge

CUDNN: 7.4.2

CUDA: 10.0

Commit Hash: b3b952f9d5490ee2707209ab866e6c3f094e2046 (PoC changes made on top of this built from source)

Mixed Precision Models:

Resnet50_v1: JSON FileParams File

imagenet1k-resnet-152: JSON File, Params File

Results

Model (Samples/sec)Batch SizeOriginal Model (Samples/sec)Mixed Precision Model (Samples/sec)Original Model with Implicit Type Conversion (MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION=1) (Samples/sec)



imagenet1k-resnet-152

1857272
2140

140

142
4240270228
8320470261
16405680315



resnet50_v1

1215165205
2370330365
4560600545
8760980635
169351400790


FAQ

Will the arg_params and aux_params be casted to fp16 ?

Inputs of ops in FP16 will be casted. Other params may or may not be casted based on the type inference logic.

How is this different from casting inputs to FP16 and casting params to FP16 in Gluon ?

Casting inputs to FP16 and params to FP16 for gluon ensures that you are able to execute the model in FP16 precision. Generally, there may be some ops which may need to run in FP16 while other in FP32 for accuracy and performance considerations. This is where the AMP APIs will be useful. 

Will the dtype attribute in the serialized model change after convert_model is called ?

Yes dtype attribute in the serialized model can change after convert_model is called. This depends on how the whitelist affects the model in question and if the type inference decides that certain params needs to be in float16.

Is there a need for hybridizing and running a forward pass for the AMP converted gluon model ?

No there is no need to hybridize since it will return SymbolBlocks which are already hybridized.

What changes need to be made to existing script to convert and run inference mixed precision model ?

Adding the line, amp.convert_model or amp.convert_block should be sufficient to convert and run inference on a mixed precision model. Below are two user experience examples to convert a model to mixed precision model and run inference:

Module API

sym, arg_params, aux_params = mx.model.load_checkpoint("resnet18", 0)

# Additional line below to convert to a mixed precision model. Everything else remains the same
result_sym, arg_params, aux_params = mx.contrib.amp.convert_model(sym, arg_params, aux_params, target_dtype="float16")

mod = mx.mod.Module(result_sym, data_names=['data'], label_names=None, context=mx.cpu())
mod.bind(data_shapes=[['data', (1, 3, 224, 224)]])
mod.set_params(arg_params, aux_params)
mod.forward(mx.io.DataBatch(data=[mx.nd.ones((1, 3, 224, 224))], label=None))
result = mod.get_outputs()[0].asnumpy()

Gluon API

net = get_model(name="resnet50_v1", classes=1000, pretrained=True)
net.hybridize()
x = mx.nd.random.uniform(0, 1, shape=(1, 3, 224, 224))
out = net(x)

# Additional line below to convert to a mixed precision model. Everything else remains the same
net = mx.contrib.amp.convert_block(net, target_dtype="float16")

out = net(x)

References

  1. https://github.com/apache/incubator-mxnet/pull/14173
  2. https://github.com/apache/incubator-mxnet/pull/9552
  3. https://github.com/apache/incubator-mxnet/pull/14702
  • No labels

2 Comments

  1. Good proposal!  Could you list what OPs are fp16 supported now? 



  2. Hi Patric thanks for your feedback ! Do you mean what ops are FP16 supported for the AMP whitelist ? Or from all ops in MXNet which ops are FP16 supported ? 

    If you are asking which ops can be included in FP16 whitelist, it is any op which supports FP16 dtype. If your question is which ops support FP16 dtype that may be a bigger question. By default most ops should support this dtype . Having said that, for CPU FP16 operations are very slow. Also, there may have been exceptions created for FP16 dtype support, for example topk and layernorm. Since finding which operators have these exceptions created would require looking into each op implementation I think we can do this based on customer feedback on which ops dont support FP16.

    Currently from https://github.com/apache/incubator-mxnet/issues?utf8=%E2%9C%93&q=is%3Aissue+FP16+Support+label%3AFP16+ I only find requests for layernorm and topk