Table of Contents |
---|
Note: extended design for callbcaks here
Problem and Goals
Background
...
class EventHandler:
def __init__(self,estimator):
self._train_stats= estimator.train_stats
def train_begin(self):
pass
def train_end(self):
pass
def batch_begin(self):
pass
def batch_end(self):
pass
def epoch_begin(self):
pass
def epoch_end(self):
passclass LoggingHandler(EventHandler):
def __init__(self, estimator, log_loc = './'):
# setup logging
def epoch_end:
## log the train stats to log location
class CheckpointHandler(EventHandler):
def __init__(self, estimator, checkpoint_interval=5 , ckpt_loc='./', monitor= "val_loss"):
super.__init__()
train_stats = {"lr" = [0.1], "train_acc" = [0.85], "val_acc" = [0.99], ... }
def epoch_end:
## save the model params to the checkpointing location
class MetricHandler(EventHandler):
def __init__(self, estimator):
super.__init__()
train_stats = {"lr" = [0.1], "train_acc" = [0.85], "val_acc" = [0.99], ... }
def epoch_end:
## calculate and update metrics for thr training dataset
## update_metrics(pred, labels)- default implementation can be overriden in case of multi-output cases
## update validation metrics for validation dataset
class EarlyStopping(EventHandler):
def __init__(self, monitor= "val_loss", min_delta=0, patience=0, mode="auto", baseline=None, restore_best_params=False):
# setup early stopping rules based on the metric/loss monitor and the mode
# e.g. if "acc" use greater mode else use lesser
def on_epoch_end:
# if metric improved, record the best value
# else wait n epochs(n=patience) and stop trainning
# restore net parameters from the best epoch accordingly
def on_train_end:
# let user know if early stopping is triggered
...
By supporting the following models, we believe we can cover most basic use cases for Gluon users
Domain | Category | Model | Reference | Feature Required | Note |
---|---|---|---|---|---|
CV | Image Classification | AlexNet | Gluon Book | net, dataloader, batch_size, trainer, ctx, num_epochs | mlp, lenet, vgg are similar, example: train_ch5() |
CV | Image Augmentation + Classification | ResNet18 | Gluon Book | net, dataloader, batch_size, trainer, ctx, num_epochs | example: train_ch5() |
CV | Semantic Segmentation | FCN | Gluon Book | more data_transformation, multi-gpu | example: train() |
CV | Object Detection | SSD | Gluon Book | multiple lables, losses, and metrics | training script from Gluon CV |
NLP | Text Sentiment Classification | BiRNN | Gluon Book | same as 1 &2 | example: train() |
NLP | Text Sentiment classification | TextCNN | Gluon Book | same as 1 &2 | example: train() |
NLP | Neural Machine Translation | encoder-decoder and attention mechanism. | Gluon Book | multiple trainer, different inputs for loss | |
Various | Various | LR | Kaggle Blog | LR and XGBoost is most used besides CV and NLP models | XGBoost is not in scope and not supported |
APPENDIX C - Tensorflow estimators
...