diff --git a/torchvision/datasets/cifar.py b/torchvision/datasets/cifar.py index ac7a7269af3..d7fdfbc18b5 100644 --- a/torchvision/datasets/cifar.py +++ b/torchvision/datasets/cifar.py @@ -45,6 +45,18 @@ class CIFAR10(data.Dataset): test_list = [ ['test_batch', '40351d587109b95175f43aff81a1287e'], ] + meta = { + 'filename': 'batches.meta', + 'key': 'label_names', + 'md5': '5ff9c542aee3614f3951f8cda6e48888', + } + + @property + def targets(self): + if self.train: + return self.train_labels + else: + return self.test_labels def __init__(self, root, train=True, transform=None, target_transform=None, @@ -100,6 +112,21 @@ def __init__(self, root, train=True, self.test_data = self.test_data.reshape((10000, 3, 32, 32)) self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC + self._load_meta() + + def _load_meta(self): + path = os.path.join(self.root, self.base_folder, self.meta['filename']) + if not check_integrity(path, self.meta['md5']): + raise RuntimeError('Dataset metadata file not found or corrupted.' + + ' You can use download=True to download it') + with open(path, 'rb') as infile: + if sys.version_info[0] == 2: + data = pickle.load(infile) + else: + data = pickle.load(infile, encoding='latin1') + self.classes = data[self.meta['key']] + self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} + def __getitem__(self, index): """ Args: @@ -187,3 +214,8 @@ class CIFAR100(CIFAR10): test_list = [ ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], ] + meta = { + 'filename': 'meta', + 'key': 'fine_label_names', + 'md5': '7973b15100ade9c7d40fb424638fde48', + } diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 1df4bcbf44d..967cac1681a 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -69,6 +69,7 @@ class DatasetFolder(data.Dataset): classes (list): List of the class names. class_to_idx (dict): Dict with items (class_name, class_index). samples (list): List of (sample path, class_index) tuples + targets (list): The class_index value for each image in the dataset """ def __init__(self, root, loader, extensions, transform=None, target_transform=None): @@ -85,6 +86,7 @@ def __init__(self, root, loader, extensions, transform=None, target_transform=No self.classes = classes self.class_to_idx = class_to_idx self.samples = samples + self.targets = [s[1] for s in samples] self.transform = transform self.target_transform = target_transform diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 7f4463eff64..7fa6dcf7666 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -35,6 +35,16 @@ class MNIST(data.Dataset): processed_folder = 'processed' training_file = 'training.pt' test_file = 'test.pt' + classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', + '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] + class_to_idx = {_class: i for i, _class in enumerate(classes)} + + @property + def targets(self): + if self.train: + return self.train_labels + else: + return self.test_labels def __init__(self, root, train=True, transform=None, target_transform=None, download=False): self.root = os.path.expanduser(root) @@ -174,6 +184,9 @@ class FashionMNIST(MNIST): 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz', 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz', ] + classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', + 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] + class_to_idx = {_class: i for i, _class in enumerate(classes)} class EMNIST(MNIST):