diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index bc117072df3..b33dc1450e3 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1473,18 +1473,19 @@ def pcam(root, config): return num_images -# @register_mock -def stanford_cars(info, root, config): +@register_mock(name="stanford-cars", configs=combinations_grid(split=("train", "test"))) +def stanford_cars(root, config): import scipy.io as io from numpy.core.records import fromarrays - num_samples = {"train": 5, "test": 7}[config["split"]] + split = config["split"] + num_samples = {"train": 5, "test": 7}[split] num_categories = 3 devkit = root / "devkit" devkit.mkdir(parents=True) - if config["split"] == "train": + if split == "train": images_folder_name = "cars_train" annotations_mat_path = devkit / "cars_train_annos.mat" else: @@ -1498,7 +1499,7 @@ def stanford_cars(info, root, config): num_examples=num_samples, ) - make_tar(root, f"cars_{config.split}.tgz", images_folder_name) + make_tar(root, f"cars_{split}.tgz", images_folder_name) bbox = np.random.randint(1, 200, num_samples, dtype=np.uint8) classes = np.random.randint(1, num_categories + 1, num_samples, dtype=np.uint8) fnames = [f"{i:5d}.jpg" for i in range(num_samples)] @@ -1508,7 +1509,7 @@ def stanford_cars(info, root, config): ) io.savemat(annotations_mat_path, {"annotations": rec_array}) - if config.split == "train": + if split == "train": make_tar(root, "car_devkit.tgz", devkit, compression="gz") return num_samples diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index a5de1359e4e..dcec6d0e716 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -41,8 +41,9 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) class DTD(Dataset2): """DTD Dataset. - homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", + homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", """ + def __init__( self, root: Union[str, pathlib.Path], diff --git a/torchvision/prototype/datasets/_builtin/stanford_cars.py b/torchvision/prototype/datasets/_builtin/stanford_cars.py index 51c0b6152e6..85098eb34e5 100644 --- a/torchvision/prototype/datasets/_builtin/stanford_cars.py +++ b/torchvision/prototype/datasets/_builtin/stanford_cars.py @@ -1,11 +1,19 @@ import pathlib -from typing import Any, Dict, List, Tuple, Iterator, BinaryIO +from typing import Any, Dict, List, Tuple, Iterator, BinaryIO, Union from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper -from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, path_comparator, read_mat +from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import ( + hint_sharding, + hint_shuffling, + path_comparator, + read_mat, + BUILTIN_DIR, +) from torchvision.prototype.features import BoundingBox, EncodedImage, Label +from .._api import register_dataset, register_info + class StanfordCarsLabelReader(IterDataPipe[Tuple[int, int, int, int, int, str]]): def __init__(self, datapipe: IterDataPipe[Dict[str, Any]]) -> None: @@ -18,16 +26,33 @@ def __iter__(self) -> Iterator[Tuple[int, int, int, int, int, str]]: yield tuple(ann) # type: ignore[misc] -class StanfordCars(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - name="stanford-cars", - homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html", - dependencies=("scipy",), - valid_options=dict( - split=("test", "train"), - ), - ) +NAME = "stanford-cars" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + categories = DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories") + categories = [c[0] for c in categories] + return dict(categories=categories) + + +@register_dataset(NAME) +class StanfordCars(Dataset2): + """Stanford Cars dataset. + homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html", + dependencies=scipy + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "test"}) + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("scipy",)) _URL_ROOT = "https://ai.stanford.edu/~jkrause/" _URLS = { @@ -44,9 +69,9 @@ def _make_info(self) -> DatasetInfo: "car_devkit": "512b227b30e2f0a8aab9e09485786ab4479582073a144998da74d64b801fd288", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - resources: List[OnlineResource] = [HttpResource(self._URLS[config.split], sha256=self._CHECKSUM[config.split])] - if config.split == "train": + def _resources(self) -> List[OnlineResource]: + resources: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUM[self._split])] + if self._split == "train": resources.append(HttpResource(url=self._URLS["car_devkit"], sha256=self._CHECKSUM["car_devkit"])) else: @@ -65,19 +90,14 @@ def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Tuple[int, int, int, return dict( path=path, image=image, - label=Label(target[4] - 1, categories=self.categories), + label=Label(target[4] - 1, categories=self._categories), bounding_box=BoundingBox(target[:4], format="xyxy", image_size=image.image_size), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: images_dp, targets_dp = resource_dps - if config.split == "train": + if self._split == "train": targets_dp = Filter(targets_dp, path_comparator("name", "cars_train_annos.mat")) targets_dp = StanfordCarsLabelReader(targets_dp) dp = Zipper(images_dp, targets_dp) @@ -85,12 +105,14 @@ def _make_datapipe( dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: pathlib.Path) -> List[str]: - config = self.info.make_config(split="train") - resources = self.resources(config) + def _generate_categories(self) -> List[str]: + resources = self._resources() - devkit_dp = resources[1].load(root) + devkit_dp = resources[1].load(self._root) meta_dp = Filter(devkit_dp, path_comparator("name", "cars_meta.mat")) _, meta_file = next(iter(meta_dp)) return list(read_mat(meta_file, squeeze_me=True)["class_names"]) + + def __len__(self) -> int: + return 8_144 if self._split == "train" else 8_041