From eb8dc84eae8492b7e7276af84b045be6fae1e8f0 Mon Sep 17 00:00:00 2001 From: Tian Qi Chen Date: Thu, 5 Jan 2017 01:43:42 -0800 Subject: [PATCH] add mnist --- torchvision/datasets/__init__.py | 4 +- torchvision/datasets/mnist.py | 154 +++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 torchvision/datasets/mnist.py diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index 2eac78c79c0..e9c4b0e7184 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -2,8 +2,10 @@ from .folder import ImageFolder from .coco import CocoCaptions, CocoDetection from .cifar import CIFAR10, CIFAR100 +from .mnist import MNIST __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'CocoCaptions', 'CocoDetection', - 'CIFAR10', 'CIFAR100') + 'CIFAR10', 'CIFAR100', + 'MNIST') diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py new file mode 100644 index 00000000000..04506691924 --- /dev/null +++ b/torchvision/datasets/mnist.py @@ -0,0 +1,154 @@ +from __future__ import print_function +import torch.utils.data as data +from PIL import Image +import os +import os.path +import errno +import torch +import json +import codecs +import numpy as np + +class MNIST(data.Dataset): + urls = [ + 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', + 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', + 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', + 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', + ] + raw_folder = 'raw' + processed_folder = 'processed' + training_file = 'training.pt' + test_file = 'test.pt' + + def __init__(self, root, train=True, transform=None, target_transform=None, download=False): + self.root = root + self.transform = transform + self.target_transform = target_transform + self.train = train # training set or test set + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError('Dataset not found.' + + ' You can use download=True to download it') + + if self.train: + self.train_data, self.train_labels = torch.load(os.path.join(root, self.processed_folder, self.training_file)) + else: + self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file)) + + def __getitem__(self, index): + if self.train: + img, target = self.train_data[index], self.train_labels[index] + else: + img, target = self.test_data[index], self.test_labels[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img.numpy(), mode='L') + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + if self.train: + return 60000 + else: + return 10000 + + def _check_exists(self): + return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \ + os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file)) + + def download(self): + from six.moves import urllib + import gzip + + if self._check_exists(): + print('Files already downloaded') + return + + # download files + try: + os.makedirs(os.path.join(self.root, self.raw_folder)) + os.makedirs(os.path.join(self.root, self.processed_folder)) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise + + for url in self.urls: + print('Downloading ' + url) + data = urllib.request.urlopen(url) + filename = url.rpartition('/')[2] + file_path = os.path.join(self.root, self.raw_folder, filename) + with open(file_path, 'wb') as f: + f.write(data.read()) + with open(file_path.replace('.gz', ''), 'wb') as out_f, \ + gzip.GzipFile(file_path) as zip_f: + out_f.write(zip_f.read()) + os.unlink(file_path) + + # process and save as torch files + print('Processing') + + training_set = ( + read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')), + read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte')) + ) + test_set = ( + read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')), + read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte')) + ) + with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f: + torch.save(training_set, f) + with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f: + torch.save(test_set, f) + + print('Done!') + +def get_int(b): + return int(codecs.encode(b, 'hex'), 16) + +def parse_byte(b): + if isinstance(b, str): + return ord(b) + return b + +def read_label_file(path): + with open(path, 'rb') as f: + data = f.read() + assert get_int(data[:4]) == 2049 + length = get_int(data[4:8]) + labels = [parse_byte(b) for b in data[8:]] + assert len(labels) == length + return torch.LongTensor(labels) + +def read_image_file(path): + with open(path, 'rb') as f: + data = f.read() + assert get_int(data[:4]) == 2051 + length = get_int(data[4:8]) + num_rows = get_int(data[8:12]) + num_cols = get_int(data[12:16]) + images = [] + idx = 16 + for l in range(length): + img = [] + images.append(img) + for r in range(num_rows): + row = [] + img.append(row) + for c in range(num_cols): + row.append(parse_byte(data[idx])) + idx += 1 + assert len(images) == length + return torch.ByteTensor(images).view(-1, 28, 28)