How to combine variable length sequences in PyTorch DataLoaders
If you're getting started with PyTorch for text, you've probably encountered an error that looks something like:
Sizes of tensors must match except in dimension 0.
The short explanation for this error is that sequences are often different lengths, but tensors are required to be rectangular. The fix for this is to make every sequence the same shape.
To see what this looks like in more detail, let's imagine thta we have a few sequences of text like the following that we would like to classify:
text = [
"So when he bleeds, I bleed, the sacred riddle",
"Aint no money I can make to make the cops get little",
"We've seen'em murder the indigenous, the passage middle",
"The constitution, a life for a bag of Skittles"
]
labels = [0, 1, 0, 1]
We can construct an in-memory dataset with these because it is so small:
from niacin.text.compat.pytorch import MemoryTextDataset
dataset = MemoryTextDataset(text, labels)
Then construct a PyTorch DataLoader with that dataset, and iterate over the sequences.
from torch.utils.data import DataLoader
loader = DataLoader(dataset)
for data, labels in loader:
print(data)
By printing out the inputs one at a time, we can see what these tensors look like:
tensor([[10, 37, 22, 15, 2, 5, 14, 2, 3, 33, 32]])
tensor([[ 8, 29, 27, 5, 16, 7, 35, 7, 3, 18, 21, 25]])
tensor([[12, 4, 36, 34, 4, 19, 28, 3, 23, 2, 3, 31, 26]])
tensor([[11, 17, 2, 6, 24, 20, 6, 13, 30, 9]])
So far this is working well, but we are unlikely to use a batch size of 1 in real life. If we increase this to a batch of just 2:
loader = DataLoader(dataset, batch_size=2)
for data, labels in loader:
print(data)
We will immediately encounter a RuntimeError
letting us know that PyTorch had trouble putting the batch together.
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 11 and 12 in dimension 1 at /Users/distiller/project/conda/conda-bld/pytorch_1565272526878/work/aten/src/TH/generic/THTensor.cpp:689
This happens because all of our text entries are different lengths -- some are 10 tokens long, some are 11 tokens long, and one of them is 13 -- so the attempt to make a rectangular datastructure looks like this:
0 | 1 | ... | n-2 | n-1 | n |
---|---|---|---|---|---|
10 | 37 | ... | 33 | 32 | ?? |
8 | 29 | ... | 18 | 21 | 25 |
The way we solve this problem for sequences in general is to "pad" the shorter seqences with nonsense values. In PyTorch specifically, the torchtext Vocabulary class will have a special token ("<pad>"
) whose purpose is exactly this:
vocab = dataset._vocab
vocab.stoi['<pad>']
1
This means we can pad our shorter tensors with 1
until they are as long as our longer tensors.1 In PyTorch, we do this by using the collate_fn
argument of the DataLoader class, which lets us pass in an arbitrary function for combining samples into a batch.
collate_fn (callable, optional): merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
This function should accept a single input, which will be a list of length batch_size
, where each element is whatever gets returned by the Dataset class. In our case, the in-memory dataset returns a tuple of data and label, so we'll get a list of tuples.
Inside the function, we create a new tensor of filled with our padding token (1
), which is as wide as our widest sequence. Then, we replace the padding token wherever we have real data, and return that as the data
batch from the function:
def variable_width(batch):
# batch is a list of (data, label) pairs with length batch_size
text = [i[0] for i in batch]
# our outputs will be (batch_size, sequence_length)
batch_length = len(batch)
max_sequence_length = max([len(t) for t in text])
# initialize array of "<pad>" integers
data = torch.full((batch_length, max_sequence_length), fill_value=1, dtype=torch.long)
# fill with non-padded values
for i, t in enumerate(text):
data[i, 0:len(t)] = t
# don't forget the labels!
labels = torch.tensor([i[1] for i in batch])
return data, labels
If we initialize a new DataLoader with this function:
loader = DataLoader(dataset, batch_size=2, collate_fn=variable_width)
for data, labels in loader:
print(data)
we'll see that we correctly generate rectangular batches, where the shorter sequences have an appropriate amount of 1
s padded onto the end.
tensor([[10, 37, 22, 15, 2, 5, 14, 2, 3, 33, 32, 1],
[ 8, 29, 27, 5, 16, 7, 35, 7, 3, 18, 21, 25]])
tensor([[12, 4, 36, 34, 4, 19, 28, 3, 23, 2, 3, 31, 26],
[11, 17, 2, 6, 24, 20, 6, 13, 30, 9, 1, 1, 1]])
One more thing -- some of the neural network classes in PyTorch expect to receive batches of shape (sequence_size, batch_size
). The recurrent neural networks (RNN)s for example, follow this approach. You have two options for using this component:
- set the module's
batch_first
init argument toTrue
; or, - transpose the batch before returning.
To see what that second option entails, we can revisit our variable width collation function, and swap the order of the dimensions corresponding to batch length and sequence length:
def variable_width_T(batch):
text = [i[0] for i in batch]
batch_length = len(batch)
max_sequence_length = max([len(t) for t in text])
# we're going to switch the order of sequence and batch size in the constructor
data = torch.full((max_sequence_length, batch_length), fill_value=1, dtype=torch.long)
# then fill from top-to-bottom, instead of left-to-right
for i, t in enumerate(text):
data[0:len(t), i] = t
labels = torch.tensor([i[1] for i in batch])
return data, labels
then recreate our DataLoader instance with the new, transposed collate function:
loader = DataLoader(dataset, batch_size=2, collate_fn=variable_width_T)
for data, labels in loader:
print(data)
and see that the batched outputs now have the sequence as the first dimension:
tensor([[10, 8],
[37, 29],
[22, 27],
[15, 5],
[ 2, 16],
[ 5, 7],
[14, 35],
[ 2, 7],
[ 3, 3],
[33, 18],
[32, 21],
[ 1, 25]])
tensor([[12, 11],
[ 4, 17],
[36, 2],
[34, 6],
[ 4, 24],
[19, 20],
[28, 6],
[ 3, 13],
[23, 30],
[ 2, 9],
[ 3, 1],
[31, 1],
[26, 1]])
-
different instances of a vocabulary class might have a different index number (i.e. not
1
) for the"<pad>"
token ↩