Note: extended design for callbcaks here

Problem and Goals


Training a model in Gluon requires users to write the training loop, this is useful because of its imperative nature, however repeating the same code across multiple models can become tedious and repetitive with boilerplate code. The training loop can also be overwhelming to some users new to deep learning.
Users have asked for a simple Fit API, similar to APIs available in SKLearn and Keras (example forum ask) as a way to simplify model training and reduce boilerplate code and complexity.

Target Users

  1. Beginners who are new to deep learning and/or to Gluon.
  2. Applied scientists and deep learning practitioners training models with low complexity and non-custom requirements.

Goals and Deliverables

Existing user experience

Currently in Gluon because of its imperative style of programming, users write the entire training loop which requires multiple steps. To see a code example, see appendix A.
Writing the custom loop involves:

  1. Set-up the train/validate data
  2. Do a forward pass on batch of data
  3. Calculate loss
  4. Calculate gradients after back-propagation
  5. Update weights
  6. Evaluate metrics
  7. Update logs
  8. Save models

The above steps are generic and remains same across many models being trained. This is a repetitive work and we can simplify this by providing a simple “fit” API for the user that caters for 80% of modeling use-cases especially for novice users. We prefer advanced users continue to write their own custom training loops which gives them more flexibility.

Proposed Approach

We propose to add a simple “fit” API for Gluon by offering a new Estimator class which includes a fit method - similar to SKLearn Classifier API. Gluon Estimator will hold details of the model training like training statistics, training network and event handlers.
We will have EventHandler base class that exposes methods for the stages of the training loop viz. train_begin, train_end, epoch_begin, epoch_end, batch_begin and batch_end which gives users flexibility to override and customize different stages of training. We will also provide default event handlers will for common actions such as Logging, Metrics, EarlyStopping, Checkpointing.
So, the “fit” method will run forward pass, calculate loss and gradients, log, checkpoint and update metrics.

new user experience

The new API described below reduces the number of lines of code to be written by the user. In cases where the user implements logging and checkpointing, the number of lines is reduced from ~40 to ~6

Below is an example for the Fit API implementing similar functionality to the one using the existing training loop in Appendix A.

import mx.gluon.estimator as est
net = get_model() ## get the network
loss = gluon.loss.CrossEntropy()
e = est(net, lossfn = loss)
## training
trainers = [gluon.Trainer('sgd',{'learning_rate':0.001})], val_data, epochs, trainers, context)

Error Handling

We will provide the following checks to make sure the fit method is robust and easy to debug. We aim to let user know clearly what's wrong with clear error messages.

  1. Checking required arguments and optional arguments, notify users if default values are used.
  2. Checking the number of inputs from DataLoader and the number of losses.
  3. Check the number of outputs from network and the number of inputs for losses and metrics.

All error handling will be covered in unit tests.

Estimator class

Estimator is the new class that will encapsulate training and expose a Fit method.

class Estimator:
def __init__(self, net, lossfn=gluon.loss.CrossEntropy(),
metrics=[mx.metric.TopKAccuracy(), mx.metric.RMSE()]):
self._train_stats = {"lr" : [], "epoch":[],"train_metric1" : [], "val_metric1" : [],
"train_loss1" : [], "val_loss1" : [] , "time":[] ...}
self._net = net
self._loss = lossfn
self._metrics = metrics
self._loggingHandler= LoggingHandler(self)
self._checkpointingHandler= CheckpointHandler(self)
self._metricsHandler= MetricHandler(self)
self._additionalHandlers= [] ##can be the custom eventhandlers

def fit(self, train_data_loader,

def metrics:
return self._metrics

def additionalHandlers:
return self._additionalHandlers

def loggingHandler:
return self._loggingHandler

def checkpointhandler:
return self._checkpointHandler

def metricHandler:
return self._metricHandler

##Loss fn should take predictions and labels as input and return a scalar loss
def loss:
return self._loss

def plot_loss_graph(self):
## plan to support MXboard in next phase

EventHandler CLASS

EventHandler is a new class to offer customizing the training process by offering callbacks for handling different stages of the training.
We will implement standard handlers for logging, checkpointing and more. The user can extend EventHandler and implement their own custom handlers.

class EventHandler:
def __init__(self,estimator):
self._train_stats= estimator.train_stats

def train_begin(self):
def train_end(self):
def batch_begin(self):
def batch_end(self):
def epoch_begin(self):
def epoch_end(self):
class LoggingHandler(EventHandler):
def __init__(self, estimator, log_loc = './'):
# setup logging
def epoch_end:
## log the train stats to log location

class CheckpointHandler(EventHandler):
def __init__(self, estimator, checkpoint_interval=5 , ckpt_loc='./', monitor= "val_loss"):
train_stats = {"lr" = [0.1], "train_acc" = [0.85], "val_acc" = [0.99], ... }
def epoch_end:
## save the model params to the checkpointing location

class MetricHandler(EventHandler):
def __init__(self, estimator):
train_stats = {"lr" = [0.1], "train_acc" = [0.85], "val_acc" = [0.99], ... }
def epoch_end:
## calculate and update metrics for thr training dataset
## update_metrics(pred, labels)- default implementation can be overriden in case of multi-output cases
## update validation metrics for validation dataset

class EarlyStopping(EventHandler):
def __init__(self, monitor= "val_loss", min_delta=0, patience=0, mode="auto", baseline=None, restore_best_params=False):
# setup early stopping rules based on the metric/loss monitor and the mode
# e.g. if "acc" use greater mode else use lesser
def on_epoch_end:
# if metric improved, record the best value
# else wait n epochs(n=patience) and stop trainning
# restore net parameters from the best epoch accordingly
def on_train_end:
# let user know if early stopping is triggered

Addition of New APIs

The design adds new classes Estimator and EventHandler which solves the problem at hand.

Backward compatibility

The design doesn't alter any existing APIs and so the design is backward compatible.

Performance Considerations

The design proposes APIs to substitute training loops in Gluon and shouldn't have any performance regressions.
We can add tests to compare training using training loops vs the new APIs and compare the training times to understand to make sure that their is no regression.

Test Plan

We will implement 100% unit test coverage for the new API and event handlers.

  1. A test with all possible parameters
  2. A test with default parameters
  3. A test with missing parameters

We will add integration tests covering all of the models in the release goals (Appendix). See comment section for integration test plan.

Technical Challenges / Open Questions

MetricEventHandler requires to access net outputs and the labels to calculate metric and update the metrics which are available in the Estimator class. This EventHandler is important while dealing with multiple outputs/multiple metrics as the logic to associate outputs with metrics lies here. As an alternative we can have metric_update function to estimator which can be customized by the users.


  1. MXNet-Module APIs-,
  2. Tensorflow estimators-
  3. Keras Model API- (fit, predict, evaluate)-
  5. Scikit learn-
  6. Torchbearer is a framework for doing easy fit, evaluate and predict on Pytorch. It is implemented using state objects which hold states of all the callback functions. The training parameters (loss, optimizer, metric) are passed as callbacks. They use association between states and callbacks for making the fit method work. There are similar frameworks for Pytorch like skorch, PyToune, ignite, TorchNetTwo (TNT), Inferno. Skorch uses Scikit learn internally and we don't want to add additional dependencies

APPENDIX A - Gluon Training Loop Example

current user experience

##Current training loop in gluon
# Only one epoch
num_epochs = 1

trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params)
L = gluon.loss.SoftmaxCrossEntropyLoss()
best_val_score = 1

for epoch in range(num_epochs):
tic = time.time()
btic = time.time()

for i, batch in enumerate(train_data):
data, label = batch_fn(batch, ctx)

with ag.record():
outputs = [net(X) for X in data]
loss = [L(yhat, y) for yhat, y in zip(outputs, label)]
for l in loss:
lr_scheduler.update(i, epoch)

train_metric.update(label, outputs)

if log_interval and not (i+1)%log_interval:
train_metric_name, train_metric_score = train_metric.get()'Epoch[%d] Batch [%d]\tSpeed: %f samples/sec\t%s=%f\tlr=%f'%(
epoch, i, batch_size*log_interval/(time.time()-btic),
train_metric_name, train_metric_score, trainer.learning_rate))
btic = time.time()

train_metric_name, train_metric_score = train_metric.get()
throughput = int(batch_size * i /(time.time() - tic))

err_top1_val, err_top5_val = test(ctx, val_data)'[Epoch %d] training: %s=%f'%(epoch, train_metric_name, train_metric_score))'[Epoch %d] speed: %d samples/sec\ttime cost: %f'%(epoch, throughput, time.time()-tic))'[Epoch %d] validation: err-top1=%f err-top5=%f'%(epoch, err_top1_val, err_top5_val))

if err_top1_val < best_val_score:
best_val_score = err_top1_val
net.save_parameters('%s/%.4f-imagenet-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
trainer.save_states('%s/%.4f-imagenet-%s-%d-best.states'%(save_dir, best_val_score, model_name, epoch))

if save_frequency and save_dir and (epoch + 1) % save_frequency == 0:
net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, epoch))
trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, epoch))

if save_frequency and save_dir:
net.save_parameters('%s/imagenet-%s-%d.params'%(save_dir, model_name, opt.num_epochs-1))
trainer.save_states('%s/imagenet-%s-%d.states'%(save_dir, model_name, opt.num_epochs-1))

Fit - Proposed implementation

##sample fit function

def fit(net, train_data_loader, val_data_loader, epochs, loss_fn, trainers, context):
EventHandlers= [self.LoggingHandler, self.CheckpointHandler, self.MetricHandler]
EventHandlers = EventHandlers + self.additionalHandlers

for handlers in EventHandlers:
while not exit condition():
for handlers in EventHandlers:
for each epoch:
for handlers in EventHandlers:
##do a split and load for multigpu
x,y = split_and_load(train_dataloader)
y=net(x) ## forward pass
calculate loss using loss_fn
backward pass
for handlers in EventHandlers:
for handlers in EventHandlers:

for handlers in EventHandlers:

APPENDIX B - Supported Models

By supporting the following models, we believe we can cover most basic use cases for Gluon users

DomainCategoryModelReferenceFeature RequiredNote
CVImage ClassificationAlexNetGluon Booknet, dataloader, batch_size, trainer, ctx, num_epochsmlp, lenet, vgg are similar, example: train_ch5()
CVImage Augmentation + ClassificationResNet18Gluon Booknet, dataloader, batch_size, trainer, ctx, num_epochsexample: train_ch5()
CVSemantic SegmentationFCNGluon Bookmore data_transformation, multi-gpuexample: train()
CVObject DetectionSSDGluon Bookmultiple lables, losses, and metricstraining script from Gluon CV
NLPText Sentiment ClassificationBiRNNGluon Booksame as 1 &2example: train()
NLPText Sentiment classificationTextCNNGluon Booksame as 1 &2example: train()
NLPNeural Machine Translationencoder-decoder and attention mechanism.Gluon Bookmultiple trainer, different inputs for loss
VariousVariousLRKaggle BlogLR and XGBoost is most used besides CV and NLP modelsXGBoost is not in scope and not supported

APPENDIX C - Tensorflow estimators
Tensorflow estimators are objects in TF which provided three methods- Train, eval and predict.
There are two options on how to use the estimators- Custom estimators and pre-defined estimators
For each of the estimators, the model and algorithm is defined model using model_fn and three methods train, evaluate and predict are defined for using those models. Train_hooks and eval_hooks are part of train/predict methods callbacks to run codes within training loop/prediction code.

Some pre-defined estimators: DNN, Linear Regressor, etc
An example of how to define pre-defined estimator. There are multiple options available for DNN classifies

# estimator using the ProximalAdagradOptimizer optimizer with
# regularization.
estimator = DNNClassifier(
feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
hidden_units=[1024, 512, 256],

For customizing a specific part of pre-defined estimator, we need to re-create a new estimator with the customized module and then use it.

Custom-estimators examples:

To implement a typical model function, you must do the following:

Appendix D - Keras fit api

Example usage in of Keras fit API:

import kerasfrom keras.models
import Sequentialfrom keras.layers
import Dense, Dropout, Activation
from keras.optimizers import SGD
import numpy as np
x_train = np.random.random((1000, 20))
y_train = keras.utils.to_categorical(np.random.randint(10, size=(1000, 1)), num_classes=10)
x_test = np.random.random((100, 20))
y_test = keras.utils.to_categorical(np.random.randint(10, size=(100, 1)), num_classes=10)
model = get_model()
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
metrics=['accuracy']), y_train,

Keras fit API is implemented using Callback (custom object) which exposes methods to be called at
(i) beginning of training
(ii) end of training
(iii) beginning of epoch
(iv) end of epoch
(v) batch_begin
(vi) batch_end
It has a list of default implementation of callbacks like History, BaseLogger, CSVLogger, ModelCheckpoint, EarlyStopping, LRScheduler, Tensorboard. It also has an option of customizable callback which user may define as required.