99import pathlib
1010import pickle
1111import random
12+ import unittest .mock
1213import xml .etree .ElementTree as ET
1314from collections import defaultdict , Counter
1415
1516import numpy as np
1617import PIL .Image
1718import pytest
1819import torch
19- from datasets_utils import make_zip , make_tar , create_image_folder , create_image_file
20+ from datasets_utils import make_zip , make_tar , create_image_folder , create_image_file , combinations_grid
2021from torch .nn .functional import one_hot
2122from torch .testing import make_tensor as _make_tensor
22- from torchvision .prototype . datasets . _api import find
23+ from torchvision .prototype import datasets
2324from torchvision .prototype .utils ._internal import sequence_to_str
2425
2526make_tensor = functools .partial (_make_tensor , device = "cpu" )
3031
3132
3233class DatasetMock :
33- def __init__ (self , name , mock_data_fn ):
34- self .dataset = find (name )
35- self .info = self .dataset .info
36- self .name = self .info .name
37-
34+ def __init__ (self , name , * , mock_data_fn , configs ):
35+ # FIXME: error handling for unknown names
36+ self .name = name
3837 self .mock_data_fn = mock_data_fn
39- self .configs = self . info . _configs
38+ self .configs = configs
4039
4140 def _parse_mock_info (self , mock_info ):
4241 if mock_info is None :
@@ -65,10 +64,13 @@ def prepare(self, home, config):
6564 root = home / self .name
6665 root .mkdir (exist_ok = True )
6766
68- mock_info = self ._parse_mock_info (self .mock_data_fn (self . info , root , config ))
67+ mock_info = self ._parse_mock_info (self .mock_data_fn (root , config ))
6968
69+ with unittest .mock .patch .object (datasets .utils .Dataset2 , "__init__" ):
70+ required_file_names = {
71+ resource .file_name for resource in datasets .load (self .name , root = root , ** config )._resources ()
72+ }
7073 available_file_names = {path .name for path in root .glob ("*" )}
71- required_file_names = {resource .file_name for resource in self .dataset .resources (config )}
7274 missing_file_names = required_file_names - available_file_names
7375 if missing_file_names :
7476 raise pytest .UsageError (
@@ -123,10 +125,16 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None):
123125DATASET_MOCKS = {}
124126
125127
126- def register_mock (fn ):
127- name = fn .__name__ .replace ("_" , "-" )
128- DATASET_MOCKS [name ] = DatasetMock (name , fn )
129- return fn
128+ def register_mock (name = None , * , configs ):
129+ def wrapper (mock_data_fn ):
130+ nonlocal name
131+ if name is None :
132+ name = mock_data_fn .__name__
133+ DATASET_MOCKS [name ] = DatasetMock (name , mock_data_fn = mock_data_fn , configs = configs )
134+
135+ return mock_data_fn
136+
137+ return wrapper
130138
131139
132140class MNISTMockData :
@@ -204,7 +212,7 @@ def generate(
204212 return num_samples
205213
206214
207- @register_mock
215+ # @register_mock
208216def mnist (info , root , config ):
209217 train = config .split == "train"
210218 images_file = f"{ 'train' if train else 't10k' } -images-idx3-ubyte.gz"
@@ -217,10 +225,10 @@ def mnist(info, root, config):
217225 )
218226
219227
220- DATASET_MOCKS .update ({name : DatasetMock (name , mnist ) for name in ["fashionmnist" , "kmnist" ]})
228+ # DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]})
221229
222230
223- @register_mock
231+ # @register_mock
224232def emnist (info , root , config ):
225233 # The image sets that merge some lower case letters in their respective upper case variant, still use dense
226234 # labels in the data files. Thus, num_categories != len(categories) there.
@@ -247,7 +255,7 @@ def emnist(info, root, config):
247255 return num_samples_map [config ]
248256
249257
250- @register_mock
258+ # @register_mock
251259def qmnist (info , root , config ):
252260 num_categories = len (info .categories )
253261 if config .split == "train" :
@@ -324,7 +332,7 @@ def generate(
324332 make_tar (root , name , folder , compression = "gz" )
325333
326334
327- @register_mock
335+ # @register_mock
328336def cifar10 (info , root , config ):
329337 train_files = [f"data_batch_{ idx } " for idx in range (1 , 6 )]
330338 test_files = ["test_batch" ]
@@ -342,7 +350,7 @@ def cifar10(info, root, config):
342350 return len (train_files if config .split == "train" else test_files )
343351
344352
345- @register_mock
353+ # @register_mock
346354def cifar100 (info , root , config ):
347355 train_files = ["train" ]
348356 test_files = ["test" ]
@@ -360,7 +368,7 @@ def cifar100(info, root, config):
360368 return len (train_files if config .split == "train" else test_files )
361369
362370
363- @register_mock
371+ # @register_mock
364372def caltech101 (info , root , config ):
365373 def create_ann_file (root , name ):
366374 import scipy .io
@@ -410,7 +418,7 @@ def create_ann_folder(root, name, file_name_fn, num_examples):
410418 return num_images_per_category * len (info .categories )
411419
412420
413- @register_mock
421+ # @register_mock
414422def caltech256 (info , root , config ):
415423 dir = root / "256_ObjectCategories"
416424 num_images_per_category = 2
@@ -430,26 +438,26 @@ def caltech256(info, root, config):
430438 return num_images_per_category * len (info .categories )
431439
432440
433- @register_mock
434- def imagenet (info , root , config ):
441+ @register_mock ( configs = combinations_grid ( split = ( "train" , "val" , "test" )))
442+ def imagenet (root , config ):
435443 from scipy .io import savemat
436444
437- categories = info . categories
438- wnids = [ info . extra . category_to_wnid [ category ] for category in categories ]
439- if config . split == "train" :
440- num_samples = len (wnids )
445+ info = datasets . info ( "imagenet" )
446+
447+ if config [ " split" ] == "train" :
448+ num_samples = len (info [ " wnids" ] )
441449 archive_name = "ILSVRC2012_img_train.tar"
442450
443451 files = []
444- for wnid in wnids :
452+ for wnid in info [ " wnids" ] :
445453 create_image_folder (
446454 root = root ,
447455 name = wnid ,
448456 file_name_fn = lambda image_idx : f"{ wnid } _{ image_idx :04d} .JPEG" ,
449457 num_examples = 1 ,
450458 )
451459 files .append (make_tar (root , f"{ wnid } .tar" ))
452- elif config . split == "val" :
460+ elif config [ " split" ] == "val" :
453461 num_samples = 3
454462 archive_name = "ILSVRC2012_img_val.tar"
455463 files = [create_image_file (root , f"ILSVRC2012_val_{ idx + 1 :08d} .JPEG" ) for idx in range (num_samples )]
@@ -459,20 +467,20 @@ def imagenet(info, root, config):
459467 data_root .mkdir (parents = True )
460468
461469 with open (data_root / "ILSVRC2012_validation_ground_truth.txt" , "w" ) as file :
462- for label in torch .randint (0 , len (wnids ), (num_samples ,)).tolist ():
470+ for label in torch .randint (0 , len (info [ " wnids" ] ), (num_samples ,)).tolist ():
463471 file .write (f"{ label } \n " )
464472
465473 num_children = 0
466474 synsets = [
467475 (idx , wnid , category , "" , num_children , [], 0 , 0 )
468- for idx , (category , wnid ) in enumerate (zip (categories , wnids ), 1 )
476+ for idx , (category , wnid ) in enumerate (zip (info [ " categories" ], info [ " wnids" ] ), 1 )
469477 ]
470478 num_children = 1
471479 synsets .extend ((0 , "" , "" , "" , num_children , [], 0 , 0 ) for _ in range (5 ))
472480 savemat (data_root / "meta.mat" , dict (synsets = synsets ))
473481
474482 make_tar (root , devkit_root .with_suffix (".tar.gz" ).name , compression = "gz" )
475- else : # config. split == "test"
483+ else : # config[" split"] == "test"
476484 num_samples = 5
477485 archive_name = "ILSVRC2012_img_test_v10102019.tar"
478486 files = [create_image_file (root , f"ILSVRC2012_test_{ idx + 1 :08d} .JPEG" ) for idx in range (num_samples )]
@@ -587,7 +595,7 @@ def generate(
587595 return num_samples
588596
589597
590- @register_mock
598+ # @register_mock
591599def coco (info , root , config ):
592600 return CocoMockData .generate (root , year = config .year , num_samples = 5 )
593601
@@ -661,12 +669,12 @@ def generate(cls, root):
661669 return num_samples_map
662670
663671
664- @register_mock
672+ # @register_mock
665673def sbd (info , root , config ):
666674 return SBDMockData .generate (root )[config .split ]
667675
668676
669- @register_mock
677+ # @register_mock
670678def semeion (info , root , config ):
671679 num_samples = 3
672680 num_categories = len (info .categories )
@@ -779,7 +787,7 @@ def generate(cls, root, *, year, trainval):
779787 return num_samples_map
780788
781789
782- @register_mock
790+ # @register_mock
783791def voc (info , root , config ):
784792 trainval = config .split != "test"
785793 return VOCMockData .generate (root , year = config .year , trainval = trainval )[config .split ]
@@ -873,12 +881,12 @@ def generate(cls, root):
873881 return num_samples_map
874882
875883
876- @register_mock
884+ # @register_mock
877885def celeba (info , root , config ):
878886 return CelebAMockData .generate (root )[config .split ]
879887
880888
881- @register_mock
889+ # @register_mock
882890def dtd (info , root , config ):
883891 data_folder = root / "dtd"
884892
@@ -926,7 +934,7 @@ def dtd(info, root, config):
926934 return num_samples_map [config ]
927935
928936
929- @register_mock
937+ # @register_mock
930938def fer2013 (info , root , config ):
931939 num_samples = 5 if config .split == "train" else 3
932940
@@ -951,7 +959,7 @@ def fer2013(info, root, config):
951959 return num_samples
952960
953961
954- @register_mock
962+ # @register_mock
955963def gtsrb (info , root , config ):
956964 num_examples_per_class = 5 if config .split == "train" else 3
957965 classes = ("00000" , "00042" , "00012" )
@@ -1021,7 +1029,7 @@ def _make_ann_file(path, num_examples, class_idx):
10211029 return num_examples
10221030
10231031
1024- @register_mock
1032+ # @register_mock
10251033def clevr (info , root , config ):
10261034 data_folder = root / "CLEVR_v1.0"
10271035
@@ -1127,7 +1135,7 @@ def generate(self, root):
11271135 return num_samples_map
11281136
11291137
1130- @register_mock
1138+ # @register_mock
11311139def oxford_iiit_pet (info , root , config ):
11321140 return OxfordIIITPetMockData .generate (root )[config .split ]
11331141
@@ -1293,13 +1301,13 @@ def generate(cls, root):
12931301 return num_samples_map
12941302
12951303
1296- @register_mock
1304+ # @register_mock
12971305def cub200 (info , root , config ):
12981306 num_samples_map = (CUB2002011MockData if config .year == "2011" else CUB2002010MockData ).generate (root )
12991307 return num_samples_map [config .split ]
13001308
13011309
1302- @register_mock
1310+ # @register_mock
13031311def svhn (info , root , config ):
13041312 import scipy .io as sio
13051313
@@ -1319,7 +1327,7 @@ def svhn(info, root, config):
13191327 return num_samples
13201328
13211329
1322- @register_mock
1330+ # @register_mock
13231331def pcam (info , root , config ):
13241332 import h5py
13251333
0 commit comments