@@ -46,6 +46,10 @@ class CIFAR10(data.Dataset):
4646 ['test_batch' , '40351d587109b95175f43aff81a1287e' ],
4747 ]
4848
49+ meta_list = [
50+ ['batches.meta' , '5ff9c542aee3614f3951f8cda6e48888' ],
51+ ]
52+
4953 def __init__ (self , root , train = True ,
5054 transform = None , target_transform = None ,
5155 download = False ):
@@ -100,6 +104,16 @@ def __init__(self, root, train=True,
100104 self .test_data = self .test_data .reshape ((10000 , 3 , 32 , 32 ))
101105 self .test_data = self .test_data .transpose ((0 , 2 , 3 , 1 )) # convert to HWC
102106
107+ f = self .meta_list [0 ][0 ]
108+ file = os .path .join (self .root , self .base_folder , f )
109+ fo = open (file , 'rb' )
110+ if sys .version_info [0 ] == 2 :
111+ entry = pickle .load (fo )
112+ else :
113+ entry = pickle .load (fo , encoding = 'latin1' )
114+ fo .close ()
115+ self .meta = entry
116+
103117 def __getitem__ (self , index ):
104118 """
105119 Args:
@@ -133,7 +147,7 @@ def __len__(self):
133147
134148 def _check_integrity (self ):
135149 root = self .root
136- for fentry in (self .train_list + self .test_list ):
150+ for fentry in (self .train_list + self .test_list + self . meta_list ):
137151 filename , md5 = fentry [0 ], fentry [1 ]
138152 fpath = os .path .join (root , self .base_folder , filename )
139153 if not check_integrity (fpath , md5 ):
@@ -187,3 +201,7 @@ class CIFAR100(CIFAR10):
187201 test_list = [
188202 ['test' , 'f0ef6b0ae62326f3e7ffdfab6717acfc' ],
189203 ]
204+
205+ meta_list = [
206+ ['meta' , '7973b15100ade9c7d40fb424638fde48' ],
207+ ]
0 commit comments