Problem Description

SVRG stands for Stochastic Variance Reduced Gradient, which was first introduced in the paper Accelerating Stochastic Gradient Descent using Predicative Variance Reduction in 2013. It is an optimization technique that complements SGD. SGD is known for large scale optimization but it suffers from slow convergence asymptotically due to the inherent variance. SGD approximates the full gradient using a small batch of samples which introduces variance. In order to converge faster, SGD often needs to start with a smaller learning rate. SVRG remedies the problem by keeping a version of the estimated weights that is close to the optimal parameters and maintain average of full gradient over full pass of data. The average of full gradients of all data is calculated w.r.t to parameters of last mth epochs. It has provable guarantees for strongly convex smooth functions, and a more detailed proof can be found in section 3 of the paper. SVRG uses a different update rule: gradients w.r.t current parameters minus gradients w.r.t parameters from the last mth epoch, plus the average of gradients over all data. 

Average of full gradient over a full pass of data w.r.t parameters of past m epochs:

SVRG Update Rule:

The initial set of experiments were conducted with linear regression model on YearPredictionMSD dataset, which contains more than 40, 000 samples. The results of using SVRG optimization showed strong guarantees of faster convergence compared to SGD. A more detailed analysis of experiment results can be found in Benchmark section.

Key Characteristics of SVRG:

Expected Deliverables

The goal is to implement an MXNet Python Module that implements SVRG optimization technique.

Tenets

Implementation approach

A common question that was being asked is why not implement it as an optimizer. In Optimizer class, cross-keys operations that are required for SVRG are not supported.  After evaluating several options (listed in the appendix), we concluded to create an SVRGModule that follows the MXNet Module API and implements SVRG optimization technique under the hood.

In addition to the parameters used in Module class, SVRGModule class will have an attribute update_frequency that user can set to control the frequency at which the full gradients will be updated. For example, if the user sets update_frequency = 2, the full gradients over full pass of data will be updated every 2 epochs.  In particular, the full gradients will be averaged over the number of batches. For example, if the user sets update_frequency = 2, the full gradients over full pass of data will be updated every 2 epochs.  

Following operations will be permitted on SVRGModule: (The operations that are changes to the existing Module API are highlighted.)

Like MXNet Module class, a user will be able to use the SVRGModule using high-level and intermediate level APIs.

The following snippet demonstrates the usage of High Level API of SVRGModule:

mod = mx.mod.SVRGModule(symbol=model, update_frequency=2, data_names=['data'], label_names=['lin_reg_label'])

# Interface for fit() will remain the same as BaseModule API
mod.fit(data_iterator, num_epoch=100, ...) 

The following snippet demonstrates the suggested usage of Intermediate API of SVRGModule. The following steps required to train a SVRG module are the same as that of Module API, except the additional steps to invoke update_full_grads() for special epochs.

mod = mx.mod.SVRGModule(symbol=model, update_frequency=2, data_names=['data'], label_names=['lin_reg_label'])
mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
mod.init_params()
mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ))
for epoch in range(1, num_epoch + 1):
      ##These are the only 2 lines that will be different from the module
      if epoch % (self.update_frequency + 1) == 0:
          mod.update_full_grads(data_iterator)
      di.reset()
      while not end_of_batch:
          data_batch = next_data_batch:
          mod.forward_backward(data_batch)
          mod.update()

Pros: SVRGModule is a new class that encapsulates the implementation logic of SVRG without modifying existing Module class.
Cons: Other MXNet built-in modules (ex.BucketingModule) cannot use the SVRG technique without implementing them separately.

Implementation Details

Derive from BaseModule vs Module

There are two possible implementation approaches, however both of them will require overwriting functions like fit() defined in the BaseModule to adapt to SVRG optimization logic.

  1. Inherit from BaseModule
    1. cons: it wraps two modules, and therefore need to define methods that are implemented in Module API but not in BaseModule.
  2. Inherit from Module
    1. pros: It will inherit all the additional methods that are defined in Module API and only requires encapsulating a single module instead of two as in implementation approach 1.
    2. cons: It is not consistent in terms of inheritance hierarchy with all other built-in module classes in MXNet.

Implementation of SVRG logic

Testing Strategy

Experiments will be performed to compare SGD vs SVRG.

Benchmark

The SVRG implementation using existing Module API is benchmarked on YearPrediction dataset with linear regression model.

Training Loss over 100 Epochs with lr_scheduler

A lr_scheduler returns a new learning rate based on the number of updates that have been performed. The training loss of SVRG is less than SGD with lr_scheduler over all of the 100 epochs, as shown in the graph below.

Training Loss Comparisons with SGD, fixed learning rates

One drawback for SGD is that in order to converge faster, the learning rate has to decay to zero, thus SGD needs to start with a small learning rate. The learning rate does not need to decay to zero for SVRG, therefore we can use a relatively larger learning rate. SGD with learning rate of (0.001, 0.0025) and SVRG with learning rate of (0.025) are benchmarked. Even though SVRG starts with a relatively large learning rate, it converges much faster than SGD in both cases. This particular experiment result aligns with what was stated in the SVRG paper section 5. 

Frequently Asked Questions

-- Why not implement SVRG Optimization technique as a Optimizer?
A: SVRG optimization logic requires calculation of full gradients w.r.t full pass of data every other update_frequency epochs. There is currently no notion of epoch in the Optimizer class. Full gradients calculations will also require access to loop through full dataset in batches and cross key operations, which can't be accomplished via Optimizer class.
-- Is it possible to implement using Gluon API?
A: Gluon API with Autograd was explored as a possible implementation option but it will not provide convenience to package SVRG for users. On the other hand, the MXNet SVRGModule API will provide a clean, single point of entry to users. Instead, sample code for Gluon API can be provided at a later date. 

Appendix:

Alternative Design Approach 

Another design approach considered is to modify current Module API to take in an extra parameter use_svrg to indicate if SVRG optimization will be used in training. However, changing Module API may introduce backward compatibility issues.

High-level API to use Module with use_svrg:

mod = mx.mod.Module(symbol=model, use_svrg=True, data_names=['data'], label_names=['lin_reg_label'])

mod.fit(di, num_epoch=100, ...)

Links & Resources