From e5a4b8d948c6e3af1e7e39a7efda23283e36bc46 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Apr 2022 09:22:42 +0200 Subject: [PATCH 1/6] migrate VOC prototype dataset --- test/builtin_dataset_mocks.py | 74 +++---- .../prototype/datasets/_builtin/voc.py | 191 ++++++++++-------- 2 files changed, 150 insertions(+), 115 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index c4f51463e34..83e0d70d29c 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -695,8 +695,22 @@ def semeion(info, root, config): return num_samples -class VOCMockData: - _TRAIN_VAL_FILE_NAMES = { +@register_mock( + configs=[ + *combinations_grid( + split=("train", "val", "trainval"), + year=("2007", "2008", "2009", "2010", "2011", "2012"), + task=("detection", "segmentation"), + ), + *combinations_grid( + split=("test",), + year=("2007",), + task=("detection", "segmentation"), + ), + ], +) +def voc(root, options): + TRAIN_VAL_FILE_NAMES = { "2007": "VOCtrainval_06-Nov-2007.tar", "2008": "VOCtrainval_14-Jul-2008.tar", "2009": "VOCtrainval_11-May-2009.tar", @@ -704,12 +718,11 @@ class VOCMockData: "2011": "VOCtrainval_25-May-2011.tar", "2012": "VOCtrainval_11-May-2012.tar", } - _TEST_FILE_NAMES = { + TEST_FILE_NAMES = { "2007": "VOCtest_06-Nov-2007.tar", } - @classmethod - def _make_split_files(cls, root, *, year, trainval): + def make_split_files(root, *, year, trainval): split_folder = root / "ImageSets" if trainval: @@ -733,16 +746,14 @@ def _make_split_files(cls, root, *, year, trainval): return sorted(set(itertools.chain(*ids_map.values()))), {split: len(ids) for split, ids in ids_map.items()} - @classmethod - def _make_detection_anns_folder(cls, root, name, *, file_name_fn, num_examples): + def make_detection_anns_folder(root, name, *, file_name_fn, num_examples): folder = root / name folder.mkdir(parents=True, exist_ok=True) for idx in range(num_examples): - cls._make_detection_ann_file(folder, file_name_fn(idx)) + make_detection_ann_file(folder, file_name_fn(idx)) - @classmethod - def _make_detection_ann_file(cls, root, name): + def make_detection_ann_file(root, name): def add_child(parent, name, text=None): child = ET.SubElement(parent, name) child.text = str(text) @@ -772,30 +783,25 @@ def add_bndbox(obj): with open(root / name, "wb") as fh: fh.write(ET.tostring(annotation)) - @classmethod - def generate(cls, root, *, year, trainval): - archive_folder = root - if year == "2011": - archive_folder /= "TrainVal" - data_folder = archive_folder / "VOCdevkit" / f"VOC{year}" - data_folder.mkdir(parents=True, exist_ok=True) - - ids, num_samples_map = cls._make_split_files(data_folder, year=year, trainval=trainval) - for make_folder_fn, name, suffix in [ - (create_image_folder, "JPEGImages", ".jpg"), - (create_image_folder, "SegmentationClass", ".png"), - (cls._make_detection_anns_folder, "Annotations", ".xml"), - ]: - make_folder_fn(data_folder, name, file_name_fn=lambda idx: ids[idx] + suffix, num_examples=len(ids)) - make_tar(root, (cls._TRAIN_VAL_FILE_NAMES if trainval else cls._TEST_FILE_NAMES)[year], data_folder) - - return num_samples_map - - -# @register_mock -def voc(info, root, config): - trainval = config.split != "test" - return VOCMockData.generate(root, year=config.year, trainval=trainval)[config.split] + year = options["year"] + trainval = options["split"] != "test" + + archive_folder = root + if year == "2011": + archive_folder /= "TrainVal" + data_folder = archive_folder / "VOCdevkit" / f"VOC{year}" + data_folder.mkdir(parents=True, exist_ok=True) + + ids, num_samples_map = make_split_files(data_folder, year=year, trainval=trainval) + for make_folder_fn, name, suffix in [ + (create_image_folder, "JPEGImages", ".jpg"), + (create_image_folder, "SegmentationClass", ".png"), + (make_detection_anns_folder, "Annotations", ".xml"), + ]: + make_folder_fn(data_folder, name, file_name_fn=lambda idx: ids[idx] + suffix, num_examples=len(ids)) + make_tar(root, (TRAIN_VAL_FILE_NAMES if trainval else TEST_FILE_NAMES)[year], data_folder) + + return num_samples_map[options["split"]] class CelebAMockData: diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index 5c1d3f8c3a3..5ea680ff76f 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -1,6 +1,7 @@ +import enum import functools import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Callable +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Union from xml.etree import ElementTree from torchdata.datapipes.iter import ( @@ -12,13 +13,7 @@ LineReader, ) from torchvision.datasets import VOCDetection -from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, - HttpResource, - OnlineResource, -) +from torchvision.prototype.datasets.utils import DatasetInfo, OnlineResource, HttpResource, Dataset2 from torchvision.prototype.datasets.utils._internal import ( path_accessor, getitem, @@ -26,34 +21,41 @@ path_comparator, hint_sharding, hint_shuffling, + BUILTIN_DIR, ) from torchvision.prototype.features import BoundingBox, Label, EncodedImage +from .._api import register_dataset, register_info -class VOCDatasetInfo(DatasetInfo): - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - self._configs = tuple(config for config in self._configs if config.split != "test" or config.year == "2007") +NAME = "voc" - def make_config(self, **options: Any) -> DatasetConfig: - config = super().make_config(**options) - if config.split == "test" and config.year != "2007": - raise ValueError("`split='test'` is only available for `year='2007'`") +CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) - return config +@register_info(NAME) +def info() -> Dict[str, Any]: + return dict(categories=CATEGORIES) -class VOC(Dataset): - def _make_info(self) -> DatasetInfo: - return VOCDatasetInfo( - "voc", - homepage="http://host.robots.ox.ac.uk/pascal/VOC/", - valid_options=dict( - split=("train", "val", "trainval", "test"), - year=("2012", "2007", "2008", "2009", "2010", "2011"), - task=("detection", "segmentation"), - ), - ) + +@register_dataset(NAME) +class VOC(Dataset2): + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + year: str = "2012", + task: str = "detection", + **kwargs: Any, + ) -> None: + self._year = self._verify_str_arg(year, "year", ("2007", "2008", "2009", "2010", "2011", "2012")) + if split == "test" and year != "2007": + raise ValueError("`split='test'` is only available for `year='2007'`") + else: + self._split = self._verify_str_arg(split, "split", ("train", "val", "trainval", "test")) + self._task = self._verify_str_arg(task, "task", ("detection", "segmentation")) + + super().__init__(root, **kwargs) _TRAIN_VAL_ARCHIVES = { "2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"), @@ -67,31 +69,34 @@ def _make_info(self) -> DatasetInfo: "2007": ("VOCtest_06-Nov-2007.tar", "6836888e2e01dca84577a849d339fa4f73e1e4f135d312430c4856b5609b4892") } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - file_name, sha256 = (self._TEST_ARCHIVES if config.split == "test" else self._TRAIN_VAL_ARCHIVES)[config.year] - archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{config.year}/{file_name}", sha256=sha256) - return [archive] + def _resources(self) -> List[OnlineResource]: + file_name, sha256 = (self._TEST_ARCHIVES if self._split == "test" else self._TRAIN_VAL_ARCHIVES)[self._year] + return [HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{self._year}/{file_name}", sha256=sha256)] + + @property + def _anns_folder(self) -> str: + return "Annotations" if self._task == "detection" else "SegmentationClass" - _ANNS_FOLDER = dict( - detection="Annotations", - segmentation="SegmentationClass", - ) - _SPLIT_FOLDER = dict( - detection="Main", - segmentation="Segmentation", - ) + @property + def _split_folder(self) -> str: + return "Main" if self._task == "detection" else "Segmentation" def _is_in_folder(self, data: Tuple[str, Any], *, name: str, depth: int = 1) -> bool: path = pathlib.Path(data[0]) return name in path.parent.parts[-depth:] - def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> Optional[int]: + class _Demux(enum.IntEnum): + SPLIT = 0 + IMAGES = 1 + ANNS = 2 + + def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: if self._is_in_folder(data, name="ImageSets", depth=2): - return 0 + return self._Demux.SPLIT elif self._is_in_folder(data, name="JPEGImages"): - return 1 - elif self._is_in_folder(data, name=self._ANNS_FOLDER[config.task]): - return 2 + return self._Demux.IMAGES + elif self._is_in_folder(data, name=self._anns_folder): + return self._Demux.ANNS else: return None @@ -110,9 +115,7 @@ def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: format="xyxy", image_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))), ), - labels=Label( - [self.categories.index(instance["name"]) for instance in instances], categories=self.categories - ), + labels=Label([CATEGORIES.index(instance["name"]) for instance in instances], categories=CATEGORIES), ) def _prepare_segmentation_ann(self, buffer: BinaryIO) -> Dict[str, Any]: @@ -121,8 +124,6 @@ def _prepare_segmentation_ann(self, buffer: BinaryIO) -> Dict[str, Any]: def _prepare_sample( self, data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]], - *, - prepare_ann_fn: Callable[[BinaryIO], Dict[str, Any]], ) -> Dict[str, Any]: split_and_image_data, ann_data = data _, image_data = split_and_image_data @@ -130,29 +131,23 @@ def _prepare_sample( ann_path, ann_buffer = ann_data return dict( - prepare_ann_fn(ann_buffer), + (self._prepare_detection_ann if self._task == "detection" else self._prepare_segmentation_ann)(ann_buffer), image_path=image_path, image=EncodedImage.from_file(image_buffer), ann_path=ann_path, ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: - archive_dp = resource_dps[0] + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: split_dp, images_dp, anns_dp = Demultiplexer( - archive_dp, + resource_dps[0], 3, - functools.partial(self._classify_archive, config=config), + self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE, ) - split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._SPLIT_FOLDER[config.task])) - split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) + split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._split_folder)) + split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) split_dp = LineReader(split_dp, decode=True) split_dp = hint_shuffling(split_dp) split_dp = hint_sharding(split_dp) @@ -166,25 +161,59 @@ def _make_datapipe( ref_key_fn=path_accessor("stem"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper( - dp, - functools.partial( - self._prepare_sample, - prepare_ann_fn=self._prepare_detection_ann - if config.task == "detection" - else self._prepare_segmentation_ann, - ), - ) - - def _filter_detection_anns(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool: - return self._classify_archive(data, config=config) == 2 - - def _generate_categories(self, root: pathlib.Path) -> List[str]: - config = self.info.make_config(task="detection") - - resource = self.resources(config)[0] - dp = resource.load(pathlib.Path(root) / self.name) - dp = Filter(dp, self._filter_detection_anns, fn_kwargs=dict(config=config)) + return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return { + ("train", "2007", "detection"): 2_501, + ("train", "2007", "segmentation"): 209, + ("train", "2008", "detection"): 2_111, + ("train", "2008", "segmentation"): 511, + ("train", "2009", "detection"): 3_473, + ("train", "2009", "segmentation"): 749, + ("train", "2010", "detection"): 4_998, + ("train", "2010", "segmentation"): 964, + ("train", "2011", "detection"): 5_717, + ("train", "2011", "segmentation"): 1_112, + ("train", "2012", "detection"): 5_717, + ("train", "2012", "segmentation"): 1_464, + ("val", "2007", "detection"): 2_510, + ("val", "2007", "segmentation"): 213, + ("val", "2008", "detection"): 2_221, + ("val", "2008", "segmentation"): 512, + ("val", "2009", "detection"): 3_581, + ("val", "2009", "segmentation"): 750, + ("val", "2010", "detection"): 5_105, + ("val", "2010", "segmentation"): 964, + ("val", "2011", "detection"): 5_823, + ("val", "2011", "segmentation"): 1_111, + ("val", "2012", "detection"): 5_823, + ("val", "2012", "segmentation"): 1_449, + ("trainval", "2007", "detection"): 5_011, + ("trainval", "2007", "segmentation"): 422, + ("trainval", "2008", "detection"): 4_332, + ("trainval", "2008", "segmentation"): 1_023, + ("trainval", "2009", "detection"): 7_054, + ("trainval", "2009", "segmentation"): 1_499, + ("trainval", "2010", "detection"): 10_103, + ("trainval", "2010", "segmentation"): 1_928, + ("trainval", "2011", "detection"): 11_540, + ("trainval", "2011", "segmentation"): 2_223, + ("trainval", "2012", "detection"): 11_540, + ("trainval", "2012", "segmentation"): 2_913, + ("test", "2007", "detection"): 4_952, + ("test", "2007", "segmentation"): 210, + }[(self._split, self._year, self._task)] + + def _filter_anns(self, data: Tuple[str, Any]) -> bool: + return self._classify_archive(data) == self._Demux.ANNS + + def _generate_categories(self) -> List[str]: + self._task = "detection" + resources = self._resources() + + archive_dp = resources[0].load(self._root) + dp = Filter(archive_dp, self._filter_detection_anns) dp = Mapper(dp, self._parse_detection_ann, input_col=1) return sorted({instance["name"] for _, anns in dp for instance in anns["object"]}) From 2d6ecc47f3a22cb419828ba6e6fb2c879debcd5d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Apr 2022 15:31:19 +0200 Subject: [PATCH 2/6] cleanup --- torchvision/prototype/datasets/_builtin/voc.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index 5ea680ff76f..d9684e65b18 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -33,7 +33,7 @@ @register_info(NAME) -def info() -> Dict[str, Any]: +def _info() -> Dict[str, Any]: return dict(categories=CATEGORIES) @@ -55,6 +55,8 @@ def __init__( self._split = self._verify_str_arg(split, "split", ("train", "val", "trainval", "test")) self._task = self._verify_str_arg(task, "task", ("detection", "segmentation")) + self._categories: List[str] = _info()["categories"] + super().__init__(root, **kwargs) _TRAIN_VAL_ARCHIVES = { @@ -71,7 +73,8 @@ def __init__( def _resources(self) -> List[OnlineResource]: file_name, sha256 = (self._TEST_ARCHIVES if self._split == "test" else self._TRAIN_VAL_ARCHIVES)[self._year] - return [HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{self._year}/{file_name}", sha256=sha256)] + archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{self._year}/{file_name}", sha256=sha256) + return [archive] @property def _anns_folder(self) -> str: @@ -115,7 +118,9 @@ def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: format="xyxy", image_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))), ), - labels=Label([CATEGORIES.index(instance["name"]) for instance in instances], categories=CATEGORIES), + labels=Label( + [self._categories.index(instance["name"]) for instance in instances], categories=self._categories + ), ) def _prepare_segmentation_ann(self, buffer: BinaryIO) -> Dict[str, Any]: @@ -138,8 +143,9 @@ def _prepare_sample( ) def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: + archive_dp = resource_dps[0] split_dp, images_dp, anns_dp = Demultiplexer( - resource_dps[0], + archive_dp, 3, self._classify_archive, drop_none=True, From 489fd91c2e9bf0702b5a223f9b8ef449dc4fa1c3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Apr 2022 16:22:19 +0200 Subject: [PATCH 3/6] revert unrelated mock data changes --- test/builtin_dataset_mocks.py | 87 +++++++++++++++++++---------------- 1 file changed, 47 insertions(+), 40 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 83e0d70d29c..ad979b6bd84 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -695,22 +695,8 @@ def semeion(info, root, config): return num_samples -@register_mock( - configs=[ - *combinations_grid( - split=("train", "val", "trainval"), - year=("2007", "2008", "2009", "2010", "2011", "2012"), - task=("detection", "segmentation"), - ), - *combinations_grid( - split=("test",), - year=("2007",), - task=("detection", "segmentation"), - ), - ], -) -def voc(root, options): - TRAIN_VAL_FILE_NAMES = { +class VOCMockData: + _TRAIN_VAL_FILE_NAMES = { "2007": "VOCtrainval_06-Nov-2007.tar", "2008": "VOCtrainval_14-Jul-2008.tar", "2009": "VOCtrainval_11-May-2009.tar", @@ -718,11 +704,12 @@ def voc(root, options): "2011": "VOCtrainval_25-May-2011.tar", "2012": "VOCtrainval_11-May-2012.tar", } - TEST_FILE_NAMES = { + _TEST_FILE_NAMES = { "2007": "VOCtest_06-Nov-2007.tar", } - def make_split_files(root, *, year, trainval): + @classmethod + def _make_split_files(cls, root, *, year, trainval): split_folder = root / "ImageSets" if trainval: @@ -746,14 +733,16 @@ def make_split_files(root, *, year, trainval): return sorted(set(itertools.chain(*ids_map.values()))), {split: len(ids) for split, ids in ids_map.items()} - def make_detection_anns_folder(root, name, *, file_name_fn, num_examples): + @classmethod + def _make_detection_anns_folder(cls, root, name, *, file_name_fn, num_examples): folder = root / name folder.mkdir(parents=True, exist_ok=True) for idx in range(num_examples): - make_detection_ann_file(folder, file_name_fn(idx)) + cls._make_detection_ann_file(folder, file_name_fn(idx)) - def make_detection_ann_file(root, name): + @classmethod + def _make_detection_ann_file(cls, root, name): def add_child(parent, name, text=None): child = ET.SubElement(parent, name) child.text = str(text) @@ -783,25 +772,43 @@ def add_bndbox(obj): with open(root / name, "wb") as fh: fh.write(ET.tostring(annotation)) - year = options["year"] - trainval = options["split"] != "test" - - archive_folder = root - if year == "2011": - archive_folder /= "TrainVal" - data_folder = archive_folder / "VOCdevkit" / f"VOC{year}" - data_folder.mkdir(parents=True, exist_ok=True) - - ids, num_samples_map = make_split_files(data_folder, year=year, trainval=trainval) - for make_folder_fn, name, suffix in [ - (create_image_folder, "JPEGImages", ".jpg"), - (create_image_folder, "SegmentationClass", ".png"), - (make_detection_anns_folder, "Annotations", ".xml"), - ]: - make_folder_fn(data_folder, name, file_name_fn=lambda idx: ids[idx] + suffix, num_examples=len(ids)) - make_tar(root, (TRAIN_VAL_FILE_NAMES if trainval else TEST_FILE_NAMES)[year], data_folder) - - return num_samples_map[options["split"]] + @classmethod + def generate(cls, root, *, year, trainval): + archive_folder = root + if year == "2011": + archive_folder /= "TrainVal" + data_folder = archive_folder / "VOCdevkit" / f"VOC{year}" + data_folder.mkdir(parents=True, exist_ok=True) + + ids, num_samples_map = cls._make_split_files(data_folder, year=year, trainval=trainval) + for make_folder_fn, name, suffix in [ + (create_image_folder, "JPEGImages", ".jpg"), + (create_image_folder, "SegmentationClass", ".png"), + (cls._make_detection_anns_folder, "Annotations", ".xml"), + ]: + make_folder_fn(data_folder, name, file_name_fn=lambda idx: ids[idx] + suffix, num_examples=len(ids)) + make_tar(root, (cls._TRAIN_VAL_FILE_NAMES if trainval else cls._TEST_FILE_NAMES)[year], data_folder) + + return num_samples_map + + +@register_mock( + configs=[ + *combinations_grid( + split=("train", "val", "trainval"), + year=("2007", "2008", "2009", "2010", "2011", "2012"), + task=("detection", "segmentation"), + ), + *combinations_grid( + split=("test",), + year=("2007",), + task=("detection", "segmentation"), + ), + ], +) +def voc(root, config): + trainval = config["split"] != "test" + return VOCMockData.generate(root, year=config["year"], trainval=trainval)[config["split"]] class CelebAMockData: From 39af08eaebe6eb7cd15266df45d65d52393045ab Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Apr 2022 16:26:05 +0200 Subject: [PATCH 4/6] remove categories annotations --- torchvision/prototype/datasets/_builtin/imagenet.py | 4 ++-- torchvision/prototype/datasets/_builtin/voc.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index fb507af01b0..638878d5ec3 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -59,8 +59,8 @@ def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> N info = _info() categories, wnids = info["categories"], info["wnids"] - self._categories: List[str] = categories - self._wnids: List[str] = wnids + self._categories = categories + self._wnids = wnids self._wnid_to_category = dict(zip(wnids, categories)) super().__init__(root) diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index d9684e65b18..d8590a7062e 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -55,7 +55,7 @@ def __init__( self._split = self._verify_str_arg(split, "split", ("train", "val", "trainval", "test")) self._task = self._verify_str_arg(task, "task", ("detection", "segmentation")) - self._categories: List[str] = _info()["categories"] + self._categories = _info()["categories"] super().__init__(root, **kwargs) From be8729309ec1925921dfa74df0a336ef10a17ba5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Apr 2022 16:28:04 +0200 Subject: [PATCH 5/6] move properties to constructor --- torchvision/prototype/datasets/_builtin/voc.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index d8590a7062e..06e84c43cf8 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -55,6 +55,9 @@ def __init__( self._split = self._verify_str_arg(split, "split", ("train", "val", "trainval", "test")) self._task = self._verify_str_arg(task, "task", ("detection", "segmentation")) + self._anns_folder = "Annotations" if task == "detection" else "SegmentationClass" + self._split_folder = "Main" if task == "detection" else "Segmentation" + self._categories = _info()["categories"] super().__init__(root, **kwargs) @@ -76,14 +79,6 @@ def _resources(self) -> List[OnlineResource]: archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{self._year}/{file_name}", sha256=sha256) return [archive] - @property - def _anns_folder(self) -> str: - return "Annotations" if self._task == "detection" else "SegmentationClass" - - @property - def _split_folder(self) -> str: - return "Main" if self._task == "detection" else "Segmentation" - def _is_in_folder(self, data: Tuple[str, Any], *, name: str, depth: int = 1) -> bool: path = pathlib.Path(data[0]) return name in path.parent.parts[-depth:] From 3902312f73953e00023efa4ff7b98453049682a4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Apr 2022 16:43:47 +0200 Subject: [PATCH 6/6] readd homepage --- torchvision/prototype/datasets/_builtin/voc.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index 06e84c43cf8..d000bdbe0e7 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -39,6 +39,10 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) class VOC(Dataset2): + """ + - **homepage**: http://host.robots.ox.ac.uk/pascal/VOC/ + """ + def __init__( self, root: Union[str, pathlib.Path],