Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate coco prototype #5473

Merged
12 changes: 9 additions & 3 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,15 @@ def generate(
return num_samples


# @register_mock
def coco(info, root, config):
return CocoMockData.generate(root, year=config.year, num_samples=5)
@register_mock(
configs=combinations_grid(
split=("train", "val"),
year=("2017", "2014"),
annotations=("instances", "captions", None),
)
)
def coco(root, config):
return CocoMockData.generate(root, year=config["year"], num_samples=5)


class SBDMockData:
Expand Down
145 changes: 82 additions & 63 deletions torchvision/prototype/datasets/_builtin/coco.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import functools
import pathlib
import re
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union

import torch
from torchdata.datapipes.iter import (
Expand All @@ -16,11 +16,10 @@
UnBatcher,
)
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
Dataset2,
)
from torchvision.prototype.datasets.utils._internal import (
MappingIterator,
Expand All @@ -32,27 +31,51 @@
hint_shuffling,
)
from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage
from torchvision.prototype.utils._internal import FrozenMapping


class Coco(Dataset):
def _make_info(self) -> DatasetInfo:
name = "coco"
categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories"))

return DatasetInfo(
name,
dependencies=("pycocotools",),
categories=categories,
homepage="https://cocodataset.org/",
valid_options=dict(
split=("train", "val"),
year=("2017", "2014"),
annotations=(*self._ANN_DECODERS.keys(), None),
),
extra=dict(category_to_super_category=FrozenMapping(zip(categories, super_categories))),

from .._api import register_dataset, register_info


NAME = "coco"


@register_info(NAME)
def _info() -> Dict[str, Any]:
categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories"))
return dict(categories=categories, super_categories=super_categories)


@register_dataset(NAME)
class Coco(Dataset2):
"""
- **homepage**: https://cocodataset.org/
pmeier marked this conversation as resolved.
Show resolved Hide resolved
- **dependencies**:
- <pycocotools `https://github.com/cocodataset/cocoapi`>_
"""

def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
year: str = "2017",
annotations: Optional[str] = "instances",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "val"})
self._year = self._verify_str_arg(year, "year", {"2017", "2014"})
self._annotations = (
self._verify_str_arg(annotations, "annotations", self._ANN_DECODERS.keys())
if annotations is not None
else None
)

info = _info()
categories, super_categories = info["categories"], info["super_categories"]
self._categories = categories
self._category_to_super_category = dict(zip(categories, super_categories))

super().__init__(root, dependencies=("pycocotools",), skip_integrity_check=skip_integrity_check)

_IMAGE_URL_BASE = "http://images.cocodataset.org/zips"

_IMAGES_CHECKSUMS = {
Expand All @@ -69,14 +92,14 @@ def _make_info(self) -> DatasetInfo:
"2017": "113a836d90195ee1f884e704da6304dfaaecff1f023f49b6ca93c4aaae470268",
}

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
def _resources(self) -> List[OnlineResource]:
images = HttpResource(
f"{self._IMAGE_URL_BASE}/{config.split}{config.year}.zip",
sha256=self._IMAGES_CHECKSUMS[(config.year, config.split)],
f"{self._IMAGE_URL_BASE}/{self._split}{self._year}.zip",
sha256=self._IMAGES_CHECKSUMS[(self._year, self._split)],
)
meta = HttpResource(
f"{self._META_URL_BASE}/annotations_trainval{config.year}.zip",
sha256=self._META_CHECKSUMS[config.year],
f"{self._META_URL_BASE}/annotations_trainval{self._year}.zip",
sha256=self._META_CHECKSUMS[self._year],
)
return [images, meta]

Expand Down Expand Up @@ -110,10 +133,8 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st
format="xywh",
image_size=image_size,
),
labels=Label(labels, categories=self.categories),
super_categories=[
self.info.extra.category_to_super_category[self.info.categories[label]] for label in labels
],
labels=Label(labels, categories=self._categories),
super_categories=[self._category_to_super_category[self._categories[label]] for label in labels],
ann_ids=[ann["id"] for ann in anns],
)

Expand All @@ -134,9 +155,14 @@ def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str,
fr"(?P<annotations>({'|'.join(_ANN_DECODERS.keys())}))_(?P<split>[a-zA-Z]+)(?P<year>\d+)[.]json"
)

def _filter_meta_files(self, data: Tuple[str, Any], *, split: str, year: str, annotations: str) -> bool:
def _filter_meta_files(self, data: Tuple[str, Any]) -> bool:
match = self._META_FILE_PATTERN.match(pathlib.Path(data[0]).name)
return bool(match and match["split"] == split and match["year"] == year and match["annotations"] == annotations)
return bool(
match
and match["split"] == self._split
and match["year"] == self._year
and match["annotations"] == self._annotations
)

def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]:
key, _ = data
Expand All @@ -157,38 +183,26 @@ def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
def _prepare_sample(
self,
data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]],
*,
annotations: str,
) -> Dict[str, Any]:
ann_data, image_data = data
anns, image_meta = ann_data

sample = self._prepare_image(image_data)
# this method is only called if we have annotations
annotations = cast(str, self._annotations)
sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta))
return sample

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
images_dp, meta_dp = resource_dps

if config.annotations is None:
if self._annotations is None:
dp = hint_shuffling(images_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._prepare_image)

meta_dp = Filter(
meta_dp,
functools.partial(
self._filter_meta_files,
split=config.split,
year=config.year,
annotations=config.annotations,
),
)
meta_dp = Filter(meta_dp, self._filter_meta_files)
meta_dp = JsonParser(meta_dp)
meta_dp = Mapper(meta_dp, getitem(1))
meta_dp: IterDataPipe[Dict[str, Dict[str, Any]]] = MappingIterator(meta_dp)
Expand Down Expand Up @@ -216,26 +230,31 @@ def _make_datapipe(
ref_key_fn=getitem("id"),
buffer_size=INFINITE_BUFFER_SIZE,
)

dp = IterKeyZipper(
anns_dp,
images_dp,
key_fn=getitem(1, "file_name"),
ref_key_fn=path_accessor("name"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)

def __len__(self) -> int:
return {
("train", "2017"): defaultdict(lambda: 118_287, instances=117_266),
("train", "2014"): defaultdict(lambda: 82_783, instances=82_081),
("val", "2017"): defaultdict(lambda: 5_000, instances=4_952),
("val", "2014"): defaultdict(lambda: 40_504, instances=40_137),
}[(self._split, self._year)][
self._annotations # type: ignore[index]
]

return Mapper(dp, functools.partial(self._prepare_sample, annotations=config.annotations))

def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]:
config = self.default_config
resources = self.resources(config)
def _generate_categories(self) -> Tuple[Tuple[str, str]]:
self._annotations = "instances"
resources = self._resources()

dp = resources[1].load(root)
dp = Filter(
dp,
functools.partial(self._filter_meta_files, split=config.split, year=config.year, annotations="instances"),
)
dp = resources[1].load(self._root)
dp = Filter(dp, self._filter_meta_files)
dp = JsonParser(dp)

_, meta = next(iter(dp))
Expand Down
14 changes: 12 additions & 2 deletions torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,17 @@ class ImageNetDemux(enum.IntEnum):

@register_dataset(NAME)
class ImageNet(Dataset2):
def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> None:
"""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small fix to ImageNet that I found while working on COCO.

- **homepage**: https://www.image-net.org/
"""

def __init__(
self,
root: Union[str, pathlib.Path],
*,
split: str = "train",
skip_integrity_check: bool = False,
) -> None:
self._split = self._verify_str_arg(split, "split", {"train", "val", "test"})

info = _info()
Expand All @@ -63,7 +73,7 @@ def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> N
self._wnids = wnids
self._wnid_to_category = dict(zip(wnids, categories))

super().__init__(root)
super().__init__(root, skip_integrity_check=skip_integrity_check)

_IMAGES_CHECKSUMS = {
"train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
split: str = "train",
year: str = "2012",
task: str = "detection",
**kwargs: Any,
skip_integrity_check: bool = False,
) -> None:
self._year = self._verify_str_arg(year, "year", ("2007", "2008", "2009", "2010", "2011", "2012"))
if split == "test" and year != "2007":
Expand All @@ -64,7 +64,7 @@ def __init__(

self._categories = _info()["categories"]

super().__init__(root, **kwargs)
super().__init__(root, skip_integrity_check=skip_integrity_check)

_TRAIN_VAL_ARCHIVES = {
"2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"),
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/generate_category_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def parse_args(argv=None):


if __name__ == "__main__":
args = parse_args(["-f", "imagenet"])
args = parse_args()

try:
main(*args.names, force=args.force)
Expand Down
13 changes: 12 additions & 1 deletion torchvision/prototype/datasets/utils/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,18 @@ def _verify_str_arg(
) -> str:
return verify_str_arg(value, arg, valid_values, custom_msg=custom_msg)

def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None:
def __init__(
self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False, dependencies: Collection[str] = ()
) -> None:
for dependency in dependencies:
try:
importlib.import_module(dependency)
except ModuleNotFoundError:
raise ModuleNotFoundError(
f"{type(self).__name__}() depends on the third-party package '{dependency}'. "
f"Please install it, for example with `pip install {dependency}`."
pmeier marked this conversation as resolved.
Show resolved Hide resolved
) from None

self._root = pathlib.Path(root).expanduser().resolve()
resources = [
resource.load(self._root, skip_integrity_check=skip_integrity_check) for resource in self._resources()
Expand Down