Adding data augmentation to torchtext datasets

It is universally acknowledged that artificially augmented datasets lead to models which are both more accurate and more generalizable. They do this by introducing variability which is likely to be encountered in ecologically valid settings but is not present in the training data; and, by providing negative examples of spurious data associations. Estimates of model accuracy improvements from using augmentation methods range from 0 to ~10 percentage points, where the change in model error is strongly dependent on model, task, and dataset.1

Data augmentation involves applying an augmenting transform to training data examples, where an augmenting transform is one which modifies the low-level information in an example, but not the high-level information. Common examples from machine vision models include rotating an image, cropping an image, flipping an image, shifting the hue of image (e.g. making every pixel a little bit more blue), decreasing the resolution of the image, etc. In all cases, these result in large changes to the pixel values of an image given to a model, but are unlikely to change the semantic content of that image unless the transform itself is extreme (like shifting every pixel to perfectly black, for example). For this reason, augmenting functions are typically parameterized by intensity, like the angle of rotation, or the percentage of the image left after cropping.

In most cases, augmenting transforms are easiest to apply on the data as they are, before other modifications. For this reason, torchvision applies augmenting transforms to data as they are read in, e.g.:

img = Image.fromarray(img)

if self.transform is not None:
    img = self.transform(img)

if self.target_transform is not None:
    target = self.target_transform(target)

return img, target

Text data are a bit different than images in that they need to be vectorized (converted from unicode characters to a numeric format) before being given to a model. In the built-in torchtext datasets (e.g. all of the text classification datasets), this is done before the dataset instance is constructed:

data = []
labels = []
with tqdm(unit_scale=0, unit='lines') as t:
    for cls, tokens in iterator:
        if include_unk:
            tokens = torch.tensor([vocab[token] for token in tokens])
        else:
            token_ids = list(filter(lambda x: x is not Vocab.UNK, [vocab[token]
                                    for token in tokens]))
            tokens = torch.tensor(token_ids)
        if len(tokens) == 0:
            logging.info('Row contains no tokens.')
        data.append((cls, tokens))
        labels.append(cls)

which precludes the use of augmenting transforms when fetching data:

def __getitem__(self, i):
    return self._data[i]

To use augmenting transforms inside a torchtext-like dataset, the dunder get item method will need to be able to fetch the raw data, which means either storing the raw text in memory:

df = pd.read_table(datafile, sep=sep)
self._data = df.iloc[:, 1:].apply(lambda s: s.str.cat(sep=" "), axis=1).tolist()
self._labels = df.iloc[:, 0].tolist()

or storing the location to a file where the text can be accessed.

self._data_dir = Path(data_dir)
self._data = sorted(os.listdir(data_dir))
self._labels_dir = Path(labels_dir)
self._labels = sorted(os.listdir(labels_dir))

Then, when get item is called the class can apply the transforms on the fly:

def __getitem__(self, index):
    label = self._labels[index]
    data = self._data[index]
    data = self._transform(data)
    tokens = self._tokenize(data)
    vector = self._vectorize(tokens)
    return vector, label

To learn more about using text augmentation with PyTorch data loaders, see this brief video:

or take a look at niacin's documentation.


  1. Chen, S., Dobriban, E., & Lee, J. H. (2019). Invariance reduces variance: Understandingdata augmentation in deep learning and beyond. http://arxiv.org/abs/1907.10905