Virtual epochs for PyTorch

A common problem when training neural networks is the size of the data1. There are several strategies for storing and querying large amounts of data, or for increasing model throughput to speed up training when there are large amounts of data, but scale causes problems in much more mundane ways.

For example, a common logical structure for a training loop is to specify a number of epochs, then to pass a dataset -- chunk by chunk -- through your model for each epoch. It might look something like this:

for epoch in range(epochs):
    loader = DataLoader(dataset)
    for data, labels in loader:
        prediction = model(data)
        ...
        loss.backward()
        scheduler.step()

After an epoch has completed, you may want to calculate your validation error, or checkpoint your model. With very large data, however, it may take several days to get through a single pass of your dataset. Waiting days before the first validation error is available can dramatically slow down the iteration time for experimenting with different model architectures, and checkpointing that infrequently means a lot of computation work will get lost if something unfortunate happens. Ideally, if our model is going to fail, we want it to fail fast.

One solution to this problem is to use a "virtual epoch" -- an epoch with a fixed number of steps, where each step is of a given batch size, and the total number of training examples seen during an epoch is nbatch size * nsteps per epoch. This still lets you iterate over the entire dataset (given enough epochs), but decouples the rate at which your model gets checkpointed or validated from the size of the dataset you have.

In Keras, this can be done by using the steps per epoch argument in your fit call.

In PyTorch, it's a little more complicated. The idiomatic training loop in PyTorch implicitly ties the length of an epoch to total length of a dataset. To get around this, we can wrap the standard PyTorch DataLoader in a class which:

  1. tracks the number of steps taken; and,
  2. automatically re-initializes the DataLoader's iterator upon exhaustion

We can call this our "virtual" data loader, and it will look like this:

class VirtualDataLoader:

    def __init__(self, data_loader, steps_per_epoch: int = 1000):
        self.data_loader = data_loader
        self.iterator = iter(self.data_loader)
        self.steps_per_epoch = steps_per_epoch
        self.current_step = 0

    def __iter__(self):
        return self

    def __next__(self):
        if self.current_step < self.steps_per_epoch:
            self.current_step += 1
            try:
                return next(self.iterator)
            except StopIteration:
                self.iterator = iter(self.data_loader)
                return next(self.iterator)
        else:
            self.current_step = 0
            raise StopIteration

It accepts a PyTorch DataLoader instance and a total number of steps per epoch as input. It initializes the iterator from the DataLoader, and a counter starting at 0.

At every iteration in the training loop, it first checks its counter to see if the virtual epoch has ended. If it has, it stops the inner loop, but does not reset the iterator, so that the next epoch will resume where the previous one ended.

Otherwise, it increments its own counter, and tries to return the next set of outputs from the DataLoader. If the DataLoader's iterator has already gone through all of the training examples, we restart the iterator at the beginning.

The dunder iter: return self is there to make this work properly in Python's for-loop structure, which is implicitly doing something similar to:

iterator = iter(loader)
while True:
    try:
        next(iterator)
    except StopIteration:
        break

Let's see what this looks like in practice. We'll need to start by mocking out a simple dataset to work with:

class SimpleDataset:

    def __len__(self):
        return 2

    def __getitem__(self, index):
        return index, index

This is a dataset with just two training samples in it, (0, 0) and (1, 1).

Now, we have two cases that we're going to care about:

  1. an epoch should finish at the given number of steps; and,
  2. one epoch should start where the last has ended.

To see case 1, we can use a batch size of one and a single step per epoch:

loader = DataLoader(SimpleDataset(), batch_size=1)
virtual_loader = VirtualDataLoader(loader, steps_per_epoch=1)
for x, y in virtual_loader:
    print(x, y)

and ensure that only sample (0, 0) gets returned:

tensor([0]) tensor([0])

To see case 2, we can set the steps per epoch to a larger number, like 4:

loader = DataLoader(SimpleDataset(), batch_size=1)
virtual_loader = VirtualDataLoader(loader, steps_per_epoch=1)
for x, y in virtual_loader:
    print(x, y)

and we should see that we iterate over the dataset twice in a single epoch:

tensor([0]) tensor([0])
tensor([1]) tensor([1])
tensor([0]) tensor([0])
tensor([1]) tensor([1])

One key difference between the DataLoader and the VirtualDataLoader is that the virtual one should only be initialized once, and that this should happen at the top of the training loop. The new loop structure should look something like this:

loader = DataLoader(dataset, batch_size=batch_size)
virtual_loader = VirtualDataLoader(loader, steps_per_epoch=steps_per_epoch)
for epoch in range(epochs):
    for data, labels in virtual_loader:
        prediction = model(data)
        ...
        loss.backward()
        scheduler.step()

  1. this post was inspired by a chapter in Machine Learning Design Patterns