Skip to content

Commit

Permalink
Merge pull request #3 from pytorch/cifar
Browse files Browse the repository at this point in the history
cifar 10 and 100
  • Loading branch information
soumith authored Nov 10, 2016
2 parents e37323d + 754d526 commit 63dabca
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 1 deletion.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
build/
dist/
torchvision.egg-info/
*/**/__pycache__
*/**/*.pyc
*/**/*~
*~
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ The following dataset loaders are available:
- [LSUN Classification](#lsun)
- [ImageFolder](#imagefolder)
- [Imagenet-12](#imagenet-12)
- [CIFAR10 and CIFAR100](#cifar)

Datasets have the API:
- `__getitem__`
Expand Down Expand Up @@ -97,6 +98,15 @@ u'A mountain view with a plume of smoke in the background']
- ['bedroom_train', 'church_train', ...] : a list of categories to load


### CIFAR

`dset.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)`
`dset.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)`

- `root` : root directory of dataset where there is folder `cifar-10-batches-py`
- `train` : `True` = Training set, `False` = Test set
- `download` : `True` = downloads the dataset from the internet and puts it in root directory. If dataset already downloaded, does not do anything.

### ImageFolder

A generic data loader where the images are arranged in this way:
Expand Down
12 changes: 12 additions & 0 deletions test/cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
import torchvision.datasets as dset

print('\n\nCifar 10')
a = dset.CIFAR10(root="abc/def/ghi", download=True)

print(a[3])

print('\n\nCifar 100')
a = dset.CIFAR100(root="abc/def/ghi", download=True)

print(a[3])
4 changes: 3 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .lsun import LSUN, LSUNClass
from .folder import ImageFolder
from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder',
'CocoCaptions', 'CocoDetection')
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100')
159 changes: 159 additions & 0 deletions torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import numpy as np
import sys
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle

class CIFAR10(data.Dataset):
base_folder = 'cifar-10-batches-py'
url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
tgz_mdf = 'c58f30108f718f92721af3b95e74349a'
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
]

test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'],
]

def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set

if download:
self.download()

if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.'
+ ' You can use download=True to download it')

# now load the picked numpy arrays
self.train_data = []
self.train_labels = []
for fentry in self.train_list:
f = fentry[0]
file = os.path.join(root, self.base_folder, f)
fo = open(file, 'rb')
entry = pickle.load(fo)
self.train_data.append(entry['data'])
if 'labels' in entry:
self.train_labels += entry['labels']
else:
self.train_labels += entry['fine_labels']
fo.close()

self.train_data = np.concatenate(self.train_data)

f = self.test_list[0][0]
file = os.path.join(root, self.base_folder, f)
fo = open(file, 'rb')
entry = pickle.load(fo)
self.test_data = entry['data']
if 'labels' in entry:
self.test_labels = entry['labels']
else:
self.test_labels = entry['fine_labels']
fo.close()

self.train_data = self.train_data.reshape((50000, 3, 32, 32))
self.test_data = self.test_data.reshape((10000, 3, 32, 32))

def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):
if self.train:
return 50000
else:
return 10000

def _check_integrity(self):
import hashlib
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not os.path.isfile(fpath):
return False
md5c = hashlib.md5(open(fpath, 'rb').read()).hexdigest()
if md5c != md5:
return False
return True

def download(self):
from six.moves import urllib
import tarfile
import hashlib

root = self.root
fpath = os.path.join(root, self.filename)

try:
os.makedirs(root)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise

if self._check_integrity():
print('Files already downloaded and verified')
return

# downloads file
if os.path.isfile(fpath) and \
hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.tgz_md5:
print('Using downloaded file: ' + fpath)
else:
print('Downloading ' + self.url + ' to ' + fpath)
urllib.request.urlretrieve(self.url, fpath)

# extract file
cwd = os.getcwd()
print('Extracting tar file')
tar = tarfile.open(fpath, "r:gz")
os.chdir(root)
tar.extractall()
tar.close()
os.chdir(cwd)
print('Done!')


class CIFAR100(CIFAR10):
base_folder = 'cifar-100-python'
url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename = "cifar-100-python.tar.gz"
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
train_list = [
['train', '16019d7e3df5f24257cddd939b257f8d'],
]

test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
]

0 comments on commit 63dabca

Please sign in to comment.