Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge DatasetInfo into Dataset #5369

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/config.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ sphinx-copybutton>=0.3.1
sphinx-gallery>=0.9.0
sphinx==3.5.4
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
tabulate
91 changes: 90 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@
# import sys
# sys.path.insert(0, os.path.abspath('.'))

import pathlib
import textwrap

import pytorch_sphinx_theme
import torchvision

from tabulate import tabulate
from torchvision.prototype import datasets

# -- General configuration ------------------------------------------------

Expand Down Expand Up @@ -287,3 +291,88 @@ def inject_minigalleries(app, what, name, obj, options, lines):

def setup(app):
app.connect("autodoc-process-docstring", inject_minigalleries)


def generate_prototype_datasets_doc():
root = pathlib.Path(__file__).parent / "datasets_prototype"
# TODO: should be datasets.list()
names = ["imagenet", "voc"]
for name in names:
info = datasets.find(name)
with open(root / f"{name}.rst", "w") as file:
file.write(f"{name}\n{'=' * len(name)}\n\n")

if info.description:
file.write(f"{info.description}\n\n")

properties = []

if info.homepage:
properties.append(("Homepage", info.homepage))

if info.dependecies:
properties.append(
(
"Dependencies",
", ".join(
f"`{dependency} <https://pypi.org/project/{dependency}/>`_"
for dependency in info.dependecies
),
)
)

if properties:
file.write("General\n")
file.write("-------\n\n")

for key, value in properties:
file.write(f"- **{key}**: {value}\n")

file.write("\n")

if info.options:
file.write("Options\n")
file.write("-------\n\n")

table = tabulate(
[(option.name, f":class:`{option.annotation.__name__}`", option.doc) for option in info.options],
tablefmt="rst",
)
file.write(".. table::\n")
file.write(" :widths: 15 15 70\n\n")
file.write(f"{textwrap.indent(table, ' ' * 4)}\n\n")

if info.attributes:
file.write("Attributes\n")
file.write("----------\n\n")

table = tabulate([(attr, doc) for attr, doc in info.attributes.items()], tablefmt="rst")
file.write(".. table::\n")
file.write(" :widths: 20 80\n\n")
file.write(f"{textwrap.indent(table, ' ' * 4)}\n\n")

# TODO: this should probably be generated by jinja
with open(root / "index.rst", "w") as file:
file.write(
"\n".join(
[
"Prototype Datasets",
"==================",
"",
"API",
"---",
"",
"",
"Builtin datasets",
"----------------",
"",
".. toctree::",
" :maxdepth: 1",
"",
*[f" {name}" for name in names],
]
)
)


generate_prototype_datasets_doc()
1 change: 1 addition & 0 deletions docs/source/datasets_prototype/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.rst
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ architectures, and common image transformations for computer vision.
transforms
models
datasets
datasets_prototype/index
utils
ops
io
Expand Down
22 changes: 17 additions & 5 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections.abc
import contextlib
import csv
import functools
import gzip
Expand Down Expand Up @@ -31,12 +32,20 @@

class DatasetMock:
def __init__(self, name, mock_data_fn):
if name not in ("imagenet", "voc"):
self.name = name
return

self.dataset = find(name)
self.info = self.dataset.info
self.name = self.info.name
self.name = self.dataset.name

self.mock_data_fn = mock_data_fn
self.configs = self.info._configs

self.configs = []
for combination in itertools.product(*[option.valid for option in self.dataset.options]):
options = dict(zip([option.name for option in self.dataset.options], combination))
with contextlib.suppress(Exception):
self.configs.append(self.dataset.make_config(**options))
Comment on lines +44 to +48
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only part of the test suite changes that needs reviewing. All other changes are just hacks to temporary only test VOC and ImageNet.


def _parse_mock_info(self, mock_info):
if mock_info is None:
Expand Down Expand Up @@ -65,7 +74,8 @@ def prepare(self, home, config):
root = home / self.name
root.mkdir(exist_ok=True)

mock_info = self._parse_mock_info(self.mock_data_fn(self.info, root, config))
# TODO: rename all the parameters in the mock_data_fn's info -> dataset
mock_info = self._parse_mock_info(self.mock_data_fn(self.dataset, root, config))

available_file_names = {path.name for path in root.glob("*")}
required_file_names = {resource.file_name for resource in self.dataset.resources(config)}
Expand Down Expand Up @@ -110,6 +120,8 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None):
elif not isinstance(marks, collections.abc.Mapping):
raise pytest.UsageError()

dataset_mocks = {name: dataset_mocks[name] for name in ("imagenet", "voc") if name in dataset_mocks}

return pytest.mark.parametrize(
("dataset_mock", "config"),
[
Expand Down Expand Up @@ -432,7 +444,7 @@ def caltech256(info, root, config):

@register_mock
def imagenet(info, root, config):
wnids = tuple(info.extra.wnid_to_category.keys())
wnids = tuple(info.wnid_to_category.keys())
if config.split == "train":
images_root = root / "ILSVRC2012_img_train"

Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def load(
dataset = find(name)

if decoder is DEFAULT_DECODER:
decoder = DEFAULT_DECODER_MAP.get(dataset.info.type)
decoder = DEFAULT_DECODER_MAP.get(dataset.type)

config = dataset.info.make_config(**options)
config = dataset.make_config(**options)
root = os.path.join(home(), dataset.name)

return dataset.load(root, config=config, decoder=decoder, skip_integrity_check=skip_integrity_check)
31 changes: 16 additions & 15 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import Cifar10, Cifar100
from .clevr import CLEVR
from .coco import Coco
from .cub200 import CUB200
from .dtd import DTD
from .fer2013 import FER2013
from .gtsrb import GTSRB
# from .caltech import Caltech101, Caltech256
# from .celeba import CelebA
# from .cifar import Cifar10, Cifar100
# from .clevr import CLEVR
# from .coco import Coco
# from .cub200 import CUB200
# from .dtd import DTD
# from .fer2013 import FER2013
# from .gtsrb import GTSRB
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .oxford_iiit_pet import OxfordIITPet
from .pcam import PCAM
from .sbd import SBD
from .semeion import SEMEION
from .svhn import SVHN

# from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
# from .oxford_iiit_pet import OxfordIITPet
# from .pcam import PCAM
# from .sbd import SBD
# from .semeion import SEMEION
# from .svhn import SVHN
from .voc import VOC
61 changes: 36 additions & 25 deletions torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
DatasetOption,
OnlineResource,
ManualDownloadResource,
DatasetType,
Expand All @@ -25,7 +25,6 @@
hint_shuffling,
)
from torchvision.prototype.features import Label
from torchvision.prototype.utils._internal import FrozenMapping


class ImageNetResource(ManualDownloadResource):
Expand All @@ -34,41 +33,53 @@ def __init__(self, **kwargs: Any) -> None:


class ImageNet(Dataset):
def _make_info(self) -> DatasetInfo:
def __init__(self):
name = "imagenet"
categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories"))

return DatasetInfo(
categories, wnids = zip(*self.read_categories_file(BUILTIN_DIR / f"{name}.categories"))
super().__init__(
name,
DatasetOption("split", ("train", "val", "test")),
type=DatasetType.IMAGE,
description="""
The ImageNet dataset contains 14,197,122 annotated images according to the WordNet hierarchy. Since 2010
the dataset is used in the ImageNet Large Scale Visual Recognition Challenge (ILSVRC), a benchmark in image
classification and object detection. The publicly released dataset contains a set of manually annotated
training images. A set of test images is also released, with the manual annotations withheld. ILSVRC
annotations fall into one of two categories: (1) image-level annotation of a binary label for the presence
or absence of an object class in the image, e.g., "there are cars in this image" but "there are no tigers,"
and (2) object-level annotation of a tight bounding box and class label around an object instance in the
image, e.g., "there is a screwdriver centered at position (20,25) with width of 50 pixels and height of
30 pixels". The ImageNet project does not own the copyright of the images, therefore only thumbnails and
URLs of images are provided.

- Total number of non-empty WordNet synsets: 21841
- Total number of images: 14197122
- Number of images with bounding box annotations: 1,034,908
- Number of synsets with SIFT features: 1000
- Number of images with SIFT features: 1.2 million
""",
dependencies=("scipy",),
categories=categories,
homepage="https://www.image-net.org/",
valid_options=dict(split=("train", "val", "test")),
extra=dict(
wnid_to_category=FrozenMapping(zip(wnids, categories)),
category_to_wnid=FrozenMapping(zip(categories, wnids)),
sizes=FrozenMapping(
[
(DatasetConfig(split="train"), 1_281_167),
(DatasetConfig(split="val"), 50_000),
(DatasetConfig(split="test"), 100_000),
]
),
attributes=dict(
wnid_to_category="Mapping for WordNet IDs to human readable categories.",
category_to_wnid="Mapping for human readable categories to WordNet IDs.",
),
)
# TODO: handle num_samples
# sizes = FrozenMapping(
# [
# (DatasetConfig(split="train"), 1_281_167),
# (DatasetConfig(split="val"), 50_000),
# (DatasetConfig(split="test"), 100_000),
# ]
# ),
self.wnid_to_category = dict(zip(wnids, categories))
self.category_to_wnid = dict(zip(categories, wnids))

def supports_sharded(self) -> bool:
return True

@property
def category_to_wnid(self) -> Dict[str, str]:
return cast(Dict[str, str], self.info.extra.category_to_wnid)

@property
def wnid_to_category(self) -> Dict[str, str]:
return cast(Dict[str, str], self.info.extra.wnid_to_category)

_IMAGES_CHECKSUMS = {
"train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
"val": "c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0",
Expand Down
44 changes: 25 additions & 19 deletions torchvision/prototype/datasets/_builtin/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
DatasetOption,
)
from torchvision.prototype.datasets.utils._internal import (
path_accessor,
Expand All @@ -33,10 +33,29 @@
from torchvision.prototype.features import BoundingBox


class VOCDatasetInfo(DatasetInfo):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self._configs = tuple(config for config in self._configs if config.split != "test" or config.year == "2007")
class VOC(Dataset):
def __init__(self):
super().__init__(
"voc",
DatasetOption(
"split",
("train", "val", "trainval", "test"),
doc="{options} ``'test'`` is only available for ``year='2007'``.",
),
DatasetOption("year", ("2007", "2008", "2009", "2010", "2011", "2012"), default="2012"),
DatasetOption("task", valid=("detection", "segmentation")),
type=DatasetType.IMAGE,
description="""
The PASCAL Visual Object Classes (VOC) 2012 dataset contains 20 object categories including vehicles,
household, animals, and other: aeroplane, bicycle, boat, bus, car, motorbike, train, bottle, chair, dining
table, potted plant, sofa, TV/monitor, bird, cat, cow, dog, horse, sheep, and person. Each image in this
dataset has pixel-level segmentation annotations, bounding box annotations, and object class annotations.
This dataset has been widely used as a benchmark for object detection, semantic segmentation, and
classification tasks. The PASCAL VOC dataset is split into three subsets: 1,464 images for training, 1,449
images for validation and a private testing set.
""",
homepage="http://host.robots.ox.ac.uk/pascal/VOC/",
)

def make_config(self, **options: Any) -> DatasetConfig:
config = super().make_config(**options)
Expand All @@ -45,20 +64,6 @@ def make_config(self, **options: Any) -> DatasetConfig:

return config


class VOC(Dataset):
def _make_info(self) -> DatasetInfo:
return VOCDatasetInfo(
"voc",
type=DatasetType.IMAGE,
homepage="http://host.robots.ox.ac.uk/pascal/VOC/",
valid_options=dict(
split=("train", "val", "trainval", "test"),
year=("2012", "2007", "2008", "2009", "2010", "2011"),
task=("detection", "segmentation"),
),
)

_TRAIN_VAL_ARCHIVES = {
"2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"),
"2008": ("VOCtrainval_14-Jul-2008.tar", "7f0ca53c1b5a838fbe946965fc106c6e86832183240af5c88e3f6c306318d42e"),
Expand Down Expand Up @@ -158,4 +163,5 @@ def _make_datapipe(
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)

return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder))
Loading