Skip to content

Add metadata to some datasets #501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions torchvision/datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ class CIFAR10(data.Dataset):
test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'],
]
meta = {
'filename': 'batches.meta',
'key': 'label_names',
'md5': '5ff9c542aee3614f3951f8cda6e48888',
}

@property
def targets(self):
if self.train:
return self.train_labels
else:
return self.test_labels

def __init__(self, root, train=True,
transform=None, target_transform=None,
Expand Down Expand Up @@ -100,6 +112,21 @@ def __init__(self, root, train=True,
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC

self._load_meta()

def _load_meta(self):
path = os.path.join(self.root, self.base_folder, self.meta['filename'])
if not check_integrity(path, self.meta['md5']):
raise RuntimeError('Dataset metadata file not found or corrupted.' +
' You can use download=True to download it')
with open(path, 'rb') as infile:
if sys.version_info[0] == 2:
data = pickle.load(infile)
else:
data = pickle.load(infile, encoding='latin1')
self.classes = data[self.meta['key']]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}

def __getitem__(self, index):
"""
Args:
Expand Down Expand Up @@ -187,3 +214,8 @@ class CIFAR100(CIFAR10):
test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
]
meta = {
'filename': 'meta',
'key': 'fine_label_names',
'md5': '7973b15100ade9c7d40fb424638fde48',
}
2 changes: 2 additions & 0 deletions torchvision/datasets/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class DatasetFolder(data.Dataset):
classes (list): List of the class names.
class_to_idx (dict): Dict with items (class_name, class_index).
samples (list): List of (sample path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""

def __init__(self, root, loader, extensions, transform=None, target_transform=None):
Expand All @@ -85,6 +86,7 @@ def __init__(self, root, loader, extensions, transform=None, target_transform=No
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]

self.transform = transform
self.target_transform = target_transform
Expand Down
13 changes: 13 additions & 0 deletions torchvision/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ class MNIST(data.Dataset):
processed_folder = 'processed'
training_file = 'training.pt'
test_file = 'test.pt'
classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
class_to_idx = {_class: i for i, _class in enumerate(classes)}

@property
def targets(self):
if self.train:
return self.train_labels
else:
return self.test_labels

def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
Expand Down Expand Up @@ -174,6 +184,9 @@ class FashionMNIST(MNIST):
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
]
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
class_to_idx = {_class: i for i, _class in enumerate(classes)}


class EMNIST(MNIST):
Expand Down