11import pathlib
2- from typing import Any , Dict , List , Optional , Tuple , BinaryIO
2+ from typing import Any , Dict , List , Optional , Tuple , BinaryIO , Union
33
44from torchdata .datapipes .iter import IterDataPipe , Mapper , Filter , IterKeyZipper , Demultiplexer , JsonParser , UnBatcher
5- from torchvision .prototype .datasets .utils import (
6- Dataset ,
7- DatasetConfig ,
8- DatasetInfo ,
9- HttpResource ,
10- OnlineResource ,
11- )
5+ from torchvision .prototype .datasets .utils import Dataset2 , HttpResource , OnlineResource
126from torchvision .prototype .datasets .utils ._internal import (
137 INFINITE_BUFFER_SIZE ,
148 hint_sharding ,
1913)
2014from torchvision .prototype .features import Label , EncodedImage
2115
16+ from .._api import register_dataset , register_info
17+
18+ NAME = "clevr"
19+
20+
21+ @register_info (NAME )
22+ def _info () -> Dict [str , Any ]:
23+ return dict ()
2224
23- class CLEVR (Dataset ):
24- def _make_info (self ) -> DatasetInfo :
25- return DatasetInfo (
26- "clevr" ,
27- homepage = "https://cs.stanford.edu/people/jcjohns/clevr/" ,
28- valid_options = dict (split = ("train" , "val" , "test" )),
29- )
3025
31- def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
26+ @register_dataset (NAME )
27+ class CLEVR (Dataset2 ):
28+ """
29+ - **homepage**: https://cs.stanford.edu/people/jcjohns/clevr/
30+ """
31+
32+ def __init__ (
33+ self , root : Union [str , pathlib .Path ], * , split : str = "train" , skip_integrity_check : bool = False
34+ ) -> None :
35+ self ._split = self ._verify_str_arg (split , "split" , ("train" , "val" , "test" ))
36+
37+ super ().__init__ (root , skip_integrity_check = skip_integrity_check )
38+
39+ def _resources (self ) -> List [OnlineResource ]:
3240 archive = HttpResource (
3341 "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip" ,
3442 sha256 = "5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1" ,
@@ -61,12 +69,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, A
6169 label = Label (len (scenes_data ["objects" ])) if scenes_data else None ,
6270 )
6371
64- def _make_datapipe (
65- self ,
66- resource_dps : List [IterDataPipe ],
67- * ,
68- config : DatasetConfig ,
69- ) -> IterDataPipe [Dict [str , Any ]]:
72+ def _datapipe (self , resource_dps : List [IterDataPipe ]) -> IterDataPipe [Dict [str , Any ]]:
7073 archive_dp = resource_dps [0 ]
7174 images_dp , scenes_dp = Demultiplexer (
7275 archive_dp ,
@@ -76,12 +79,12 @@ def _make_datapipe(
7679 buffer_size = INFINITE_BUFFER_SIZE ,
7780 )
7881
79- images_dp = Filter (images_dp , path_comparator ("parent.name" , config . split ))
82+ images_dp = Filter (images_dp , path_comparator ("parent.name" , self . _split ))
8083 images_dp = hint_shuffling (images_dp )
8184 images_dp = hint_sharding (images_dp )
8285
83- if config . split != "test" :
84- scenes_dp = Filter (scenes_dp , path_comparator ("name" , f"CLEVR_{ config . split } _scenes.json" ))
86+ if self . _split != "test" :
87+ scenes_dp = Filter (scenes_dp , path_comparator ("name" , f"CLEVR_{ self . _split } _scenes.json" ))
8588 scenes_dp = JsonParser (scenes_dp )
8689 scenes_dp = Mapper (scenes_dp , getitem (1 , "scenes" ))
8790 scenes_dp = UnBatcher (scenes_dp )
@@ -97,3 +100,6 @@ def _make_datapipe(
97100 dp = Mapper (images_dp , self ._add_empty_anns )
98101
99102 return Mapper (dp , self ._prepare_sample )
103+
104+ def __len__ (self ) -> int :
105+ return 70_000 if self ._split == "train" else 15_000
0 commit comments