Skip to content
Merged
40 changes: 18 additions & 22 deletions test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,31 @@ def test_coverage():
)


# TODO: replace this with a simple call to datasets.load() as soon all datasets are migrated and thus datasets.load2()
# can be merged into datasets.load()
def _load_dataset(name, **options):
try:
return datasets.load2(name, **options)
except ValueError:
return datasets.load(name, **options)


class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_smoke(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)

# TODO: check for Dataset2 after all datasets are migrated
if not isinstance(dataset, IterDataPipe):
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_sample(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)
dataset = _load_dataset(dataset_mock.name, **config)

try:
sample = next(iter(dataset))
Expand All @@ -65,32 +75,19 @@ def test_sample(self, test_home, dataset_mock, config):
def test_num_samples(self, test_home, dataset_mock, config):
mock_info = dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)
dataset = _load_dataset(dataset_mock.name, **config)

num_samples = 0
for _ in dataset:
num_samples += 1

assert num_samples == mock_info["num_samples"]

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_decoding(self, test_home, dataset_mock, config):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed this test since it is obsolete after #5287.

dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)

undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)}
if undecoded_features:
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded."
)

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)
dataset = _load_dataset(dataset_mock.name, **config)

vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
if vanilla_tensors:
Expand All @@ -103,7 +100,7 @@ def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
def test_transformable(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)
dataset = _load_dataset(dataset_mock.name, **config)

next(iter(dataset.map(transforms.Identity())))

Expand Down Expand Up @@ -138,16 +135,15 @@ def scan(graph):
yield from scan(sub_graph)

dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)
dataset = _load_dataset(dataset_mock.name, **config)

if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))):
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_save_load(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
dataset = _load_dataset(dataset_mock.name, **config)
sample = next(iter(dataset))

with io.BytesIO() as buffer:
Expand All @@ -161,7 +157,7 @@ class TestQMNIST:
def test_extra_label(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)
dataset = _load_dataset(dataset_mock.name, **config)

sample = next(iter(dataset))
for key, type in (
Expand All @@ -186,7 +182,7 @@ def test_label_matches_path(self, test_home, dataset_mock, config):

dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)
dataset = _load_dataset(dataset_mock.name, **config)

for sample in dataset:
label_from_path = int(Path(sample["path"]).parent.name)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
from ._home import home

# Load this last, since some parts depend on the above being loaded first
from ._api import list_datasets, info, load # usort: skip
from ._api import list_datasets, info, load, register_info, register_dataset, load2 # usort: skip
from ._folder import from_data_folder, from_image_folder
63 changes: 55 additions & 8 deletions torchvision/prototype/datasets/_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import os
from typing import Any, Dict, List
import pathlib
from typing import Any, Dict, List, Callable, Type, Optional, Union

from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, Dataset2
from torchvision.prototype.utils._internal import add_suggestion

from . import _builtin

DATASETS: Dict[str, Dataset] = {}

Expand All @@ -15,11 +15,6 @@ def register(dataset: Dataset) -> None:
DATASETS[dataset.name] = dataset


for name, obj in _builtin.__dict__.items():
if not name.startswith("_") and isinstance(obj, type) and issubclass(obj, Dataset) and obj is not Dataset:
register(obj())


def list_datasets() -> List[str]:
return sorted(DATASETS.keys())

Expand Down Expand Up @@ -57,3 +52,55 @@ def load(
root = os.path.join(home(), dataset.name)

return dataset.load(root, config=config, skip_integrity_check=skip_integrity_check)


BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {}


def register_info(name: str) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]:
def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]:
BUILTIN_INFOS[name] = fn()
return fn

return wrapper


def info2(name: str) -> Dict[str, Any]:
try:
return BUILTIN_INFOS[name]
except KeyError:
raise ValueError


BUILTIN_DATASETS = {}


def register_dataset(name: str) -> Callable[[Type], Type]:
def wrapper(dataset_cls: Type) -> Type:
if not issubclass(dataset_cls, Dataset2):
raise TypeError

BUILTIN_DATASETS[name] = dataset_cls

return dataset_cls

return wrapper


def load2(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **options: Any) -> Dataset2:
try:
dataset_cls = BUILTIN_DATASETS[name]
except KeyError:
raise ValueError

if root is None:
root = pathlib.Path(home()) / name

return dataset_cls(root, **options)


from . import _builtin

for name, obj in _builtin.__dict__.items():
if not name.startswith("_") and isinstance(obj, type) and issubclass(obj, Dataset) and obj is not Dataset:
register(obj())
50 changes: 49 additions & 1 deletion torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import functools
import pathlib
import re
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast, Union

from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, Filter, Demultiplexer
from torchdata.datapipes.iter import TarArchiveReader
Expand All @@ -11,6 +12,7 @@
DatasetInfo,
OnlineResource,
ManualDownloadResource,
Dataset2,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
Expand All @@ -25,6 +27,8 @@
from torchvision.prototype.features import Label, EncodedImage
from torchvision.prototype.utils._internal import FrozenMapping

from .._api import register_dataset, register_info


class ImageNetResource(ManualDownloadResource):
def __init__(self, **kwargs: Any) -> None:
Expand Down Expand Up @@ -201,3 +205,47 @@ def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]:
categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp)))
categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])
return categories_and_wnids


NAME = "imagenet"


@register_info(NAME)
def _info() -> Dict[str, Any]:
categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories"))
return dict(categories=categories, wnids=wnids)


@register_dataset(NAME)
class ImageNet2(Dataset2):
def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> None:
if split not in {"train", "val", "test"}:
raise ValueError
self._split = split

info = _info()
categories, wnids = info["categories"], info["wnids"]

self._old_style_dataset = ImageNet()
self._old_style_config = self._old_style_dataset.info.make_config(split=self._split)

self.categories = categories
self.info = SimpleNamespace(
wnid_to_category=zip(wnids, categories),
category_to_wnid=zip(categories, wnids),
)

super().__init__(root)

def _resources(self) -> List[OnlineResource]:
return self._old_style_dataset.resources(self._old_style_config)

def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
return self._old_style_dataset._make_datapipe(resource_dps, config=self._old_style_config)

def __len__(self) -> int:
return {
"train": 1_281_167,
"val": 50_000,
"test": 100_000,
}[self._split]
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import _internal # usort: skip
from ._dataset import DatasetConfig, DatasetInfo, Dataset
from ._dataset import DatasetConfig, DatasetInfo, Dataset, Dataset2
from ._query import SampleQuery
from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource
26 changes: 25 additions & 1 deletion torchvision/prototype/datasets/utils/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
import os
import pathlib
from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection
from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection, Iterator

from torch.utils.data import IterDataPipe
from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion, sequence_to_str
Expand Down Expand Up @@ -181,3 +181,27 @@ def load(

def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]:
raise NotImplementedError


class Dataset2(IterDataPipe[Dict[str, Any]], abc.ABC):
def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None:
self._root = pathlib.Path(root).expanduser().resolve()
resources = [
resource.load(self._root, skip_integrity_check=skip_integrity_check) for resource in self._resources()
]
self._dp = self._datapipe(resources)

def __iter__(self) -> Iterator[Dict[str, Any]]:
yield from self._dp

@abc.abstractmethod
def _resources(self) -> List[OnlineResource]:
pass

@abc.abstractmethod
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
pass

@abc.abstractmethod
def __len__(self) -> int:
pass