diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index ad979b6bd84..f88d9aa8364 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -600,9 +600,15 @@ def generate( return num_samples -# @register_mock -def coco(info, root, config): - return CocoMockData.generate(root, year=config.year, num_samples=5) +@register_mock( + configs=combinations_grid( + split=("train", "val"), + year=("2017", "2014"), + annotations=("instances", "captions", None), + ) +) +def coco(root, config): + return CocoMockData.generate(root, year=config["year"], num_samples=5) class SBDMockData: diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 1005c7b3130..75896a8db08 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -1,8 +1,8 @@ -import functools import pathlib import re from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union import torch from torchdata.datapipes.iter import ( @@ -16,11 +16,10 @@ UnBatcher, ) from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, DatasetInfo, HttpResource, OnlineResource, + Dataset2, ) from torchvision.prototype.datasets.utils._internal import ( MappingIterator, @@ -32,27 +31,51 @@ hint_shuffling, ) from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage -from torchvision.prototype.utils._internal import FrozenMapping - - -class Coco(Dataset): - def _make_info(self) -> DatasetInfo: - name = "coco" - categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories")) - - return DatasetInfo( - name, - dependencies=("pycocotools",), - categories=categories, - homepage="https://cocodataset.org/", - valid_options=dict( - split=("train", "val"), - year=("2017", "2014"), - annotations=(*self._ANN_DECODERS.keys(), None), - ), - extra=dict(category_to_super_category=FrozenMapping(zip(categories, super_categories))), + +from .._api import register_dataset, register_info + + +NAME = "coco" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + return dict(categories=categories, super_categories=super_categories) + + +@register_dataset(NAME) +class Coco(Dataset2): + """ + - **homepage**: https://cocodataset.org/ + - **dependencies**: + - _ + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + year: str = "2017", + annotations: Optional[str] = "instances", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "val"}) + self._year = self._verify_str_arg(year, "year", {"2017", "2014"}) + self._annotations = ( + self._verify_str_arg(annotations, "annotations", self._ANN_DECODERS.keys()) + if annotations is not None + else None ) + info = _info() + categories, super_categories = info["categories"], info["super_categories"] + self._categories = categories + self._category_to_super_category = dict(zip(categories, super_categories)) + + super().__init__(root, dependencies=("pycocotools",), skip_integrity_check=skip_integrity_check) + _IMAGE_URL_BASE = "http://images.cocodataset.org/zips" _IMAGES_CHECKSUMS = { @@ -69,14 +92,14 @@ def _make_info(self) -> DatasetInfo: "2017": "113a836d90195ee1f884e704da6304dfaaecff1f023f49b6ca93c4aaae470268", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: images = HttpResource( - f"{self._IMAGE_URL_BASE}/{config.split}{config.year}.zip", - sha256=self._IMAGES_CHECKSUMS[(config.year, config.split)], + f"{self._IMAGE_URL_BASE}/{self._split}{self._year}.zip", + sha256=self._IMAGES_CHECKSUMS[(self._year, self._split)], ) meta = HttpResource( - f"{self._META_URL_BASE}/annotations_trainval{config.year}.zip", - sha256=self._META_CHECKSUMS[config.year], + f"{self._META_URL_BASE}/annotations_trainval{self._year}.zip", + sha256=self._META_CHECKSUMS[self._year], ) return [images, meta] @@ -110,10 +133,8 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st format="xywh", image_size=image_size, ), - labels=Label(labels, categories=self.categories), - super_categories=[ - self.info.extra.category_to_super_category[self.info.categories[label]] for label in labels - ], + labels=Label(labels, categories=self._categories), + super_categories=[self._category_to_super_category[self._categories[label]] for label in labels], ann_ids=[ann["id"] for ann in anns], ) @@ -134,9 +155,14 @@ def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str, fr"(?P({'|'.join(_ANN_DECODERS.keys())}))_(?P[a-zA-Z]+)(?P\d+)[.]json" ) - def _filter_meta_files(self, data: Tuple[str, Any], *, split: str, year: str, annotations: str) -> bool: + def _filter_meta_files(self, data: Tuple[str, Any]) -> bool: match = self._META_FILE_PATTERN.match(pathlib.Path(data[0]).name) - return bool(match and match["split"] == split and match["year"] == year and match["annotations"] == annotations) + return bool( + match + and match["split"] == self._split + and match["year"] == self._year + and match["annotations"] == self._annotations + ) def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]: key, _ = data @@ -157,38 +183,26 @@ def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: def _prepare_sample( self, data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]], - *, - annotations: str, ) -> Dict[str, Any]: ann_data, image_data = data anns, image_meta = ann_data sample = self._prepare_image(image_data) + # this method is only called if we have annotations + annotations = cast(str, self._annotations) sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta)) return sample - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: images_dp, meta_dp = resource_dps - if config.annotations is None: + if self._annotations is None: dp = hint_shuffling(images_dp) dp = hint_sharding(dp) + dp = hint_shuffling(dp) return Mapper(dp, self._prepare_image) - meta_dp = Filter( - meta_dp, - functools.partial( - self._filter_meta_files, - split=config.split, - year=config.year, - annotations=config.annotations, - ), - ) + meta_dp = Filter(meta_dp, self._filter_meta_files) meta_dp = JsonParser(meta_dp) meta_dp = Mapper(meta_dp, getitem(1)) meta_dp: IterDataPipe[Dict[str, Dict[str, Any]]] = MappingIterator(meta_dp) @@ -216,7 +230,6 @@ def _make_datapipe( ref_key_fn=getitem("id"), buffer_size=INFINITE_BUFFER_SIZE, ) - dp = IterKeyZipper( anns_dp, images_dp, @@ -224,18 +237,24 @@ def _make_datapipe( ref_key_fn=path_accessor("name"), buffer_size=INFINITE_BUFFER_SIZE, ) + return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return { + ("train", "2017"): defaultdict(lambda: 118_287, instances=117_266), + ("train", "2014"): defaultdict(lambda: 82_783, instances=82_081), + ("val", "2017"): defaultdict(lambda: 5_000, instances=4_952), + ("val", "2014"): defaultdict(lambda: 40_504, instances=40_137), + }[(self._split, self._year)][ + self._annotations # type: ignore[index] + ] - return Mapper(dp, functools.partial(self._prepare_sample, annotations=config.annotations)) - - def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]: - config = self.default_config - resources = self.resources(config) + def _generate_categories(self) -> Tuple[Tuple[str, str]]: + self._annotations = "instances" + resources = self._resources() - dp = resources[1].load(root) - dp = Filter( - dp, - functools.partial(self._filter_meta_files, split=config.split, year=config.year, annotations="instances"), - ) + dp = resources[1].load(self._root) + dp = Filter(dp, self._filter_meta_files) dp = JsonParser(dp) _, meta = next(iter(dp)) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 638878d5ec3..56accca02b4 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -54,7 +54,17 @@ class ImageNetDemux(enum.IntEnum): @register_dataset(NAME) class ImageNet(Dataset2): - def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> None: + """ + - **homepage**: https://www.image-net.org/ + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) info = _info() @@ -63,7 +73,7 @@ def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> N self._wnids = wnids self._wnid_to_category = dict(zip(wnids, categories)) - super().__init__(root) + super().__init__(root, skip_integrity_check=skip_integrity_check) _IMAGES_CHECKSUMS = { "train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb", diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index d000bdbe0e7..91b82794e27 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -50,7 +50,7 @@ def __init__( split: str = "train", year: str = "2012", task: str = "detection", - **kwargs: Any, + skip_integrity_check: bool = False, ) -> None: self._year = self._verify_str_arg(year, "year", ("2007", "2008", "2009", "2010", "2011", "2012")) if split == "test" and year != "2007": @@ -64,7 +64,7 @@ def __init__( self._categories = _info()["categories"] - super().__init__(root, **kwargs) + super().__init__(root, skip_integrity_check=skip_integrity_check) _TRAIN_VAL_ARCHIVES = { "2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"), diff --git a/torchvision/prototype/datasets/generate_category_files.py b/torchvision/prototype/datasets/generate_category_files.py index ac35eddb28b..6d4e854fe34 100644 --- a/torchvision/prototype/datasets/generate_category_files.py +++ b/torchvision/prototype/datasets/generate_category_files.py @@ -51,7 +51,7 @@ def parse_args(argv=None): if __name__ == "__main__": - args = parse_args(["-f", "imagenet"]) + args = parse_args() try: main(*args.names, force=args.force) diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 69180040194..a6ec05c3ff4 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -196,7 +196,18 @@ def _verify_str_arg( ) -> str: return verify_str_arg(value, arg, valid_values, custom_msg=custom_msg) - def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None: + def __init__( + self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False, dependencies: Collection[str] = () + ) -> None: + for dependency in dependencies: + try: + importlib.import_module(dependency) + except ModuleNotFoundError: + raise ModuleNotFoundError( + f"{type(self).__name__}() depends on the third-party package '{dependency}'. " + f"Please install it, for example with `pip install {dependency}`." + ) from None + self._root = pathlib.Path(root).expanduser().resolve() resources = [ resource.load(self._root, skip_integrity_check=skip_integrity_check) for resource in self._resources()