Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate VOC prototype dataset #5743

Merged
merged 7 commits into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,10 +792,23 @@ def generate(cls, root, *, year, trainval):
return num_samples_map


# @register_mock
def voc(info, root, config):
trainval = config.split != "test"
return VOCMockData.generate(root, year=config.year, trainval=trainval)[config.split]
@register_mock(
configs=[
*combinations_grid(
split=("train", "val", "trainval"),
year=("2007", "2008", "2009", "2010", "2011", "2012"),
task=("detection", "segmentation"),
),
*combinations_grid(
split=("test",),
year=("2007",),
task=("detection", "segmentation"),
),
],
)
def voc(root, config):
trainval = config["split"] != "test"
return VOCMockData.generate(root, year=config["year"], trainval=trainval)[config["split"]]


class CelebAMockData:
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> N

info = _info()
categories, wnids = info["categories"], info["wnids"]
self._categories: List[str] = categories
self._wnids: List[str] = wnids
self._categories = categories
self._wnids = wnids
self._wnid_to_category = dict(zip(wnids, categories))

super().__init__(root)
Expand Down
186 changes: 110 additions & 76 deletions torchvision/prototype/datasets/_builtin/voc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import enum
import functools
import pathlib
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Callable
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Union
from xml.etree import ElementTree

from torchdata.datapipes.iter import (
Expand All @@ -12,48 +13,58 @@
LineReader,
)
from torchvision.datasets import VOCDetection
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
from torchvision.prototype.datasets.utils import DatasetInfo, OnlineResource, HttpResource, Dataset2
from torchvision.prototype.datasets.utils._internal import (
path_accessor,
getitem,
INFINITE_BUFFER_SIZE,
path_comparator,
hint_sharding,
hint_shuffling,
BUILTIN_DIR,
)
from torchvision.prototype.features import BoundingBox, Label, EncodedImage

from .._api import register_dataset, register_info

NAME = "voc"

CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories"))


@register_info(NAME)
def _info() -> Dict[str, Any]:
return dict(categories=CATEGORIES)


class VOCDatasetInfo(DatasetInfo):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self._configs = tuple(config for config in self._configs if config.split != "test" or config.year == "2007")
@register_dataset(NAME)
class VOC(Dataset2):
"""
- **homepage**: http://host.robots.ox.ac.uk/pascal/VOC/
"""

def make_config(self, **options: Any) -> DatasetConfig:
config = super().make_config(**options)
if config.split == "test" and config.year != "2007":
def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
year: str = "2012",
task: str = "detection",
**kwargs: Any,
) -> None:
self._year = self._verify_str_arg(year, "year", ("2007", "2008", "2009", "2010", "2011", "2012"))
if split == "test" and year != "2007":
raise ValueError("`split='test'` is only available for `year='2007'`")
else:
self._split = self._verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
self._task = self._verify_str_arg(task, "task", ("detection", "segmentation"))

return config
self._anns_folder = "Annotations" if task == "detection" else "SegmentationClass"
self._split_folder = "Main" if task == "detection" else "Segmentation"

self._categories = _info()["categories"]

class VOC(Dataset):
def _make_info(self) -> DatasetInfo:
return VOCDatasetInfo(
"voc",
homepage="http://host.robots.ox.ac.uk/pascal/VOC/",
valid_options=dict(
split=("train", "val", "trainval", "test"),
year=("2012", "2007", "2008", "2009", "2010", "2011"),
task=("detection", "segmentation"),
),
)
super().__init__(root, **kwargs)

_TRAIN_VAL_ARCHIVES = {
"2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"),
Expand All @@ -67,31 +78,27 @@ def _make_info(self) -> DatasetInfo:
"2007": ("VOCtest_06-Nov-2007.tar", "6836888e2e01dca84577a849d339fa4f73e1e4f135d312430c4856b5609b4892")
}

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
file_name, sha256 = (self._TEST_ARCHIVES if config.split == "test" else self._TRAIN_VAL_ARCHIVES)[config.year]
archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{config.year}/{file_name}", sha256=sha256)
def _resources(self) -> List[OnlineResource]:
file_name, sha256 = (self._TEST_ARCHIVES if self._split == "test" else self._TRAIN_VAL_ARCHIVES)[self._year]
archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{self._year}/{file_name}", sha256=sha256)
return [archive]

_ANNS_FOLDER = dict(
detection="Annotations",
segmentation="SegmentationClass",
)
_SPLIT_FOLDER = dict(
detection="Main",
segmentation="Segmentation",
)

def _is_in_folder(self, data: Tuple[str, Any], *, name: str, depth: int = 1) -> bool:
path = pathlib.Path(data[0])
return name in path.parent.parts[-depth:]

def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> Optional[int]:
class _Demux(enum.IntEnum):
SPLIT = 0
IMAGES = 1
ANNS = 2

def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
if self._is_in_folder(data, name="ImageSets", depth=2):
return 0
return self._Demux.SPLIT
elif self._is_in_folder(data, name="JPEGImages"):
return 1
elif self._is_in_folder(data, name=self._ANNS_FOLDER[config.task]):
return 2
return self._Demux.IMAGES
elif self._is_in_folder(data, name=self._anns_folder):
return self._Demux.ANNS
else:
return None

Expand All @@ -111,7 +118,7 @@ def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
image_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))),
),
labels=Label(
[self.categories.index(instance["name"]) for instance in instances], categories=self.categories
[self._categories.index(instance["name"]) for instance in instances], categories=self._categories
),
)

Expand All @@ -121,38 +128,31 @@ def _prepare_segmentation_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
def _prepare_sample(
self,
data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]],
*,
prepare_ann_fn: Callable[[BinaryIO], Dict[str, Any]],
) -> Dict[str, Any]:
split_and_image_data, ann_data = data
_, image_data = split_and_image_data
image_path, image_buffer = image_data
ann_path, ann_buffer = ann_data

return dict(
prepare_ann_fn(ann_buffer),
(self._prepare_detection_ann if self._task == "detection" else self._prepare_segmentation_ann)(ann_buffer),
image_path=image_path,
image=EncodedImage.from_file(image_buffer),
ann_path=ann_path,
)

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
split_dp, images_dp, anns_dp = Demultiplexer(
archive_dp,
3,
functools.partial(self._classify_archive, config=config),
self._classify_archive,
drop_none=True,
buffer_size=INFINITE_BUFFER_SIZE,
)

split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._SPLIT_FOLDER[config.task]))
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._split_folder))
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_shuffling(split_dp)
split_dp = hint_sharding(split_dp)
Expand All @@ -166,25 +166,59 @@ def _make_datapipe(
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(
dp,
functools.partial(
self._prepare_sample,
prepare_ann_fn=self._prepare_detection_ann
if config.task == "detection"
else self._prepare_segmentation_ann,
),
)

def _filter_detection_anns(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool:
return self._classify_archive(data, config=config) == 2

def _generate_categories(self, root: pathlib.Path) -> List[str]:
config = self.info.make_config(task="detection")

resource = self.resources(config)[0]
dp = resource.load(pathlib.Path(root) / self.name)
dp = Filter(dp, self._filter_detection_anns, fn_kwargs=dict(config=config))
return Mapper(dp, self._prepare_sample)

def __len__(self) -> int:
return {
("train", "2007", "detection"): 2_501,
("train", "2007", "segmentation"): 209,
("train", "2008", "detection"): 2_111,
("train", "2008", "segmentation"): 511,
("train", "2009", "detection"): 3_473,
("train", "2009", "segmentation"): 749,
("train", "2010", "detection"): 4_998,
("train", "2010", "segmentation"): 964,
("train", "2011", "detection"): 5_717,
("train", "2011", "segmentation"): 1_112,
("train", "2012", "detection"): 5_717,
("train", "2012", "segmentation"): 1_464,
("val", "2007", "detection"): 2_510,
("val", "2007", "segmentation"): 213,
("val", "2008", "detection"): 2_221,
("val", "2008", "segmentation"): 512,
("val", "2009", "detection"): 3_581,
("val", "2009", "segmentation"): 750,
("val", "2010", "detection"): 5_105,
("val", "2010", "segmentation"): 964,
("val", "2011", "detection"): 5_823,
("val", "2011", "segmentation"): 1_111,
("val", "2012", "detection"): 5_823,
("val", "2012", "segmentation"): 1_449,
("trainval", "2007", "detection"): 5_011,
("trainval", "2007", "segmentation"): 422,
("trainval", "2008", "detection"): 4_332,
("trainval", "2008", "segmentation"): 1_023,
("trainval", "2009", "detection"): 7_054,
("trainval", "2009", "segmentation"): 1_499,
("trainval", "2010", "detection"): 10_103,
("trainval", "2010", "segmentation"): 1_928,
("trainval", "2011", "detection"): 11_540,
("trainval", "2011", "segmentation"): 2_223,
("trainval", "2012", "detection"): 11_540,
("trainval", "2012", "segmentation"): 2_913,
("test", "2007", "detection"): 4_952,
("test", "2007", "segmentation"): 210,
}[(self._split, self._year, self._task)]

def _filter_anns(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == self._Demux.ANNS

def _generate_categories(self) -> List[str]:
self._task = "detection"
resources = self._resources()

archive_dp = resources[0].load(self._root)
dp = Filter(archive_dp, self._filter_detection_anns)
dp = Mapper(dp, self._parse_detection_ann, input_col=1)

return sorted({instance["name"] for _, anns in dp for instance in anns["object"]})