diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index bc117072df3..6d8ebd4385e 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1402,10 +1402,10 @@ def generate(cls, root): return num_samples_map -# @register_mock -def cub200(info, root, config): - num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root) - return num_samples_map[config.split] +@register_mock(configs=combinations_grid(split=("train", "test"), year=("2010", "2011"))) +def cub200(root, config): + num_samples_map = (CUB2002011MockData if config["year"] == "2011" else CUB2002010MockData).generate(root) + return num_samples_map[config["split"]] @register_mock(configs=[dict()]) diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index 1b90b476aa7..073a790092c 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -1,7 +1,7 @@ import csv import functools import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Callable +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Callable, Union from torchdata.datapipes.iter import ( IterDataPipe, @@ -14,8 +14,7 @@ CSVDictParser, ) from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, + Dataset2, DatasetInfo, HttpResource, OnlineResource, @@ -28,26 +27,53 @@ getitem, path_comparator, path_accessor, + BUILTIN_DIR, ) from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage +from .._api import register_dataset, register_info + csv.register_dialect("cub200", delimiter=" ") -class CUB200(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "cub200", - homepage="http://www.vision.caltech.edu/visipedia/CUB-200-2011.html", - dependencies=("scipy",), - valid_options=dict( - split=("train", "test"), - year=("2011", "2010"), - ), +NAME = "cub200" + +CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=CATEGORIES) + + +@register_dataset(NAME) +class CUB200(Dataset2): + """ + - **homepage**: http://www.vision.caltech.edu/visipedia/CUB-200.html + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + year: str = "2011", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "test")) + self._year = self._verify_str_arg(year, "year", ("2010", "2011")) + + self._categories = _info()["categories"] + + super().__init__( + root, + # TODO: this will only be available after https://github.com/pytorch/vision/pull/5473 + # dependencies=("scipy",), + skip_integrity_check=skip_integrity_check, ) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - if config.year == "2011": + def _resources(self) -> List[OnlineResource]: + if self._year == "2011": archive = HttpResource( "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz", sha256="0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081", @@ -59,7 +85,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: preprocess="decompress", ) return [archive, segmentations] - else: # config.year == "2010" + else: # self._year == "2010" split = HttpResource( "http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz", sha256="aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428", @@ -90,12 +116,12 @@ def _2011_classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: else: return None - def _2011_filter_split(self, row: List[str], *, split: str) -> bool: + def _2011_filter_split(self, row: List[str]) -> bool: _, split_id = row return { "0": "test", "1": "train", - }[split_id] == split + }[split_id] == self._split def _2011_segmentation_key(self, data: Tuple[str, Any]) -> str: path = pathlib.Path(data[0]) @@ -149,17 +175,12 @@ def _prepare_sample( return dict( prepare_ann_fn(anns_data, image.image_size), image=image, - label=Label(int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]), categories=self.categories), + label=Label(int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]), categories=self._categories), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: prepare_ann_fn: Callable - if config.year == "2011": + if self._year == "2011": archive_dp, segmentations_dp = resource_dps images_dp, split_dp, image_files_dp, bounding_boxes_dp = Demultiplexer( archive_dp, 4, self._2011_classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE @@ -171,7 +192,7 @@ def _make_datapipe( ) split_dp = CSVParser(split_dp, dialect="cub200") - split_dp = Filter(split_dp, functools.partial(self._2011_filter_split, split=config.split)) + split_dp = Filter(split_dp, self._2011_filter_split) split_dp = Mapper(split_dp, getitem(0)) split_dp = Mapper(split_dp, image_files_map.get) @@ -188,10 +209,10 @@ def _make_datapipe( ) prepare_ann_fn = self._2011_prepare_ann - else: # config.year == "2010" + else: # self._year == "2010" split_dp, images_dp, anns_dp = resource_dps - split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) + split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) split_dp = LineReader(split_dp, decode=True, return_path=False) split_dp = Mapper(split_dp, self._2010_split_key) @@ -217,11 +238,19 @@ def _make_datapipe( ) return Mapper(dp, functools.partial(self._prepare_sample, prepare_ann_fn=prepare_ann_fn)) - def _generate_categories(self, root: pathlib.Path) -> List[str]: - config = self.info.make_config(year="2011") - resources = self.resources(config) + def __len__(self) -> int: + return { + ("train", "2010"): 3_000, + ("test", "2010"): 3_033, + ("train", "2011"): 5_994, + ("test", "2011"): 5_794, + }[(self._split, self._year)] + + def _generate_categories(self) -> List[str]: + self._year = "2011" + resources = self._resources() - dp = resources[0].load(root) + dp = resources[0].load(self._root) dp = Filter(dp, path_comparator("name", "classes.txt")) dp = CSVDictParser(dp, fieldnames=("label", "category"), dialect="cub200")