This is an extended design discussion for callback design in training fit loop of the Gluon Fit API design. Callback is a powerful tool, we can use it to deliver useful features during training for users. For example, metric update, validation, logging, and saving a model periodically.
In this document, we will discuss what functionality we should provide, how to manage shared/independent states in callbacks, how to pass arguments in callbacks, and what function should be implemented as callbacks.
Note: callbacks in the Gluon Fit-API/Estimator Design are called event handlers to avoid confusion of modules under mx.callbcaks (used for mxnet modules only)
Current supported events to be triggered, each of the following events will be triggered during different stages of the training loop.
Can be added events
Let's start with an example to discuss the ways to implement callbacks. Let's implement a Stopping Criterial to stop at certain number of batches or epochs. It tells the for loop whether to stop training.
stop = StopTrainingHandlerV2(max_batch=100, max_epoch=10)
for epoch in range(20):
for batch in range(25):
print('epoch: ', epoch)
print('batch: ', batch)
batch_result = stop.batch_end()
if batch_result:
break
epoch_result = stop.epoch_end()
if epoch_result:
break
One base class for one type of event, base class keep states that's specific to that event only.
Rule of thumb: common states for both Base classes should be managed by sub class inherit those base classes.
class BatchEnd(object):
def __init__(self, max_batch=None):
self.batch_idx = 0
self.total_batch = 0
self.max_batch = max_batch
def batch_end(self, batch_result={}):
self.batch_idx += 1
self.total_batch += 1
class EpochEnd(object):
def __init__(self, max_epoch=None):
self.epoch = 0
self.max_epoch = max_epoch
def epoch_end(self, epoch_result={}):
self.epoch += 1
class StopTrainingHandler(BatchEnd, EpochEnd):
def __init__(self, max_batch, max_epoch):
super().__init__(max_batch)
super(BatchEnd, self).__init__(max_epoch)
self.stop_training = False
def batch_end(self, batch_result={}):
super(StopTrainingHandler, self).batch_end(batch_result)
if self.total_batch == self.max_batch:
self.stop_training = True
return self.stop_training
def epoch_end(self, epoch_result={}):
super(StopTrainingHandler, self).epoch_end(epoch_result)
# reset batch index at end
self.batch_idx = 0
if self.epoch == self.max_epoch:
self.stop_training = True
return self.stop_training
One base class with all event methods
class EventHandler(object):
"""Basic for event handlers
:py:class:`EventHandler` can perform user defined functions at
different stages of training: train begin, epoch begin, batch begin,
batch end, epoch end, train end.
Parameters
----------
estimator : Estimator
The :py:class:`Estimator` to get training statistics
"""
def __init__(self):
self._estimator = None
def train_begin(self, *args, **kwargs):
pass
def epoch_begin(self, *args, **kwargs):
pass
def batch_begin(self, *args, **kwargs):
pass
def batch_end(self, batch_id, batch_results=None, *args, **kwargs):
return False
def epoch_end(self, epoch, epoch_results=None, *args, **kwargs):
return False
def train_end(self, *args, **kwargs):
pass
class StopTrainingHandlerV2(object):
def __init__(self, max_batch, max_epoch):
self.batch_idx = 0
self.epoch = 0
self.total_batch = 0
self.max_epoch = max_epoch
self.max_batch = max_batch
self.stop_training = False
def batch_end(self, batch_result={}, *args, **kwargs):
self.batch_idx += 1
self.total_batch += 1
if self.total_batch == self.max_batch:
self.stop_training = True
return self.stop_training
def epoch_end(self, epoch_result={}, *args, **kwargs):
self.epoch += 1
# reset batch index at end
self.batch_idx = 0
if self.epoch == self.max_epoch:
self.stop_training = True
return self.stop_training
handler
.
__class__
.train_begin
==
EventHandler.train_begin
class TrainBegin(object):
def train_begin(self, estimator, *args, **kwargs):
pass
class TrainEnd(object):
def train_end(self, estimator, *args, **kwargs):
pass
class EpochBegin(object):
def epoch_begin(self, estimator, *args, **kwargs):
pass
class EpochEnd(object):
def epoch_end(self, estimator, *args, **kwargs):
return False
class BatchBegin(object):
def batch_begin(self, estimator, *args, **kwargs):
pass
class BatchEnd(object):
def batch_end(self, estimator, *args, **kwargs):
return False
class MetricHandler(EpochBegin, BatchEnd):
def __init__(self, train_metrics):
self.train_metrics = train_metrics
# order to be called among all callbacks
# metrics need to be calculated before other callbacks can access them
self.rank = 1
def epoch_begin(self, estimator, *args, **kwargs):
for metric in self.train_metrics:
metric.reset()
def batch_end(self, estimator, *args, **kwargs):
pred = kwargs['pred']
label = kwargs['label']
loss = kwargs['loss']
for metric in self.train_metrics:
if isinstance(metric, Loss):
# metric wrapper for loss values
metric.update(0, loss)
else:
metric.update(label, pred)
There are many internal training states we want to keep track of during the training process. Fit loop can book keep some of these internal states and each callback can also book keep them. We need a criterial to decide who to manage each state.
There are mainly two design considerations:
Fit loop has the following structure:
for epoch in range(max_epochs):
for i, batch in train_data:
...
here epoch number, batch number is naturally managed, epoch won't exceed max epoch, batch index is automatically set to 0 after each epoch. They should be managed in Fit loop and passed to callbacks
What states are ok to keep copies in each callback, what states should be managed by Fit loop?
We categorize states into 3 types and they should be managed differently
Based on the conclusion above, we need to decide how to pass those states to callbacks.
For states managed by each callback, they can access and update them as in StopTrainingHandlerV2 above. Here we discuss how to pass external arguments (those managed by fit loop)
There are a few options
We can pass some of these states through kwargs during callback calls. We can also inject estimator into each callback so they have access to everything managed by estimator. We need to draw a line on what states will be book kept by what. Based on the conclusion on different states:
The following are provided as callbacks for now:
Here we discuss whether we should make other features as callbacks and what additional requirement they need to become a callback