...
2. In order to provide the weights from MXNet (NNVM) to the TensorRT graph converter before the symbol is fully bound (before the memory is allocated, etc.), the arg_params
and aux_params
need to be provided to the symbol's simple_bind
method. The weights and other values (e.g. moments learned from data by batch normalization, provided via aux_params
) will be provided via the shared_buffer
argument to simple_bind
as follows:
executor = sym.simple_bind(ctx=ctx, data = data_shape,
softmax_label=sm_shape, grad_req='null', shared_buffer=all_params, force_rebind=True)
3. To collect arg_params
and aux_params
from the dictionaries loaded by model.load()
, we need to combine them into one dictionary:
...
def
...
merge_dicts(*dict_args):
...
result = {}
for dictionary in dict_args:
result.update(dictionary)
return result
return result sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, epoch) all_params = merge_dicts(arg_params, aux_params)
This all_params
dictionary can be seen in use in the simple_bind
call in #2
. 4. Once the symbol is bound, we need to feed the data and run the forward()
method. Let's say we're using a test set data iterator called test_iter
. We can run inference as follows:
...
for
...
idx,
...
dbatch
...
in
...
enumerate(test_iter):
...
data = dbatch.data[0]
...
executor.arg_dict["data"][:]
...
= data
executor.forward(is_train=False)
...
preds = executor.outputs[0].asnumpy()
...
top1 = np.argmax(preds,
...
axis=1)
5. Note: One can choose between running inference with and without TensorRT. This can be selected by changing the state of the MXNET_USE_TENSORRT
environment variable. Let's first write a convenience function to change the state of this environment variable:
...
Now, assuming that the logic to bind a symbol and run inference in batches of batch_size
on dataset dataset
is wrapped in the run_inference
function, we can do the following:
print("Running inference in MXNet") set_use_tensorrt(False) mx_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size=batch_size) print("Running inference in MXNet-TensorRT") set_use_tensorrt(True) trt_pct = run_inference(sym, arg_params, aux_params, mnist, all_test_labels, batch_size=batch_size)
...