77import torch
88from builtin_dataset_mocks import parametrize_dataset_mocks , DATASET_MOCKS
99from torch .testing ._comparison import assert_equal , TensorLikePair , ObjectPair
10+ from torch .utils .data import DataLoader
1011from torch .utils .data .graph import traverse
1112from torch .utils .data .graph_settings import get_all_graph_pipes
12- from torchdata .datapipes .iter import IterDataPipe , Shuffler , ShardingFilter
13+ from torchdata .datapipes .iter import Shuffler , ShardingFilter
1314from torchvision ._utils import sequence_to_str
1415from torchvision .prototype import transforms , datasets
1516from torchvision .prototype .datasets .utils ._internal import INFINITE_BUFFER_SIZE
@@ -42,14 +43,24 @@ def test_coverage():
4243
4344@pytest .mark .filterwarnings ("error" )
4445class TestCommon :
46+ @pytest .mark .parametrize ("name" , datasets .list_datasets ())
47+ def test_info (self , name ):
48+ try :
49+ info = datasets .info (name )
50+ except ValueError :
51+ raise AssertionError ("No info available." ) from None
52+
53+ if not (isinstance (info , dict ) and all (isinstance (key , str ) for key in info .keys ())):
54+ raise AssertionError ("Info should be a dictionary with string keys." )
55+
4556 @parametrize_dataset_mocks (DATASET_MOCKS )
4657 def test_smoke (self , test_home , dataset_mock , config ):
4758 dataset_mock .prepare (test_home , config )
4859
4960 dataset = datasets .load (dataset_mock .name , ** config )
5061
51- if not isinstance (dataset , IterDataPipe ):
52- raise AssertionError (f"Loading the dataset should return an IterDataPipe , but got { type (dataset )} instead." )
62+ if not isinstance (dataset , datasets . utils . Dataset ):
63+ raise AssertionError (f"Loading the dataset should return an Dataset , but got { type (dataset )} instead." )
5364
5465 @parametrize_dataset_mocks (DATASET_MOCKS )
5566 def test_sample (self , test_home , dataset_mock , config ):
@@ -76,24 +87,7 @@ def test_num_samples(self, test_home, dataset_mock, config):
7687
7788 dataset = datasets .load (dataset_mock .name , ** config )
7889
79- num_samples = 0
80- for _ in dataset :
81- num_samples += 1
82-
83- assert num_samples == mock_info ["num_samples" ]
84-
85- @parametrize_dataset_mocks (DATASET_MOCKS )
86- def test_decoding (self , test_home , dataset_mock , config ):
87- dataset_mock .prepare (test_home , config )
88-
89- dataset = datasets .load (dataset_mock .name , ** config )
90-
91- undecoded_features = {key for key , value in next (iter (dataset )).items () if isinstance (value , io .IOBase )}
92- if undecoded_features :
93- raise AssertionError (
94- f"The values of key(s) "
95- f"{ sequence_to_str (sorted (undecoded_features ), separate_last = 'and ' )} were not decoded."
96- )
90+ assert len (list (dataset )) == mock_info ["num_samples" ]
9791
9892 @parametrize_dataset_mocks (DATASET_MOCKS )
9993 def test_no_vanilla_tensors (self , test_home , dataset_mock , config ):
@@ -116,14 +110,36 @@ def test_transformable(self, test_home, dataset_mock, config):
116110
117111 next (iter (dataset .map (transforms .Identity ())))
118112
113+ @pytest .mark .parametrize ("only_datapipe" , [False , True ])
119114 @parametrize_dataset_mocks (DATASET_MOCKS )
120- def test_serializable (self , test_home , dataset_mock , config ):
115+ def test_traversable (self , test_home , dataset_mock , config , only_datapipe ):
121116 dataset_mock .prepare (test_home , config )
117+ dataset = datasets .load (dataset_mock .name , ** config )
118+
119+ traverse (dataset , only_datapipe = only_datapipe )
122120
121+ @parametrize_dataset_mocks (DATASET_MOCKS )
122+ def test_serializable (self , test_home , dataset_mock , config ):
123+ dataset_mock .prepare (test_home , config )
123124 dataset = datasets .load (dataset_mock .name , ** config )
124125
125126 pickle .dumps (dataset )
126127
128+ @pytest .mark .parametrize ("num_workers" , [0 , 1 ])
129+ @parametrize_dataset_mocks (DATASET_MOCKS )
130+ def test_data_loader (self , test_home , dataset_mock , config , num_workers ):
131+ dataset_mock .prepare (test_home , config )
132+ dataset = datasets .load (dataset_mock .name , ** config )
133+
134+ dl = DataLoader (
135+ dataset ,
136+ batch_size = 2 ,
137+ num_workers = num_workers ,
138+ collate_fn = lambda batch : batch ,
139+ )
140+
141+ next (iter (dl ))
142+
127143 # TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
128144 # that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
129145 # contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
@@ -132,7 +148,6 @@ def test_serializable(self, test_home, dataset_mock, config):
132148 def test_has_annotations (self , test_home , dataset_mock , config , annotation_dp_type ):
133149
134150 dataset_mock .prepare (test_home , config )
135-
136151 dataset = datasets .load (dataset_mock .name , ** config )
137152
138153 if not any (isinstance (dp , annotation_dp_type ) for dp in extract_datapipes (dataset )):
@@ -160,6 +175,13 @@ def test_infinite_buffer_size(self, test_home, dataset_mock, config):
160175 # resolved
161176 assert dp .buffer_size == INFINITE_BUFFER_SIZE
162177
178+ @parametrize_dataset_mocks (DATASET_MOCKS )
179+ def test_has_length (self , test_home , dataset_mock , config ):
180+ dataset_mock .prepare (test_home , config )
181+ dataset = datasets .load (dataset_mock .name , ** config )
182+
183+ assert len (dataset ) > 0
184+
163185
164186@parametrize_dataset_mocks (DATASET_MOCKS ["qmnist" ])
165187class TestQMNIST :
@@ -186,7 +208,7 @@ class TestGTSRB:
186208 def test_label_matches_path (self , test_home , dataset_mock , config ):
187209 # We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
188210 # This test makes sure that they're both the same
189- if config . split != "train" :
211+ if config [ " split" ] != "train" :
190212 return
191213
192214 dataset_mock .prepare (test_home , config )
0 commit comments