Skip to content

Commit

Permalink
migrate coco prototype (#5473)
Browse files Browse the repository at this point in the history
* migrate coco prototype

* revert unrelated change

* add kwargs to super constructor call

* remove unneeded changes

* fix docstring position

* make kwargs explicit

* add dependencies to docstring

* fix missing dependency message
  • Loading branch information
pmeier authored Apr 6, 2022
1 parent 2ed549d commit 42bc682
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 72 deletions.
12 changes: 9 additions & 3 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,15 @@ def generate(
return num_samples


# @register_mock
def coco(info, root, config):
return CocoMockData.generate(root, year=config.year, num_samples=5)
@register_mock(
configs=combinations_grid(
split=("train", "val"),
year=("2017", "2014"),
annotations=("instances", "captions", None),
)
)
def coco(root, config):
return CocoMockData.generate(root, year=config["year"], num_samples=5)


class SBDMockData:
Expand Down
145 changes: 82 additions & 63 deletions torchvision/prototype/datasets/_builtin/coco.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import functools
import pathlib
import re
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union

import torch
from torchdata.datapipes.iter import (
Expand All @@ -16,11 +16,10 @@
UnBatcher,
)
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
Dataset2,
)
from torchvision.prototype.datasets.utils._internal import (
MappingIterator,
Expand All @@ -32,27 +31,51 @@
hint_shuffling,
)
from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage
from torchvision.prototype.utils._internal import FrozenMapping


class Coco(Dataset):
def _make_info(self) -> DatasetInfo:
name = "coco"
categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories"))

return DatasetInfo(
name,
dependencies=("pycocotools",),
categories=categories,
homepage="https://cocodataset.org/",
valid_options=dict(
split=("train", "val"),
year=("2017", "2014"),
annotations=(*self._ANN_DECODERS.keys(), None),
),
extra=dict(category_to_super_category=FrozenMapping(zip(categories, super_categories))),

from .._api import register_dataset, register_info


NAME = "coco"


@register_info(NAME)
def _info() -> Dict[str, Any]:
categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories"))
return dict(categories=categories, super_categories=super_categories)


@register_dataset(NAME)
class Coco(Dataset2):
"""
- **homepage**: https://cocodataset.org/
- **dependencies**:
- <pycocotools `https://github.com/cocodataset/cocoapi`>_
"""

def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
year: str = "2017",
annotations: Optional[str] = "instances",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "val"})
self._year = self._verify_str_arg(year, "year", {"2017", "2014"})
self._annotations = (
self._verify_str_arg(annotations, "annotations", self._ANN_DECODERS.keys())
if annotations is not None
else None
)

info = _info()
categories, super_categories = info["categories"], info["super_categories"]
self._categories = categories
self._category_to_super_category = dict(zip(categories, super_categories))

super().__init__(root, dependencies=("pycocotools",), skip_integrity_check=skip_integrity_check)

_IMAGE_URL_BASE = "http://images.cocodataset.org/zips"

_IMAGES_CHECKSUMS = {
Expand All @@ -69,14 +92,14 @@ def _make_info(self) -> DatasetInfo:
"2017": "113a836d90195ee1f884e704da6304dfaaecff1f023f49b6ca93c4aaae470268",
}

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def _resources(self) -> List[OnlineResource]:
images = HttpResource(
f"{self._IMAGE_URL_BASE}/{config.split}{config.year}.zip",
sha256=self._IMAGES_CHECKSUMS[(config.year, config.split)],
f"{self._IMAGE_URL_BASE}/{self._split}{self._year}.zip",
sha256=self._IMAGES_CHECKSUMS[(self._year, self._split)],
)
meta = HttpResource(
f"{self._META_URL_BASE}/annotations_trainval{config.year}.zip",
sha256=self._META_CHECKSUMS[config.year],
f"{self._META_URL_BASE}/annotations_trainval{self._year}.zip",
sha256=self._META_CHECKSUMS[self._year],
)
return [images, meta]

Expand Down Expand Up @@ -110,10 +133,8 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st
format="xywh",
image_size=image_size,
),
labels=Label(labels, categories=self.categories),
super_categories=[
self.info.extra.category_to_super_category[self.info.categories[label]] for label in labels
],
labels=Label(labels, categories=self._categories),
super_categories=[self._category_to_super_category[self._categories[label]] for label in labels],
ann_ids=[ann["id"] for ann in anns],
)

Expand All @@ -134,9 +155,14 @@ def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str,
fr"(?P<annotations>({'|'.join(_ANN_DECODERS.keys())}))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json"
)

def _filter_meta_files(self, data: Tuple[str, Any], *, split: str, year: str, annotations: str) -> bool:
def _filter_meta_files(self, data: Tuple[str, Any]) -> bool:
match = self._META_FILE_PATTERN.match(pathlib.Path(data[0]).name)
return bool(match and match["split"] == split and match["year"] == year and match["annotations"] == annotations)
return bool(
match
and match["split"] == self._split
and match["year"] == self._year
and match["annotations"] == self._annotations
)

def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]:
key, _ = data
Expand All @@ -157,38 +183,26 @@ def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
def _prepare_sample(
self,
data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]],
*,
annotations: str,
) -> Dict[str, Any]:
ann_data, image_data = data
anns, image_meta = ann_data

sample = self._prepare_image(image_data)
# this method is only called if we have annotations
annotations = cast(str, self._annotations)
sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta))
return sample

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, meta_dp = resource_dps

if config.annotations is None:
if self._annotations is None:
dp = hint_shuffling(images_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._prepare_image)

meta_dp = Filter(
meta_dp,
functools.partial(
self._filter_meta_files,
split=config.split,
year=config.year,
annotations=config.annotations,
),
)
meta_dp = Filter(meta_dp, self._filter_meta_files)
meta_dp = JsonParser(meta_dp)
meta_dp = Mapper(meta_dp, getitem(1))
meta_dp: IterDataPipe[Dict[str, Dict[str, Any]]] = MappingIterator(meta_dp)
Expand Down Expand Up @@ -216,26 +230,31 @@ def _make_datapipe(
ref_key_fn=getitem("id"),
buffer_size=INFINITE_BUFFER_SIZE,
)

dp = IterKeyZipper(
anns_dp,
images_dp,
key_fn=getitem(1, "file_name"),
ref_key_fn=path_accessor("name"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)

def __len__(self) -> int:
return {
("train", "2017"): defaultdict(lambda: 118_287, instances=117_266),
("train", "2014"): defaultdict(lambda: 82_783, instances=82_081),
("val", "2017"): defaultdict(lambda: 5_000, instances=4_952),
("val", "2014"): defaultdict(lambda: 40_504, instances=40_137),
}[(self._split, self._year)][
self._annotations # type: ignore[index]
]

return Mapper(dp, functools.partial(self._prepare_sample, annotations=config.annotations))

def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]:
config = self.default_config
resources = self.resources(config)
def _generate_categories(self) -> Tuple[Tuple[str, str]]:
self._annotations = "instances"
resources = self._resources()

dp = resources[1].load(root)
dp = Filter(
dp,
functools.partial(self._filter_meta_files, split=config.split, year=config.year, annotations="instances"),
)
dp = resources[1].load(self._root)
dp = Filter(dp, self._filter_meta_files)
dp = JsonParser(dp)

_, meta = next(iter(dp))
Expand Down
14 changes: 12 additions & 2 deletions torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,17 @@ class ImageNetDemux(enum.IntEnum):

@register_dataset(NAME)
class ImageNet(Dataset2):
def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> None:
"""
- **homepage**: https://www.image-net.org/
"""

def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "val", "test"})

info = _info()
Expand All @@ -63,7 +73,7 @@ def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> N
self._wnids = wnids
self._wnid_to_category = dict(zip(wnids, categories))

super().__init__(root)
super().__init__(root, skip_integrity_check=skip_integrity_check)

_IMAGES_CHECKSUMS = {
"train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
split: str = "train",
year: str = "2012",
task: str = "detection",
**kwargs: Any,
skip_integrity_check: bool = False,
) -> None:
self._year = self._verify_str_arg(year, "year", ("2007", "2008", "2009", "2010", "2011", "2012"))
if split == "test" and year != "2007":
Expand All @@ -64,7 +64,7 @@ def __init__(

self._categories = _info()["categories"]

super().__init__(root, **kwargs)
super().__init__(root, skip_integrity_check=skip_integrity_check)

_TRAIN_VAL_ARCHIVES = {
"2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"),
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/generate_category_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def parse_args(argv=None):


if __name__ == "__main__":
args = parse_args(["-f", "imagenet"])
args = parse_args()

try:
main(*args.names, force=args.force)
Expand Down
13 changes: 12 additions & 1 deletion torchvision/prototype/datasets/utils/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,18 @@ def _verify_str_arg(
) -> str:
return verify_str_arg(value, arg, valid_values, custom_msg=custom_msg)

def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None:
def __init__(
self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False, dependencies: Collection[str] = ()
) -> None:
for dependency in dependencies:
try:
importlib.import_module(dependency)
except ModuleNotFoundError:
raise ModuleNotFoundError(
f"{type(self).__name__}() depends on the third-party package '{dependency}'. "
f"Please install it, for example with `pip install {dependency}`."
) from None

self._root = pathlib.Path(root).expanduser().resolve()
resources = [
resource.load(self._root, skip_integrity_check=skip_integrity_check) for resource in self._resources()
Expand Down

0 comments on commit 42bc682

Please sign in to comment.