This Confluence has been LDAP enabled, if you are an ASF Committer, please use your LDAP Credentials to login. Any problems file an INFRA jira ticket please.

Page tree

Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

Table of Contents


Note: extended design for callbcaks here

Problem and Goals

Background

...

class EventHandler:
def __init__(self,estimator):
self._train_stats= estimator.train_stats


def train_begin(self):
pass
def train_end(self):
pass
def batch_begin(self):
pass
def batch_end(self):
pass
def epoch_begin(self):
pass
def epoch_end(self):
pass
class LoggingHandler(EventHandler):
def __init__(self, estimator, log_loc = './'):
# setup logging
def epoch_end:
## log the train stats to log location

class CheckpointHandler(EventHandler):
def __init__(self, estimator, checkpoint_interval=5 , ckpt_loc='./', monitor= "val_loss"):
super.__init__()
train_stats = {"lr" = [0.1], "train_acc" = [0.85], "val_acc" = [0.99], ... }
def epoch_end:
## save the model params to the checkpointing location

class MetricHandler(EventHandler):
def __init__(self, estimator):
super.__init__()
train_stats = {"lr" = [0.1], "train_acc" = [0.85], "val_acc" = [0.99], ... }
def epoch_end:
## calculate and update metrics for thr training dataset
## update_metrics(pred, labels)- default implementation can be overriden in case of multi-output cases
## update validation metrics for validation dataset

class EarlyStopping(EventHandler):
def __init__(self, monitor= "val_loss", min_delta=0, patience=0, mode="auto", baseline=None, restore_best_params=False):
# setup early stopping rules based on the metric/loss monitor and the mode
# e.g. if "acc" use greater mode else use lesser
def on_epoch_end:
# if metric improved, record the best value
# else wait n epochs(n=patience) and stop trainning
# restore net parameters from the best epoch accordingly
def on_train_end:
# let user know if early stopping is triggered

...

By supporting the following models, we believe we can cover most basic use cases for Gluon users

DomainCategoryModelReferenceFeature RequiredNote
CVImage ClassificationAlexNetGluon Booknet, dataloader, batch_size, trainer, ctx, num_epochsmlp, lenet, vgg are similar, example: train_ch5()
CVImage Augmentation + ClassificationResNet18Gluon Booknet, dataloader, batch_size, trainer, ctx, num_epochsexample: train_ch5()
CVSemantic SegmentationFCNGluon Bookmore data_transformation, multi-gpuexample: train()
CVObject DetectionSSDGluon Bookmultiple lables, losses, and metricstraining script from Gluon CV
NLPText Sentiment ClassificationBiRNNGluon Booksame as 1 &2example: train()
NLPText Sentiment classificationTextCNNGluon Booksame as 1 &2example: train()
NLPNeural Machine Translationencoder-decoder and attention mechanism.Gluon Bookmultiple trainer, different inputs for loss
VariousVariousLRKaggle BlogLR and XGBoost is most used besides CV and NLP modelsXGBoost is not in scope and not supported


APPENDIX C - Tensorflow estimators

...