Superconvergence in PyTorch
In Super-Convergence: Very fast training of neural networks using large learning rates1, Smith and Tobin present evidence for a learning rate parametrization scheme that can result in a 10x decrease in training time, while maintaining similar accuracy. Specifically, they propose the use of a cyclical learning rate, which starts at a small value, increases linearly over time to a large value, then decreases back to the small value again.
PyTorch has an implementation of just such a cyclic learning rate scheduler called CyclicLR. Given an optimizer (like RMSProp, for example) and a total number of epochs, you can set up an appropriate learning rate schedule like this:
from torch.optim.lr_scheduler import CyclicLR
scheduler = CyclicLR(my_rmsprop, base_lr=0.01, max_lr=0.1, step_size_up=total_epochs//2, step_size_down=total_epocs//2)
The authors go on to say that for even better accuracy, while fixing the computational budget, it helps to reserve one or two epochs at the end for training with very small learning rates. The direct quote from the paper is:
always use one cycle that is smaller than the total number of iterations/epochs and allow the learning rate to decrease several orders of magnitude less than the initial learning rate for the remaining iterations
Unfortunately, the CyclicLR class in PyTorch doesn't have an option for a trailing LR policy, and indeed if you try to set the step_size_up
and step_size_down
params to be less than the total epochs, a second cycle will begin where the learning rate goes back up, instead of down.
Fortunately, we can code up our own implementation of Smith and Tobin's 1CycleLR fairly quickly. Here is a simple implementation that uses the PyTorch component for the first cycle, then adds a final linear decrease afterward:
class OneCycleLR:
def __init__(self, optimizer, base_lr: float = 0.01, max_lr: float = 0.1, fraction_up: float = 0.4, fraction_down: float = 0.4, total_steps: int = 100):
self.cyclic = optim.lr_scheduler.CyclicLR(optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=int(total_steps * fraction_up), step_size_down=int(total_steps * fraction_down), cycle_momentum=False)
self.base_lr = base_lr
self.last_index = -1
self.cutover_index = int((fraction_up + fraction_down) * total_steps)
self.step_size = (base_lr - (base_lr / 100)) / (total_steps - self.cutover_index)
def step(self):
self.last_index += 1
if self.last_index <= self.cutover_index:
return self.cyclic.step()
else:
return self.base_lr - self.step_size * (self.last_index - self.cutover_index)