Superconvergence in PyTorch

In Super-Convergence: Very fast training of neural networks using large learning rates1, Smith and Tobin present evidence for a learning rate parametrization scheme that can result in a 10x decrease in training time, while maintaining similar accuracy. Specifically, they propose the use of a cyclical learning rate, which starts at a small value, increases linearly over time to a large value, then decreases back to the small value again.

PyTorch has an implementation of just such a cyclic learning rate scheduler called CyclicLR. Given an optimizer (like RMSProp, for example) and a total number of epochs, you can set up an appropriate learning rate schedule like this:

from torch.optim.lr_scheduler import CyclicLR

scheduler = CyclicLR(my_rmsprop, base_lr=0.01, max_lr=0.1, step_size_up=total_epochs//2, step_size_down=total_epocs//2)

The authors go on to say that for even better accuracy, while fixing the computational budget, it helps to reserve one or two epochs at the end for training with very small learning rates. The direct quote from the paper is:

always use one cycle that is smaller than the total number of iterations/epochs and allow the learning rate to decrease several orders of magnitude less than the initial learning rate for the remaining iterations

Unfortunately, the CyclicLR class in PyTorch doesn't have an option for a trailing LR policy, and indeed if you try to set the step_size_up and step_size_down params to be less than the total epochs, a second cycle will begin where the learning rate goes back up, instead of down.

Fortunately, we can code up our own implementation of Smith and Tobin's 1CycleLR fairly quickly. Here is a simple implementation that uses the PyTorch component for the first cycle, then adds a final linear decrease afterward:

class OneCycleLR:

    def __init__(self, optimizer, base_lr: float = 0.01, max_lr: float = 0.1, fraction_up: float = 0.4, fraction_down: float = 0.4, total_steps: int = 100):
        self.cyclic = optim.lr_scheduler.CyclicLR(optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=int(total_steps * fraction_up), step_size_down=int(total_steps * fraction_down), cycle_momentum=False)
        self.base_lr = base_lr
        self.last_index = -1
        self.cutover_index = int((fraction_up + fraction_down) * total_steps)
        self.step_size = (base_lr - (base_lr / 100)) / (total_steps - self.cutover_index)

    def step(self):
        self.last_index += 1
        if self.last_index <= self.cutover_index:
            return self.cyclic.step()
        else:
            return self.base_lr - self.step_size * (self.last_index - self.cutover_index)