Note: extended design for callbcaks here

Problem and Goals


Training a model in Gluon requires users to write the training loop, this is useful because of its imperative nature, however repeating the same code across multiple models can become tedious and repetitive with boilerplate code. The training loop can also be overwhelming to some users new to deep learning.
Users have asked for a simple Fit API, similar to APIs available in SKLearn and Keras (example forum ask) as a way to simplify model training and reduce boilerplate code and complexity.

Target Users

  1. Beginners who are new to deep learning and/or to Gluon.
  2. Applied scientists and deep learning practitioners training models with low complexity and non-custom requirements.

Goals and Deliverables

  • Introduce a new Gluon “Fit” API that eliminates the need to code a training loop for simple model use cases, thus reduces manual errors and friction.
  • Support Fit API handlers that enable to customize the training loop for things like checkpointing, logging, early stopping and metrics inspired by Keras Callbacks.
  • Maintain backwards compatibility: the existing Gluon way to train a model will be supported and maintained - it is needed for complex models and full imperative control by the user.
  • The new Fit API will cover beginners use-cases including canonical CV and NLP models, full list is in appendix. For advanced users and complex models, the recommended path is to use the existing training loop.
  • Test coverage: 100% unit test coverage and 100% integration test coverage for the example models in Appendix .
  • Educate Gluon users via: (1) Blog post (2) Example (3) Tutorial

Existing user experience

Currently in Gluon because of its imperative style of programming, users write the entire training loop which requires multiple steps. To see a code example, see appendix A.
Writing the custom loop involves:

  1. Set-up the train/validate data
  2. Do a forward pass on batch of data
  3. Calculate loss
  4. Calculate gradients after back-propagation
  5. Update weights
  6. Evaluate metrics
  7. Update logs
  8. Save models

The above steps are generic and remains same across many models being trained. This is a repetitive work and we can simplify this by providing a simple “fit” API for the user that caters for 80% of modeling use-cases especially for novice users. We prefer advanced users continue to write their own custom training loops which gives them more flexibility.

Proposed Approach

We propose to add a simple “fit” API for Gluon by offering a new Estimator class which includes a fit method - similar to SKLearn Classifier API. Gluon Estimator will hold details of the model training like training statistics, training network and event handlers.
We will have EventHandler base class that exposes methods for the stages of the training loop viz. train_begin, train_end, epoch_begin, epoch_end, batch_begin and batch_end which gives users flexibility to override and customize different stages of training. We will also provide default event handlers will for common actions such as Logging, Metrics, EarlyStopping, Checkpointing.
So, the “fit” method will run forward pass, calculate loss and gradients, log, checkpoint and update metrics.

new user experience

The new API described below reduces the number of lines of code to be written by the user. In cases where the user implements logging and checkpointing, the number of lines is reduced from ~40 to ~6

Below is an example for the Fit API implementing similar functionality to the one using the existing training loop in Appendix A.

import mx.gluon.estimator as est
net = get_model() ## get the network
loss = gluon.loss.CrossEntropy()
e = est(net, lossfn = loss)
## training
trainers = [gluon.Trainer('sgd',{'learning_rate':0.001})], val_data, epochs, trainers, context)

Error Handling

We will provide the following checks to make sure the fit method is robust and easy to debug. We aim to let user know clearly what's wrong with clear error messages.

  1. Checking required arguments and optional arguments, notify users if default values are used.
  2. Checking the number of inputs from DataLoader and the number of losses.
  3. Check the number of outputs from network and the number of inputs for losses and metrics.

All error handling will be covered in unit tests.

Estimator class

Estimator is the new class that will encapsulate training and expose a Fit method.

class Estimator:
def __init__(self, net, lossfn=gluon.loss.CrossEntropy(),
metrics=[mx.metric.TopKAccuracy(), mx.metric.RMSE()]):
self._train_stats = {"lr" : [], "epoch":[],"train_metric1" : [], "val_metric1" : [],
"train_loss1" : [], "val_loss1" : [] , "time":[] ...}
self._net = net
self._loss = lossfn
self._metrics = metrics
self._loggingHandler= LoggingHandler(self)
self._checkpointingHandler= CheckpointHandler(self)
self._metricsHandler= MetricHandler(self)
self._additionalHandlers= [] ##can be the custom eventhandlers

def fit(self, train_data_loader,

def metrics:
return self._metrics

def additionalHandlers:
return self._additionalHandlers

def loggingHandler:
return self._loggingHandler

def checkpointhandler:
return self._checkpointHandler

def metricHandler:
return self._metricHandler

##Loss fn should take predictions and labels as input and return a scalar loss
def loss:
return self._loss

def plot_loss_graph(self):
## plan to support MXboard in next phase

EventHandler CLASS

EventHandler is a new class to offer customizing the training process by offering callbacks for handling different stages of the training.
We will implement standard handlers for logging, checkpointing and more. The user can extend EventHandler and implement their own custom handlers.

class EventHandler:
def __init__(self,estimator):
self._train_stats= estimator.train_stats

def train_begin(self):
def train_end(self):
def batch_begin(self):
def batch_end(self):
def epoch_begin(self):
def epoch_end(self):
class LoggingHandler(EventHandler):
def __init__(self, estimator, log_loc = './'):
# setup logging
def epoch_end:
## log the train stats to log location

class CheckpointHandler(EventHandler):
def __init__(self, estimator, checkpoint_interval=5 , ckpt_loc='./', monitor= "val_loss"):
train_stats = {"lr" = [0.1], "train_acc" = [0.85], "val_acc" = [0.99], ... }
def epoch_end:
## save the model params to the checkpointing location

class MetricHandler(EventHandler):
def __init__(self, estimator):
train_stats = {"lr" = [0.1], "train_acc" = [0.85], "val_acc" = [0.99], ... }
def epoch_end:
## calculate and update metrics for thr training dataset
## update_metrics(pred, labels)- default implementation can be overriden in case of multi-output cases
## update validation metrics for validation dataset

class EarlyStopping(EventHandler):
def __init__(self, monitor= "val_loss", min_delta=0, patience=0, mode="auto", baseline=None, restore_best_params=False):
# setup early stopping rules based on the metric/loss monitor and the mode
# e.g. if "acc" use greater mode else use lesser
def on_epoch_end:
# if metric improved, record the best value
# else wait n epochs(n=patience) and stop trainning
# restore net parameters from the best epoch accordingly
def on_train_end:
# let user know if early stopping is triggered

Addition of New APIs

The design adds new classes Estimator and EventHandler which solves the problem at hand.

Backward compatibility

The design doesn't alter any existing APIs and so the design is backward compatible.

Performance Considerations

The design proposes APIs to substitute training loops in Gluon and shouldn't have any performance regressions.
We can add tests to compare training using training loops vs the new APIs and compare the training times to understand to make sure that their is no regression.

Test Plan

We will implement 100% unit test coverage for the new API and event handlers.

  1. A test with all possible parameters
  2. A test with default parameters
  3. A test with missing parameters

We will add integration tests covering all of the models in the release goals (Appendix). See comment section for integration test plan.

Technical Challenges / Open Questions

  • Using metric update function instead of MetricEventHandler:

MetricEventHandler requires to access net outputs and the labels to calculate metric and update the metrics which are available in the Estimator class. This EventHandler is important while dealing with multiple outputs/multiple metrics as the logic to associate outputs with metrics lies here. As an alternative we can have metric_update function to estimator which can be customized by the users.

  • ##sample estimator class
    class Estimator:
    def metric_updatefn:
    ## Metric update fn should take predictions and labels as
    ## input and wrap the logic of how to update metrics in case
    ## of multi-output/ special cases.
    return self._metricupdate_fn

  • the APIs should cover the use cases like Multi-task learning and SSD. This can be done in 2 ways.
    • Either make the general fit API flexible enough to accommodate all the use-cases (like multi-output, multi-loss, multi-metric) - this has a disadvantage when providing a lot of flexibility to cater all use-cases will make the API overwhelming with many handlers that requires overriding MetricHandler, LossFunction which has most of the logic of mapping outputs of the network with metric and losses.
    • or should we provide custom metric and loss functions for use-cases like ObjectDetection, Multi-task learning, Neural Machine Translation which can be used off the shelf- there are already some task specific loss functions in GluonCV which do not have uniform signatures and hence we will just duplicate the APIs to fit our use case.


  1. MXNet-Module APIs-,
  2. Tensorflow estimators-
  3. Keras Model API- (fit, predict, evaluate)-
  5. Scikit learn-
  6. Torchbearer is a framework for doing easy fit, evaluate and predict on Pytorch. It is implemented using state objects which hold states of all the callback functions. The training parameters (loss, optimizer, metric) are passed as callbacks. They use association between states and callbacks for making the fit method work. There are similar frameworks for Pytorch like skorch, PyToune, ignite, TorchNetTwo (TNT), Inferno. Skorch uses Scikit learn internally and we don't want to add additional dependencies

APPENDIX A - Gluon Training Loop Example

current user experience

##Current training loop in gluon
# Only one epoch
num_epochs = 1

trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)
L = gluon.loss.SoftmaxCrossEntropyLoss()
best_val_score = 1

for epoch in range(num_epochs):
tic = time.time()
btic = time.time()

for i, batch in enumerate(train_data):
data, label = batch_fn(batch, ctx)

with ag.record():
outputs = [net(X) for X in data]
loss = [L(yhat, y) for yhat, y in zip(outputs, label)]
for l in loss:
lr_scheduler.update(i, epoch)

train_metric.update(label, outputs)

if log_interval and not (i+1)%log_interval:
train_metric_name, train_metric_score = train_metric.get()'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'%(
epoch, i, batch_size*log_interval/(time.time()-btic),
train_metric_name, train_metric_score, trainer.learning_rate))
btic = time.time()

train_metric_name, train_metric_score = train_metric.get()
throughput = int(batch_size * i /(time.time() - tic))

err_top1_val, err_top5_val = test(ctx, val_data)'[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score))'[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, time.time()-tic))'[Epoch %d] validation: err-top1=%f err-top5=%f'%(epoch, err_top1_val, err_top5_val))

if err_top1_val < best_val_score:
best_val_score = err_top1_val
net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch))

if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))
trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))

if save_frequency and save_dir:
net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))
trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1))

Fit - Proposed implementation

##sample fit function

def fit(net, train_data_loader, val_data_loader, epochs, loss_fn, trainers, context):
EventHandlers= [self.LoggingHandler, self.CheckpointHandler, self.MetricHandler]
EventHandlers = EventHandlers + self.additionalHandlers

for handlers in EventHandlers:
while not exit condition():
for handlers in EventHandlers:
for each epoch:
for handlers in EventHandlers:
##do a split and load for multigpu
x,y = split_and_load(train_dataloader)
y=net(x) ## forward pass
calculate loss using loss_fn
backward pass
for handlers in EventHandlers:
for handlers in EventHandlers:

for handlers in EventHandlers:

APPENDIX B - Supported Models

By supporting the following models, we believe we can cover most basic use cases for Gluon users

DomainCategoryModelReferenceFeature RequiredNote
CVImage ClassificationAlexNetGluon Booknet, dataloader, batch_size, trainer, ctx, num_epochsmlp, lenet, vgg are similar, example: train_ch5()
CVImage Augmentation + ClassificationResNet18Gluon Booknet, dataloader, batch_size, trainer, ctx, num_epochsexample: train_ch5()
CVSemantic SegmentationFCNGluon Bookmore data_transformation, multi-gpuexample: train()
CVObject DetectionSSDGluon Bookmultiple lables, losses, and metricstraining script from Gluon CV
NLPText Sentiment ClassificationBiRNNGluon Booksame as 1 &2example: train()
NLPText Sentiment classificationTextCNNGluon Booksame as 1 &2example: train()
NLPNeural Machine Translationencoder-decoder and attention mechanism.Gluon Bookmultiple trainer, different inputs for loss
VariousVariousLRKaggle BlogLR and XGBoost is most used besides CV and NLP modelsXGBoost is not in scope and not supported

APPENDIX C - Tensorflow estimators
Tensorflow estimators are objects in TF which provided three methods- Train, eval and predict.
There are two options on how to use the estimators- Custom estimators and pre-defined estimators
For each of the estimators, the model and algorithm is defined model using model_fn and three methods train, evaluate and predict are defined for using those models. Train_hooks and eval_hooks are part of train/predict methods callbacks to run codes within training loop/prediction code.

Some pre-defined estimators: DNN, Linear Regressor, etc
An example of how to define pre-defined estimator. There are multiple options available for DNN classifies

# estimator using the ProximalAdagradOptimizer optimizer with
# regularization.
estimator = DNNClassifier(
feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
hidden_units=[1024, 512, 256],

For customizing a specific part of pre-defined estimator, we need to re-create a new estimator with the customized module and then use it.

Custom-estimators examples:

To implement a typical model function, you must do the following:

Appendix D - Keras fit api

Example usage in of Keras fit API:

import kerasfrom keras.models
import Sequentialfrom keras.layers
import Dense, Dropout, Activation
from keras.optimizers import SGD
import numpy as np
x_train = np.random.random((1000, 20))
y_train = keras.utils.to_categorical(np.random.randint(10, size=(1000, 1)), num_classes=10)
x_test = np.random.random((100, 20))
y_test = keras.utils.to_categorical(np.random.randint(10, size=(100, 1)), num_classes=10)
model = get_model()
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
metrics=['accuracy']), y_train,

Keras fit API is implemented using Callback (custom object) which exposes methods to be called at
(i) beginning of training
(ii) end of training
(iii) beginning of epoch
(iv) end of epoch
(v) batch_begin
(vi) batch_end
It has a list of default implementation of callbacks like History, BaseLogger, CSVLogger, ModelCheckpoint, EarlyStopping, LRScheduler, Tensorboard. It also has an option of customizable callback which user may define as required.

  • No labels


  1. Update on test plans:

    Integration Tests:

    Integration tests will be put under tests/nightly/estimator
    We will import from model zoo for vision models (except FCN), and construct RNN models in the integration tests


    On GPU we choose one model to run 5 epochs and assert acc reach threshold

    1. test_estimator_cnn_gpu
      1. Resnet18
      2. Dataset: MNIST
    2. test_estimator_rnn_gpu
      1. BiRNN
      2. Dataset: IMDB


    Test dummy data of batch size 1 for 1 epoch to make sure all models we support are runnable.
    Will add more models once we have support for them

    1. test_estimator_cnn_cpu
      1. Alexnet, Resnet18, FCN
    2. test_estimator_rnn_cpu
      1. BiRNN, CNN

    Unite tests:

    To be updated...

  2. We have received following additional feedbacks, created JIRA tickets for each of them to keep track. All tasks are under the Gluon Fit API epic

    1. Improve call back efficiency to avoid calling empty methods
      1. Categorize handlers into 6 lists depending on whether they override base handler methods. Make sure empty methods from Base Event Handler is not called.
      2. This can be done by either multi-inheritance from 6 classes and do isinstance() or check if the method is overridden with the current design
    2. Improve batch size logic
      1. We can infer batch size from DataLoader, user does not need to pass it in fit method
      2. need checks for last batch and when batch cannot be split into context
    3. Provide History Object for easy information access for event handlers
      1. Currently when constructing event handlers, user need to pass the estimator class, so handlers can access any information/data they need at different events (Reference comment)
      2. We need to keep event handlers from access additional information they don't need from estimator, for example, data loaders, trainer, context
      3. It will be useful to keep some information as training states/history and be able to serialize them for easy resuming from checkpoints
      4. We can even make history object an event handler (setting information in states/history, serialize actions controlled by callbacks)
    4. Improve validation logic
      1. Needed for most use cases to prevent over fitting.
      2. When eval data is large, user may want to do validation every n epoch
      3. Make it an event handler to trigger on epoch end and can custom to do validation every n epoch
    5. Make metrics an event handler
      1. metric resets, update can be controlled by callback methods
      2. this will allow easy customization
      3. we can provide convenient method to register existing metrics as event handler
    6. Improve how callback is designed (2, 3,4 will depends on this)
      1. As we want to make 2,3,4 event handlers, need a new design to pass different information/arguments to different callbacks in fit loop
      2. reference:
    7. Support DataLoader, DataIter, or custom DataLoader (Nvidia DALI)
      1. In estimator, we should make the convention that data/label are provided in DataLoader format
      2. provide util function in for users to convert DataIters into DataLoader format
      3. Accept custom util functions to convert custom DataLoader into DataLoader format.
      4. Note: common way of using DataLoader: data=batch[0], label=batch[1], using DataIter,[0], label=batch.label[0]
    8. New feature - stop based on total number of batches trained
      1. need a global counter for number of batches, and stop until n batches/steps trained, not based on epochs
    9. New feature - Provide user defined stopping condition function
      1. could be composite condition function(val acc not improving over 3 epochs + n steps reached)
      2. alternative is custom event handlers, because if this custom function need to executed multiple times at different events to decide whether to stop, might as well make it a event handler.
    10. Provide default metrics for given loss
      1. reduce user input for beginners, automatically infer metrics for loss (e,g. softmax ce loss → accuracy)
      2. maintain loss, metric map
    11. New feature on checkpoint handler
      1. Save last n checkpoints, remove old checkpoints if n reached
      2. save some train states in Train History
      3. deserialize train history, resume from last checkpoint automatically if epoch not reached
    12. New feature - extensible for distributed training
      1. Horovod: use a custom trainer
      2. Parameter Server: batch_fn, trainer.step, should be the same as single node multi-GPU
      3. consider on the convention to do mean(loss) and step(1) or step(batch_size), batch_size in Horovod is per device, in PS is per worker
  3. Additional feedbacks from Mu:

    1. Do we know the limitation of the current API? E.g. which models that
    this API does not support. So far, it's hard to say that it covers 80% use
    2. Extendibility. So far it has a single gluon.Estimator class do all
    works. Are we considering to allow to extend this class to support future
    use cases?
    3. It's confusion some variables end with 's' while some don't. For
    has loss, metrics, trainers and context. All of them support list inputs.
    4. Relate to 3, how to map a list of inputs to each other. E.g. if I give a
    list of loss functions, then what's their inputs, and what's are the loss
    weight? Similarly, what does multiple trainers mean.
    5. How to initialize a model with different initializers for different
    layers (pretty common use case)
    6. How to restart for a previous
    7. How about if a model have multiple components, e.g. encode-decode, or
    multi-modality training
    8. It's strange to specify a batch_size in fit() for new users. (I know the
    reason, but need to explain to users)
    9. event_handler is quit powerful, but how can users to access internal
    states through the handle. BTW, we usually call it callback in Python.
    10. Programming flavor. The codes should be easy to read. For example, use
    same naming convention as mxnet/pytorch/keras. A function should be short,
    otherwise breaks it into several pieces. Currently fit() has >100 lines of