From d1dd88ad4a70fd41ce6e0f1f28442c98016c5d0b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 23 Feb 2022 12:07:14 +0100 Subject: [PATCH 1/8] migrate coco prototype --- test/builtin_dataset_mocks.py | 75 ++++------ torchvision/prototype/datasets/_api.py | 2 +- .../prototype/datasets/_builtin/coco.py | 140 ++++++++++-------- .../prototype/datasets/_builtin/imagenet.py | 4 + .../datasets/generate_category_files.py | 2 +- .../prototype/datasets/utils/_dataset.py | 13 +- 6 files changed, 124 insertions(+), 112 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 1d988196190..29907785017 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -490,9 +490,16 @@ def imagenet(root, config): return num_samples -class CocoMockData: - @classmethod - def _make_images_archive(cls, root, name, *, num_samples): +@register_mock( + configs=combinations_grid( + split=("train", "val"), + year=("2017", "2014"), + annotations=("instances", "captions", None), + ) +) +def coco(root, options): + def make_images_archive(root, name, *, num_samples): + image_paths = create_image_folder( root, name, file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_samples ) @@ -507,15 +514,7 @@ def _make_images_archive(cls, root, name, *, num_samples): return images_meta - @classmethod - def _make_annotations_json( - cls, - root, - name, - *, - images_meta, - fn, - ): + def make_annotations_json(root, name, *, images_meta, fn): num_anns_per_image = torch.randint(1, 5, (len(images_meta),)) num_anns_total = int(num_anns_per_image.sum()) ann_ids_iter = iter(torch.arange(num_anns_total)[torch.randperm(num_anns_total)]) @@ -532,8 +531,7 @@ def _make_annotations_json( return num_anns_per_image - @staticmethod - def _make_instances_data(ann_id, image_meta): + def make_instances_data(ann_id, image_meta): def make_rle_segmentation(): height, width = image_meta["height"], image_meta["width"] numel = height * width @@ -552,52 +550,37 @@ def make_rle_segmentation(): category_id=int(make_scalar(dtype=torch.int64)), ) - @staticmethod - def _make_captions_data(ann_id, image_meta): + def make_captions_data(ann_id, image_meta): return dict(caption=f"Caption {ann_id} describing image {image_meta['id']}.") - @classmethod - def _make_annotations(cls, root, name, *, images_meta): + def make_annotations(root, name, *, images_meta): num_anns_per_image = torch.zeros((len(images_meta),), dtype=torch.int64) for annotations, fn in ( - ("instances", cls._make_instances_data), - ("captions", cls._make_captions_data), + ("instances", make_instances_data), + ("captions", make_captions_data), ): - num_anns_per_image += cls._make_annotations_json( + num_anns_per_image += make_annotations_json( root, f"{annotations}_{name}.json", images_meta=images_meta, fn=fn ) return int(num_anns_per_image.sum()) - @classmethod - def generate( - cls, - root, - *, - year, - num_samples, - ): - annotations_dir = root / "annotations" - annotations_dir.mkdir() + annotations_dir = root / "annotations" + annotations_dir.mkdir() - for split in ("train", "val"): - config_name = f"{split}{year}" + for split in ("train", "val"): + config_name = f"{split}{options['year']}" - images_meta = cls._make_images_archive(root, config_name, num_samples=num_samples) - cls._make_annotations( - annotations_dir, - config_name, - images_meta=images_meta, - ) - - make_zip(root, f"annotations_trainval{year}.zip", annotations_dir) - - return num_samples + images_meta = make_images_archive(root, config_name, num_samples=5) + make_annotations( + annotations_dir, + config_name, + images_meta=images_meta, + ) + make_zip(root, f"annotations_trainval{options['year']}.zip", annotations_dir) -# @register_mock -def coco(info, root, config): - return CocoMockData.generate(root, year=config.year, num_samples=5) + return len(images_meta) class SBDMockData: diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 8f8bb53deb4..3c1b966f52a 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -49,7 +49,7 @@ def find(dct: Dict[str, T], name: str) -> T: "You can use torchvision.datasets.list_datasets() to get a list of all available datasets." ), ) - ) from error + ) from None def info(name: str) -> Dict[str, Any]: diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 74232eb714d..b5f14308753 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -1,8 +1,8 @@ -import functools import pathlib import re from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union import torch from torchdata.datapipes.iter import ( @@ -16,11 +16,10 @@ UnBatcher, ) from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, DatasetInfo, HttpResource, OnlineResource, + Dataset2, ) from torchvision.prototype.datasets.utils._internal import ( MappingIterator, @@ -32,27 +31,47 @@ hint_shuffling, ) from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage -from torchvision.prototype.utils._internal import FrozenMapping - - -class Coco(Dataset): - def _make_info(self) -> DatasetInfo: - name = "coco" - categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories")) - - return DatasetInfo( - name, - dependencies=("pycocotools",), - categories=categories, - homepage="https://cocodataset.org/", - valid_options=dict( - split=("train", "val"), - year=("2017", "2014"), - annotations=(*self._ANN_DECODERS.keys(), None), - ), - extra=dict(category_to_super_category=FrozenMapping(zip(categories, super_categories))), + +from .._api import register_dataset, register_info + + +NAME = "coco" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + return dict(categories=categories, super_categories=super_categories) + + +@register_dataset(NAME) +class Coco(Dataset2): + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + year: str = "2017", + annotations: Optional[str] = "instances", + ) -> None: + """ + - **homepage**: https://cocodataset.org/ + """ + self._split = self._verify_str_arg(split, "split", {"train", "val"}) + self._year = self._verify_str_arg(year, "year", {"2017", "2014"}) + self._annotations = ( + self._verify_str_arg(annotations, "annotations", self._ANN_DECODERS.keys()) + if annotations is not None + else None ) + info = _info() + categories, super_categories = info["categories"], info["super_categories"] + self._categories: List[str] = categories + self._category_to_super_category = dict(zip(categories, super_categories)) + + super().__init__(root, dependencies=("pycocotools",)) + _IMAGE_URL_BASE = "http://images.cocodataset.org/zips" _IMAGES_CHECKSUMS = { @@ -69,14 +88,14 @@ def _make_info(self) -> DatasetInfo: "2017": "113a836d90195ee1f884e704da6304dfaaecff1f023f49b6ca93c4aaae470268", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: images = HttpResource( - f"{self._IMAGE_URL_BASE}/{config.split}{config.year}.zip", - sha256=self._IMAGES_CHECKSUMS[(config.year, config.split)], + f"{self._IMAGE_URL_BASE}/{self._split}{self._year}.zip", + sha256=self._IMAGES_CHECKSUMS[(self._year, self._split)], ) meta = HttpResource( - f"{self._META_URL_BASE}/annotations_trainval{config.year}.zip", - sha256=self._META_CHECKSUMS[config.year], + f"{self._META_URL_BASE}/annotations_trainval{self._year}.zip", + sha256=self._META_CHECKSUMS[self._year], ) return [images, meta] @@ -110,10 +129,8 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st format="xywh", image_size=image_size, ), - labels=Label(labels, categories=self.categories), - super_categories=[ - self.info.extra.category_to_super_category[self.info.categories[label]] for label in labels - ], + labels=Label(labels, categories=self._categories), + super_categories=[self._category_to_super_category[self._categories[label]] for label in labels], ann_ids=[ann["id"] for ann in anns], ) @@ -134,9 +151,14 @@ def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str, fr"(?P({'|'.join(_ANN_DECODERS.keys())}))_(?P[a-zA-Z]+)(?P\d+)[.]json" ) - def _filter_meta_files(self, data: Tuple[str, Any], *, split: str, year: str, annotations: str) -> bool: + def _filter_meta_files(self, data: Tuple[str, Any]) -> bool: match = self._META_FILE_PATTERN.match(pathlib.Path(data[0]).name) - return bool(match and match["split"] == split and match["year"] == year and match["annotations"] == annotations) + return bool( + match + and match["split"] == self._split + and match["year"] == self._year + and match["annotations"] == self._annotations + ) def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]: key, _ = data @@ -157,38 +179,25 @@ def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: def _prepare_sample( self, data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]], - *, - annotations: str, ) -> Dict[str, Any]: ann_data, image_data = data anns, image_meta = ann_data sample = self._prepare_image(image_data) + # this method is only called if we have annotations + annotations = cast(str, self._annotations) sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta)) return sample - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: images_dp, meta_dp = resource_dps - if config.annotations is None: + if self._annotations is None: dp = hint_sharding(images_dp) dp = hint_shuffling(dp) return Mapper(dp, self._prepare_image) - meta_dp = Filter( - meta_dp, - functools.partial( - self._filter_meta_files, - split=config.split, - year=config.year, - annotations=config.annotations, - ), - ) + meta_dp = Filter(meta_dp, self._filter_meta_files) meta_dp = JsonParser(meta_dp) meta_dp = Mapper(meta_dp, getitem(1)) meta_dp: IterDataPipe[Dict[str, Dict[str, Any]]] = MappingIterator(meta_dp) @@ -216,7 +225,6 @@ def _make_datapipe( ref_key_fn=getitem("id"), buffer_size=INFINITE_BUFFER_SIZE, ) - dp = IterKeyZipper( anns_dp, images_dp, @@ -224,18 +232,24 @@ def _make_datapipe( ref_key_fn=path_accessor("name"), buffer_size=INFINITE_BUFFER_SIZE, ) + return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return { + ("train", "2017"): defaultdict(lambda: 118_287, instances=117_266), + ("train", "2014"): defaultdict(lambda: 82_783, instances=82_081), + ("val", "2017"): defaultdict(lambda: 5_000, instances=4_952), + ("val", "2014"): defaultdict(lambda: 40_504, instances=40_137), + }[(self._split, self._year)][ + self._annotations # type: ignore[index] + ] - return Mapper(dp, functools.partial(self._prepare_sample, annotations=config.annotations)) - - def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]: - config = self.default_config - resources = self.resources(config) + def _generate_categories(self) -> Tuple[Tuple[str, str]]: + self._annotations = "instances" + resources = self._resources() - dp = resources[1].load(root) - dp = Filter( - dp, - functools.partial(self._filter_meta_files, split=config.split, year=config.year, annotations="instances"), - ) + dp = resources[1].load(self._root) + dp = Filter(dp, self._filter_meta_files) dp = JsonParser(dp) _, meta = next(iter(dp)) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 6f91d4c4a8d..22e8cd4eb63 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -42,6 +42,10 @@ def __init__(self, **kwargs: Any) -> None: @register_dataset(NAME) class ImageNet(Dataset2): + """ + - **homepage**: https://www.image-net.org/ + """ + def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> None: self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) diff --git a/torchvision/prototype/datasets/generate_category_files.py b/torchvision/prototype/datasets/generate_category_files.py index ac35eddb28b..6d4e854fe34 100644 --- a/torchvision/prototype/datasets/generate_category_files.py +++ b/torchvision/prototype/datasets/generate_category_files.py @@ -51,7 +51,7 @@ def parse_args(argv=None): if __name__ == "__main__": - args = parse_args(["-f", "imagenet"]) + args = parse_args() try: main(*args.names, force=args.force) diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 7200f00fd02..70ce14f5332 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -195,7 +195,18 @@ def _verify_str_arg( ) -> str: return verify_str_arg(value, arg, valid_values, custom_msg=custom_msg) - def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None: + def __init__( + self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False, dependencies: Collection[str] = () + ) -> None: + for dependency in dependencies: + try: + importlib.import_module(dependency) + except ModuleNotFoundError as error: + raise ModuleNotFoundError( + f"{type(self).__name__}() depends on the third-party package '{dependency}'. " + f"Please install it, for example with `pip install {dependency}`." + ) from error + self._root = pathlib.Path(root).expanduser().resolve() resources = [ resource.load(self._root, skip_integrity_check=skip_integrity_check) for resource in self._resources() From a6af500196de20ad7b1f4abbb7185294399ac28f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 24 Feb 2022 17:37:01 +0100 Subject: [PATCH 2/8] revert unrelated change --- torchvision/prototype/datasets/_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 3c1b966f52a..8f8bb53deb4 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -49,7 +49,7 @@ def find(dct: Dict[str, T], name: str) -> T: "You can use torchvision.datasets.list_datasets() to get a list of all available datasets." ), ) - ) from None + ) from error def info(name: str) -> Dict[str, Any]: From 6c98b82884919a8b3223ca38b965da5e07f35f06 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 24 Feb 2022 17:39:14 +0100 Subject: [PATCH 3/8] add kwargs to super constructor call --- torchvision/prototype/datasets/_builtin/coco.py | 3 ++- torchvision/prototype/datasets/_builtin/imagenet.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index b5f14308753..7bc00e6313b 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -53,6 +53,7 @@ def __init__( split: str = "train", year: str = "2017", annotations: Optional[str] = "instances", + **kwargs: Any, ) -> None: """ - **homepage**: https://cocodataset.org/ @@ -70,7 +71,7 @@ def __init__( self._categories: List[str] = categories self._category_to_super_category = dict(zip(categories, super_categories)) - super().__init__(root, dependencies=("pycocotools",)) + super().__init__(root, dependencies=("pycocotools",), **kwargs) _IMAGE_URL_BASE = "http://images.cocodataset.org/zips" diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 22e8cd4eb63..f3e656b9208 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -46,7 +46,7 @@ class ImageNet(Dataset2): - **homepage**: https://www.image-net.org/ """ - def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> None: + def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train", **kwargs: Any) -> None: self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) info = _info() @@ -55,7 +55,7 @@ def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> N self._wnids: List[str] = wnids self._wnid_to_category = dict(zip(wnids, categories)) - super().__init__(root) + super().__init__(root, **kwargs) _IMAGES_CHECKSUMS = { "train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb", From d7bccfd43ba210bdaf63ab7565d06849abc095ec Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Apr 2022 16:32:35 +0200 Subject: [PATCH 4/8] remove unneeded changes --- test/builtin_dataset_mocks.py | 81 ++++++++++++------- .../prototype/datasets/_builtin/coco.py | 2 +- 2 files changed, 53 insertions(+), 30 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index e5e0ff45471..5c6304292f1 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -495,16 +495,9 @@ def imagenet(root, config): return num_samples -@register_mock( - configs=combinations_grid( - split=("train", "val"), - year=("2017", "2014"), - annotations=("instances", "captions", None), - ) -) -def coco(root, options): - def make_images_archive(root, name, *, num_samples): - +class CocoMockData: + @classmethod + def _make_images_archive(cls, root, name, *, num_samples): image_paths = create_image_folder( root, name, file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_samples ) @@ -519,7 +512,15 @@ def make_images_archive(root, name, *, num_samples): return images_meta - def make_annotations_json(root, name, *, images_meta, fn): + @classmethod + def _make_annotations_json( + cls, + root, + name, + *, + images_meta, + fn, + ): num_anns_per_image = torch.randint(1, 5, (len(images_meta),)) num_anns_total = int(num_anns_per_image.sum()) ann_ids_iter = iter(torch.arange(num_anns_total)[torch.randperm(num_anns_total)]) @@ -536,7 +537,8 @@ def make_annotations_json(root, name, *, images_meta, fn): return num_anns_per_image - def make_instances_data(ann_id, image_meta): + @staticmethod + def _make_instances_data(ann_id, image_meta): def make_rle_segmentation(): height, width = image_meta["height"], image_meta["width"] numel = height * width @@ -555,37 +557,58 @@ def make_rle_segmentation(): category_id=int(make_scalar(dtype=torch.int64)), ) - def make_captions_data(ann_id, image_meta): + @staticmethod + def _make_captions_data(ann_id, image_meta): return dict(caption=f"Caption {ann_id} describing image {image_meta['id']}.") - def make_annotations(root, name, *, images_meta): + @classmethod + def _make_annotations(cls, root, name, *, images_meta): num_anns_per_image = torch.zeros((len(images_meta),), dtype=torch.int64) for annotations, fn in ( - ("instances", make_instances_data), - ("captions", make_captions_data), + ("instances", cls._make_instances_data), + ("captions", cls._make_captions_data), ): - num_anns_per_image += make_annotations_json( + num_anns_per_image += cls._make_annotations_json( root, f"{annotations}_{name}.json", images_meta=images_meta, fn=fn ) return int(num_anns_per_image.sum()) - annotations_dir = root / "annotations" - annotations_dir.mkdir() + @classmethod + def generate( + cls, + root, + *, + year, + num_samples, + ): + annotations_dir = root / "annotations" + annotations_dir.mkdir() - for split in ("train", "val"): - config_name = f"{split}{options['year']}" + for split in ("train", "val"): + config_name = f"{split}{year}" - images_meta = make_images_archive(root, config_name, num_samples=5) - make_annotations( - annotations_dir, - config_name, - images_meta=images_meta, - ) + images_meta = cls._make_images_archive(root, config_name, num_samples=num_samples) + cls._make_annotations( + annotations_dir, + config_name, + images_meta=images_meta, + ) + + make_zip(root, f"annotations_trainval{year}.zip", annotations_dir) + + return num_samples - make_zip(root, f"annotations_trainval{options['year']}.zip", annotations_dir) - return len(images_meta) +@register_mock( + configs=combinations_grid( + split=("train", "val"), + year=("2017", "2014"), + annotations=("instances", "captions", None), + ) +) +def coco(root, config): + return CocoMockData.generate(root, year=config["year"], num_samples=5) class SBDMockData: diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 6d6485d8dad..a2f3510af45 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -68,7 +68,7 @@ def __init__( info = _info() categories, super_categories = info["categories"], info["super_categories"] - self._categories: List[str] = categories + self._categories = categories self._category_to_super_category = dict(zip(categories, super_categories)) super().__init__(root, dependencies=("pycocotools",), **kwargs) From 25d999ff6d7a10642b9991f889507034f9fc8ae4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Apr 2022 16:45:37 +0200 Subject: [PATCH 5/8] fix docstring position --- torchvision/prototype/datasets/_builtin/coco.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index a2f3510af45..fdcc033a902 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -46,6 +46,10 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) class Coco(Dataset2): + """ + - **homepage**: https://cocodataset.org/ + """ + def __init__( self, root: Union[str, pathlib.Path], @@ -55,9 +59,6 @@ def __init__( annotations: Optional[str] = "instances", **kwargs: Any, ) -> None: - """ - - **homepage**: https://cocodataset.org/ - """ self._split = self._verify_str_arg(split, "split", {"train", "val"}) self._year = self._verify_str_arg(year, "year", {"2017", "2014"}) self._annotations = ( From 453f8db1839063a7d4b05fbb9ea5fbf7f9686348 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Apr 2022 19:33:13 +0200 Subject: [PATCH 6/8] make kwargs explicit --- torchvision/prototype/datasets/_builtin/coco.py | 4 ++-- torchvision/prototype/datasets/_builtin/imagenet.py | 10 ++++++++-- torchvision/prototype/datasets/_builtin/voc.py | 4 ++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index fdcc033a902..024a09beabc 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -57,7 +57,7 @@ def __init__( split: str = "train", year: str = "2017", annotations: Optional[str] = "instances", - **kwargs: Any, + skip_integrity_check: bool = False, ) -> None: self._split = self._verify_str_arg(split, "split", {"train", "val"}) self._year = self._verify_str_arg(year, "year", {"2017", "2014"}) @@ -72,7 +72,7 @@ def __init__( self._categories = categories self._category_to_super_category = dict(zip(categories, super_categories)) - super().__init__(root, dependencies=("pycocotools",), **kwargs) + super().__init__(root, dependencies=("pycocotools",), skip_integrity_check=skip_integrity_check) _IMAGE_URL_BASE = "http://images.cocodataset.org/zips" diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index a4cf98a0d0a..56accca02b4 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -58,7 +58,13 @@ class ImageNet(Dataset2): - **homepage**: https://www.image-net.org/ """ - def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train", **kwargs: Any) -> None: + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) info = _info() @@ -67,7 +73,7 @@ def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train", **kw self._wnids = wnids self._wnid_to_category = dict(zip(wnids, categories)) - super().__init__(root, **kwargs) + super().__init__(root, skip_integrity_check=skip_integrity_check) _IMAGES_CHECKSUMS = { "train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb", diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index d000bdbe0e7..91b82794e27 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -50,7 +50,7 @@ def __init__( split: str = "train", year: str = "2012", task: str = "detection", - **kwargs: Any, + skip_integrity_check: bool = False, ) -> None: self._year = self._verify_str_arg(year, "year", ("2007", "2008", "2009", "2010", "2011", "2012")) if split == "test" and year != "2007": @@ -64,7 +64,7 @@ def __init__( self._categories = _info()["categories"] - super().__init__(root, **kwargs) + super().__init__(root, skip_integrity_check=skip_integrity_check) _TRAIN_VAL_ARCHIVES = { "2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"), From 3bc839ed6af28f2b77a5539685e25bb797f9b16d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Apr 2022 12:24:41 +0200 Subject: [PATCH 7/8] add dependencies to docstring --- torchvision/prototype/datasets/_builtin/coco.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 024a09beabc..75896a8db08 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -48,6 +48,8 @@ def _info() -> Dict[str, Any]: class Coco(Dataset2): """ - **homepage**: https://cocodataset.org/ + - **dependencies**: + - _ """ def __init__( From 4ba09d6707d2b10deabe4f939d23a95a1d28186e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Apr 2022 14:09:20 +0200 Subject: [PATCH 8/8] fix missing dependency message --- torchvision/prototype/datasets/utils/_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 6b6fa30f6c6..a6ec05c3ff4 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -202,11 +202,11 @@ def __init__( for dependency in dependencies: try: importlib.import_module(dependency) - except ModuleNotFoundError as error: + except ModuleNotFoundError: raise ModuleNotFoundError( f"{type(self).__name__}() depends on the third-party package '{dependency}'. " f"Please install it, for example with `pip install {dependency}`." - ) from error + ) from None self._root = pathlib.Path(root).expanduser().resolve() resources = [