diff --git a/torchvision/datasets/cifar.py b/torchvision/datasets/cifar.py index d7fdfbc18b5..18b4ee02f7c 100644 --- a/torchvision/datasets/cifar.py +++ b/torchvision/datasets/cifar.py @@ -51,13 +51,6 @@ class CIFAR10(data.Dataset): 'md5': '5ff9c542aee3614f3951f8cda6e48888', } - @property - def targets(self): - if self.train: - return self.train_labels - else: - return self.test_labels - def __init__(self, root, train=True, transform=None, target_transform=None, download=False): @@ -73,44 +66,30 @@ def __init__(self, root, train=True, raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') - # now load the picked numpy arrays if self.train: - self.train_data = [] - self.train_labels = [] - for fentry in self.train_list: - f = fentry[0] - file = os.path.join(self.root, self.base_folder, f) - fo = open(file, 'rb') + downloaded_list = self.train_list + else: + downloaded_list = self.test_list + + self.data = [] + self.targets = [] + + # now load the picked numpy arrays + for file_name, checksum in downloaded_list: + file_path = os.path.join(self.root, self.base_folder, file_name) + with open(file_path, 'rb') as f: if sys.version_info[0] == 2: - entry = pickle.load(fo) + entry = pickle.load(f) else: - entry = pickle.load(fo, encoding='latin1') - self.train_data.append(entry['data']) + entry = pickle.load(f, encoding='latin1') + self.data.append(entry['data']) if 'labels' in entry: - self.train_labels += entry['labels'] + self.targets.extend(entry['labels']) else: - self.train_labels += entry['fine_labels'] - fo.close() + self.targets.extend(entry['fine_labels']) - self.train_data = np.concatenate(self.train_data) - self.train_data = self.train_data.reshape((50000, 3, 32, 32)) - self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC - else: - f = self.test_list[0][0] - file = os.path.join(self.root, self.base_folder, f) - fo = open(file, 'rb') - if sys.version_info[0] == 2: - entry = pickle.load(fo) - else: - entry = pickle.load(fo, encoding='latin1') - self.test_data = entry['data'] - if 'labels' in entry: - self.test_labels = entry['labels'] - else: - self.test_labels = entry['fine_labels'] - fo.close() - self.test_data = self.test_data.reshape((10000, 3, 32, 32)) - self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC + self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) + self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC self._load_meta() @@ -135,10 +114,7 @@ def __getitem__(self, index): Returns: tuple: (image, target) where target is index of the target class. """ - if self.train: - img, target = self.train_data[index], self.train_labels[index] - else: - img, target = self.test_data[index], self.test_labels[index] + img, target = self.data[index], self.targets[index] # doing this so that it is consistent with all other datasets # to return a PIL Image @@ -153,10 +129,7 @@ def __getitem__(self, index): return img, target def __len__(self): - if self.train: - return len(self.train_data) - else: - return len(self.test_data) + return len(self.data) def _check_integrity(self): root = self.root @@ -174,16 +147,11 @@ def download(self): print('Files already downloaded and verified') return - root = self.root - download_url(self.url, root, self.filename, self.tgz_md5) + download_url(self.url, self.root, self.filename, self.tgz_md5) # extract file - cwd = os.getcwd() - tar = tarfile.open(os.path.join(root, self.filename), "r:gz") - os.chdir(root) - tar.extractall() - tar.close() - os.chdir(cwd) + with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: + tar.extractall(path=self.root) def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'