diff --git a/torchvision/datasets/cifar.py b/torchvision/datasets/cifar.py index b467a813e49..d7f88882028 100644 --- a/torchvision/datasets/cifar.py +++ b/torchvision/datasets/cifar.py @@ -42,35 +42,35 @@ def __init__(self, root, train=True, transform=None, target_transform=None, down + ' You can use download=True to download it') # now load the picked numpy arrays - self.train_data = [] - self.train_labels = [] - for fentry in self.train_list: - f = fentry[0] + if self.train: + self.train_data = [] + self.train_labels = [] + for fentry in self.train_list: + f = fentry[0] + file = os.path.join(root, self.base_folder, f) + fo = open(file, 'rb') + entry = pickle.load(fo) + self.train_data.append(entry['data']) + if 'labels' in entry: + self.train_labels += entry['labels'] + else: + self.train_labels += entry['fine_labels'] + fo.close() + + self.train_data = np.concatenate(self.train_data) + self.train_data = self.train_data.reshape((50000, 3, 32, 32)) + else: + f = self.test_list[0][0] file = os.path.join(root, self.base_folder, f) fo = open(file, 'rb') entry = pickle.load(fo) - self.train_data.append(entry['data']) + self.test_data = entry['data'] if 'labels' in entry: - self.train_labels += entry['labels'] + self.test_labels = entry['labels'] else: - self.train_labels += entry['fine_labels'] + self.test_labels = entry['fine_labels'] fo.close() - - self.train_data = np.concatenate(self.train_data) - - f = self.test_list[0][0] - file = os.path.join(root, self.base_folder, f) - fo = open(file, 'rb') - entry = pickle.load(fo) - self.test_data = entry['data'] - if 'labels' in entry: - self.test_labels = entry['labels'] - else: - self.test_labels = entry['fine_labels'] - fo.close() - - self.train_data = self.train_data.reshape((50000, 3, 32, 32)) - self.test_data = self.test_data.reshape((10000, 3, 32, 32)) + self.test_data = self.test_data.reshape((10000, 3, 32, 32)) def __getitem__(self, index): if self.train: