This Confluence has been LDAP enabled, if you are an ASF Committer, please use your LDAP Credentials to login. Any problems file an INFRA jira ticket please.

Page tree

Versions Compared


  • This line was added.
  • This line was removed.
  • Formatting was changed.

Table of Contents

Note: extended design for callbcaks here

Problem and Goals



class EventHandler:
def __init__(self,estimator):
self._train_stats= estimator.train_stats

def train_begin(self):
def train_end(self):
def batch_begin(self):
def batch_end(self):
def epoch_begin(self):
def epoch_end(self):
class LoggingHandler(EventHandler):
def __init__(self, estimator, log_loc = './'):
# setup logging
def epoch_end:
## log the train stats to log location

class CheckpointHandler(EventHandler):
def __init__(self, estimator, checkpoint_interval=5 , ckpt_loc='./', monitor= "val_loss"):
train_stats = {"lr" = [0.1], "train_acc" = [0.85], "val_acc" = [0.99], ... }
def epoch_end:
## save the model params to the checkpointing location

class MetricHandler(EventHandler):
def __init__(self, estimator):
train_stats = {"lr" = [0.1], "train_acc" = [0.85], "val_acc" = [0.99], ... }
def epoch_end:
## calculate and update metrics for thr training dataset
## update_metrics(pred, labels)- default implementation can be overriden in case of multi-output cases
## update validation metrics for validation dataset

class EarlyStopping(EventHandler):
def __init__(self, monitor= "val_loss", min_delta=0, patience=0, mode="auto", baseline=None, restore_best_params=False):
# setup early stopping rules based on the metric/loss monitor and the mode
# e.g. if "acc" use greater mode else use lesser
def on_epoch_end:
# if metric improved, record the best value
# else wait n epochs(n=patience) and stop trainning
# restore net parameters from the best epoch accordingly
def on_train_end:
# let user know if early stopping is triggered


By supporting the following models, we believe we can cover most basic use cases for Gluon users

DomainCategoryModelReferenceFeature RequiredNote
CVImage ClassificationAlexNetGluon Booknet, dataloader, batch_size, trainer, ctx, num_epochsmlp, lenet, vgg are similar, example: train_ch5()
CVImage Augmentation + ClassificationResNet18Gluon Booknet, dataloader, batch_size, trainer, ctx, num_epochsexample: train_ch5()
CVSemantic SegmentationFCNGluon Bookmore data_transformation, multi-gpuexample: train()
CVObject DetectionSSDGluon Bookmultiple lables, losses, and metricstraining script from Gluon CV
NLPText Sentiment ClassificationBiRNNGluon Booksame as 1 &2example: train()
NLPText Sentiment classificationTextCNNGluon Booksame as 1 &2example: train()
NLPNeural Machine Translationencoder-decoder and attention mechanism.Gluon Bookmultiple trainer, different inputs for loss
VariousVariousLRKaggle BlogLR and XGBoost is most used besides CV and NLP modelsXGBoost is not in scope and not supported

APPENDIX C - Tensorflow estimators