Skip to content

Latest commit

 

History

History

fast_mnist

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 

Fast MNIST

The PyTorch MNIST dataset is SLOW by default, because it wants to conform to the usual interface of returning a PIL image. This is unnecessary if you just want a normalized MNIST and are not interested in image transforms (such as rotation, cropping). By folding the normalization into the dataset initialization you can save your CPU and speed up training by 2-3x.

The bottleneck when training on MNIST with a GPU and a small-ish model is the CPU. In fact, even with six dataloader workers on a six core i7, the GPU utilization is only ~5-10%. Using FastMNIST increases GPU utilization to ~20-25% and reduces CPU utilization to near zero. On my particular model the steps per second with batch size 64 went from ~150 to ~500.

Instead of the default MNIST dataset, use this:

import torch
from torchvision.datasets import MNIST

device = torch.device('cuda')

class FastMNIST(MNIST):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # Scale data to [0,1]
        self.data = self.data.unsqueeze(1).float().div(255)

        # Normalize it with the usual MNIST mean and std
        self.data = self.data.sub_(0.1307).div_(0.3081)

        # Put both data and targets on GPU in advance
        self.data, self.targets = self.data.to(device), self.targets.to(device)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        return img, target

And call the dataloader like this:

from torch.utils.data import DataLoader

train_dataset = FastMNIST('data/MNIST', train=True, download=True)
test_dataset = FastMNIST('data/MNIST', train=False, download=True)

# num_workers=0 is very important!
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
test_dataloader = DataLoader(test_dataset, batch_size=10000, shuffle=False, num_workers=0)

Results in 2-3x speedup (500it/s on a 1080Ti and a smallish MLP), uses near zero CPU (compared to full CPU usage normally).