Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cifar 10 and 100 #3

Merged
merged 1 commit into from
Nov 10, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
build/
dist/
torchvision.egg-info/
*/**/__pycache__
*/**/*.pyc
*/**/*~
*~
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ The following dataset loaders are available:
- [LSUN Classification](#lsun)
- [ImageFolder](#imagefolder)
- [Imagenet-12](#imagenet-12)
- [CIFAR10 and CIFAR100](#cifar)

Datasets have the API:
- `__getitem__`
Expand Down Expand Up @@ -97,6 +98,15 @@ u'A mountain view with a plume of smoke in the background']
- ['bedroom_train', 'church_train', ...] : a list of categories to load


### CIFAR

`dset.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)`
`dset.CIFAR100(root, train=True, transform=None, target_transform=None, download=False)`

- `root` : root directory of dataset where there is folder `cifar-10-batches-py`
- `train` : `True` = Training set, `False` = Test set
- `download` : `True` = downloads the dataset from the internet and puts it in root directory. If dataset already downloaded, does not do anything.

### ImageFolder

A generic data loader where the images are arranged in this way:
Expand Down
12 changes: 12 additions & 0 deletions test/cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
import torchvision.datasets as dset

print('\n\nCifar 10')
a = dset.CIFAR10(root="abc/def/ghi", download=True)

print(a[3])

print('\n\nCifar 100')
a = dset.CIFAR100(root="abc/def/ghi", download=True)

print(a[3])
4 changes: 3 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .lsun import LSUN, LSUNClass
from .folder import ImageFolder
from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder',
'CocoCaptions', 'CocoDetection')
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100')
159 changes: 159 additions & 0 deletions torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import numpy as np
import sys
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle

class CIFAR10(data.Dataset):
base_folder = 'cifar-10-batches-py'
url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filename = "cifar-10-python.tar.gz"
tgz_mdf = 'c58f30108f718f92721af3b95e74349a'
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
]

test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'],
]

def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set

if download:
self.download()

if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.'
+ ' You can use download=True to download it')

# now load the picked numpy arrays
self.train_data = []
self.train_labels = []
for fentry in self.train_list:
f = fentry[0]
file = os.path.join(root, self.base_folder, f)
fo = open(file, 'rb')
entry = pickle.load(fo)
self.train_data.append(entry['data'])
if 'labels' in entry:
self.train_labels += entry['labels']
else:
self.train_labels += entry['fine_labels']
fo.close()

self.train_data = np.concatenate(self.train_data)

f = self.test_list[0][0]
file = os.path.join(root, self.base_folder, f)
fo = open(file, 'rb')
entry = pickle.load(fo)
self.test_data = entry['data']
if 'labels' in entry:
self.test_labels = entry['labels']
else:
self.test_labels = entry['fine_labels']
fo.close()

self.train_data = self.train_data.reshape((50000, 3, 32, 32))
self.test_data = self.test_data.reshape((10000, 3, 32, 32))

def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self):
if self.train:
return 50000
else:
return 10000

def _check_integrity(self):
import hashlib
root = self.root
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
fpath = os.path.join(root, self.base_folder, filename)
if not os.path.isfile(fpath):
return False
md5c = hashlib.md5(open(fpath, 'rb').read()).hexdigest()
if md5c != md5:
return False
return True

def download(self):
from six.moves import urllib
import tarfile
import hashlib

root = self.root
fpath = os.path.join(root, self.filename)

try:
os.makedirs(root)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise

if self._check_integrity():
print('Files already downloaded and verified')
return

# downloads file
if os.path.isfile(fpath) and \
hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.tgz_md5:
print('Using downloaded file: ' + fpath)
else:
print('Downloading ' + self.url + ' to ' + fpath)
urllib.request.urlretrieve(self.url, fpath)

# extract file
cwd = os.getcwd()
print('Extracting tar file')
tar = tarfile.open(fpath, "r:gz")
os.chdir(root)
tar.extractall()
tar.close()
os.chdir(cwd)
print('Done!')


class CIFAR100(CIFAR10):
base_folder = 'cifar-100-python'
url = "http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename = "cifar-100-python.tar.gz"
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
train_list = [
['train', '16019d7e3df5f24257cddd939b257f8d'],
]

test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
]