Problem Description

SVRG stands for Stochastic Variance Reduced Gradient, which was first introduced in the paper Accelerating Stochastic Gradient Descent using Predicative Variance Reduction in 2013. It is an optimization technique that complements SGD. SGD is known for large scale optimization but it suffers from slow convergence asymptotically due to the inherent variance. SGD approximates the full gradient using a small batch of samples which introduces variance. In order to converge faster, SGD often needs to start with a smaller learning rate. SVRG remedies the problem by keeping a version of the estimated weights that is close to the optimal parameters and maintain average of full gradient over full pass of data. The average of full gradients of all data is calculated w.r.t to parameters of last mth epochs. It has provable guarantees for strongly convex smooth functions, and a more detailed proof can be found in section 3 of the paper. SVRG uses a different update rule: gradients w.r.t current parameters minus gradients w.r.t parameters from the last mth epoch, plus the average of gradients over all data. 

Average of full gradient over a full pass of data w.r.t parameters of past m epochs:

SVRG Update Rule:

The initial set of experiments were conducted with linear regression model on YearPredictionMSD dataset, which contains more than 40, 000 samples. The results of using SVRG optimization showed strong guarantees of faster convergence compared to SGD. A more detailed analysis of experiment results can be found in Benchmark section.

Key Characteristics of SVRG:

  • Explicit variance reduction 
  • Ability to use relatively large learning rate compared to SGD, which leads to faster convergence.

Expected Deliverables

The goal is to implement an MXNet Python Module that implements SVRG optimization technique.

Tenets

  • Minimize the surface footprint by implementing a complete SVRGModule
  • From a user's perspective, using the SVRG Module should be similar to using MXNet Python Module API, except the underlying optimization technique will be SVRG. Minimize the differences of the external APIs of SVRGModule from the Module API.
  • SVRG Module should seamlessly support both dense and sparse data, run on CPU and GPU instances on single machine and in distributing setting. 

Implementation approach

A common question that was being asked is why not implement it as an optimizer. In Optimizer class, cross-keys operations that are required for SVRG are not supported.  After evaluating several options (listed in the appendix), we concluded to create an SVRGModule that follows the MXNet Module API and implements SVRG optimization technique under the hood.

In addition to the parameters used in Module class, SVRGModule class will have an attribute update_frequency that user can set to control the frequency at which the full gradients will be updated. For example, if the user sets update_frequency = 2, the full gradients over full pass of data will be updated every 2 epochs.  In particular, the full gradients will be averaged over the number of batches. For example, if the user sets update_frequency = 2, the full gradients over full pass of data will be updated every 2 epochs.  

Following operations will be permitted on SVRGModule: (The operations that are changes to the existing Module API are highlighted.)

  • Initialization: bind(), init_params(), init_optimizer()
  • Computation: forward(), backward(), update(), update_full_grads()
  • Parameters: get_params() return type will be dict, set_params() will be overloaded with a param_dict from both modules as parameter
  • High-Level API: fit(), predict(), score()

Like MXNet Module class, a user will be able to use the SVRGModule using high-level and intermediate level APIs.

The following snippet demonstrates the usage of High Level API of SVRGModule:

mod = mx.mod.SVRGModule(symbol=model, update_frequency=2, data_names=['data'], label_names=['lin_reg_label'])

# Interface for fit() will remain the same as BaseModule API
mod.fit(data_iterator, num_epoch=100, ...) 

The following snippet demonstrates the suggested usage of Intermediate API of SVRGModule. The following steps required to train a SVRG module are the same as that of Module API, except the additional steps to invoke update_full_grads() for special epochs.

mod = mx.mod.SVRGModule(symbol=model, update_frequency=2, data_names=['data'], label_names=['lin_reg_label'])
mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
mod.init_params()
mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ))
for epoch in range(1, num_epoch + 1):
      ##These are the only 2 lines that will be different from the module
      if epoch % (self.update_frequency + 1) == 0:
          mod.update_full_grads(data_iterator)
      di.reset()
      while not end_of_batch:
          data_batch = next_data_batch:
          mod.forward_backward(data_batch)
          mod.update()

Pros: SVRGModule is a new class that encapsulates the implementation logic of SVRG without modifying existing Module class.
Cons: Other MXNet built-in modules (ex.BucketingModule) cannot use the SVRG technique without implementing them separately.

Implementation Details

Derive from BaseModule vs Module

There are two possible implementation approaches, however both of them will require overwriting functions like fit() defined in the BaseModule to adapt to SVRG optimization logic.

  1. Inherit from BaseModule
    1. cons: it wraps two modules, and therefore need to define methods that are implemented in Module API but not in BaseModule.
  2. Inherit from Module
    1. pros: It will inherit all the additional methods that are defined in Module API and only requires encapsulating a single module instead of two as in implementation approach 1.
    2. cons: It is not consistent in terms of inheritance hierarchy with all other built-in module classes in MXNet.

Implementation of SVRG logic

  • SVRGModule will be a container module that maintains two separate modules - one that generates the actual model, the second one to accumulate/track the full gradients.
  • Customized Optimizer for SVRGModule: The full gradient has to be accumulated over the entire data set; In distributed setting, it requires synchronization across all workers to ensure that the final full gradient is representative of the entire data set. In other words, in distributed mode the SVRG Module in each worker needs to push the average of full gradients calculated w.r.t its own data shard to the the kvstore and pull the accumulated full gradients over all workers before applying SVRG update. In order to accomplish this, a custom optimizer, SVRGOptimizer, will be implemented. This SVRGOptimizer will be set in KVStore via set_optimizer() method.

Testing Strategy

Experiments will be performed to compare SGD vs SVRG.

  • Starting with single machine/work implementation of SVRG, a test dataset will be prototyped to compare performances of normal SGD and SVRG. For each feature vector, the mean, standard deviation and variance will be calculated and compare with those of SGD. The variance of using SVRG should be smaller than that of SGD.
  • Unit Tests will be written for testing the behaviors of the SVRGModule API calls for both Intermediate API and High Level API.
  • Additional functional tests will be conducted:
    • Dense data vs Sparse data (CSR and Row Sparse)
    • Single machine training with CPUs, single GPU and multiple GPUs
    • Distributed training with multiple machines 

Benchmark

The SVRG implementation using existing Module API is benchmarked on YearPrediction dataset with linear regression model.

Training Loss over 100 Epochs with lr_scheduler

A lr_scheduler returns a new learning rate based on the number of updates that have been performed. The training loss of SVRG is less than SGD with lr_scheduler over all of the 100 epochs, as shown in the graph below.

Training Loss Comparisons with SGD, fixed learning rates

One drawback for SGD is that in order to converge faster, the learning rate has to decay to zero, thus SGD needs to start with a small learning rate. The learning rate does not need to decay to zero for SVRG, therefore we can use a relatively larger learning rate. SGD with learning rate of (0.001, 0.0025) and SVRG with learning rate of (0.025) are benchmarked. Even though SVRG starts with a relatively large learning rate, it converges much faster than SGD in both cases. This particular experiment result aligns with what was stated in the SVRG paper section 5. 

Frequently Asked Questions

-- Why not implement SVRG Optimization technique as a Optimizer?
A: SVRG optimization logic requires calculation of full gradients w.r.t full pass of data every other update_frequency epochs. There is currently no notion of epoch in the Optimizer class. Full gradients calculations will also require access to loop through full dataset in batches and cross key operations, which can't be accomplished via Optimizer class.
-- Is it possible to implement using Gluon API?
A: Gluon API with Autograd was explored as a possible implementation option but it will not provide convenience to package SVRG for users. On the other hand, the MXNet SVRGModule API will provide a clean, single point of entry to users. Instead, sample code for Gluon API can be provided at a later date. 

Appendix:

Alternative Design Approach 

Another design approach considered is to modify current Module API to take in an extra parameter use_svrg to indicate if SVRG optimization will be used in training. However, changing Module API may introduce backward compatibility issues.

High-level API to use Module with use_svrg:

mod = mx.mod.Module(symbol=model, use_svrg=True, data_names=['data'], label_names=['lin_reg_label'])

mod.fit(di, num_epoch=100, ...)

Links & Resources