Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement imagenet prototype dataset as function #5565

Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pathlib
import pickle
import random
import unittest.mock
import xml.etree.ElementTree as ET
from collections import defaultdict, Counter

Expand All @@ -21,7 +20,6 @@
from torch.nn.functional import one_hot
from torch.testing import make_tensor as _make_tensor
from torchvision.prototype import datasets
from torchvision.prototype.utils._internal import sequence_to_str

make_tensor = functools.partial(_make_tensor, device="cpu")
make_scalar = functools.partial(make_tensor, ())
Expand Down Expand Up @@ -66,17 +64,17 @@ def prepare(self, home, config):

mock_info = self._parse_mock_info(self.mock_data_fn(root, config))

with unittest.mock.patch.object(datasets.utils.Dataset2, "__init__"):
required_file_names = {
resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources()
}
available_file_names = {path.name for path in root.glob("*")}
missing_file_names = required_file_names - available_file_names
if missing_file_names:
raise pytest.UsageError(
f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} "
f"for {config}, but they were not created by the mock data function."
)
# with unittest.mock.patch.object(datasets.utils.Dataset2, "__init__"):
# required_file_names = {
# resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources()
# }
# available_file_names = {path.name for path in root.glob("*")}
# missing_file_names = required_file_names - available_file_names
# if missing_file_names:
# raise pytest.UsageError(
# f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} "
# f"for {config}, but they were not created by the mock data function."
# )
Comment on lines +67 to +77
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the resources are now fully internal inside the dataset function, I don't see a good way to check if the mock data is set up correctly. One thing we could try is to patch

@abc.abstractmethod
def _download(self, root: pathlib.Path) -> None:
pass

on all subclasses to raise the error I commented out above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could load_images_dp be registered and mocked?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would require registering functions like that just for testing purposes. Not sure if we should go that way.

Copy link
Contributor

@ejguan ejguan Mar 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, how about making info including another dictionary of Resources? And, load_images_dp could load resource from info.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be a possibility, yes. cc @NicolasHug


return mock_info

Expand Down
9 changes: 5 additions & 4 deletions test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import Shuffler, ShardingFilter
from torchvision.prototype import transforms, datasets
from torchvision.prototype.datasets.utils._internal import TakerDataPipe
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NivekT If I recall correctly, we wanted to upstream this to torchdata. Any progress on that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will have a look at the implementation sometime today

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can upstream Taker in torchdata. But, just a heads-up, we are aligning the API on the functionality of this DataPipe with the internal team. The name may be changed to Limiter with functional API as limit when alignment is settled down.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please tag me in the PR so I can make the changes here after it is landed.

from torchvision.prototype.utils._internal import sequence_to_str


Expand Down Expand Up @@ -51,8 +52,10 @@ def test_smoke(self, test_home, dataset_mock, config):

dataset = datasets.load(dataset_mock.name, **config)

if not isinstance(dataset, datasets.utils.Dataset2):
raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.")
if not isinstance(dataset, TakerDataPipe):
raise AssertionError(
f"Loading the dataset should return an TakerDataPipe, but got {type(dataset)} instead."
)

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_sample(self, test_home, dataset_mock, config):
Expand Down Expand Up @@ -100,7 +103,6 @@ def test_transformable(self, test_home, dataset_mock, config):

next(iter(dataset.map(transforms.Identity())))

@pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237")
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_serializable(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
Expand All @@ -109,7 +111,6 @@ def test_serializable(self, test_home, dataset_mock, config):

pickle.dumps(dataset)

@pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237")
@parametrize_dataset_mocks(DATASET_MOCKS)
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):
Expand Down
17 changes: 9 additions & 8 deletions torchvision/prototype/datasets/_api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import pathlib
from typing import Any, Dict, List, Callable, Type, Optional, Union, TypeVar
from typing import Any, Dict, List, Callable, Optional, Union, TypeVar

from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.utils import Dataset2
from torchvision.prototype.datasets.utils._internal import TakerDataPipe
from torchvision.prototype.utils._internal import add_suggestion


T = TypeVar("T")
D = TypeVar("D", bound=Type[Dataset2])

BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {}

Expand All @@ -23,10 +22,12 @@ def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]:
BUILTIN_DATASETS = {}


def register_dataset(name: str) -> Callable[[D], D]:
def wrapper(dataset_cls: D) -> D:
BUILTIN_DATASETS[name] = dataset_cls
return dataset_cls
def register_dataset(
name: Optional[str] = None,
) -> Callable[[Callable[..., TakerDataPipe]], Callable[..., TakerDataPipe]]:
def wrapper(dataset_fn: Callable[..., TakerDataPipe]) -> Callable[..., TakerDataPipe]:
BUILTIN_DATASETS[name or dataset_fn.__name__] = dataset_fn
return dataset_fn

return wrapper

Expand Down Expand Up @@ -56,7 +57,7 @@ def info(name: str) -> Dict[str, Any]:
return find(BUILTIN_INFOS, name)


def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset2:
def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> TakerDataPipe:
dataset_cls = find(BUILTIN_DATASETS, name)

if root is None:
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import imagenet
pmeier marked this conversation as resolved.
Show resolved Hide resolved
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
from .cifar import Cifar10, Cifar100
Expand All @@ -9,7 +10,6 @@
from .eurosat import EuroSAT
from .fer2013 import FER2013
from .gtsrb import GTSRB
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .oxford_iiit_pet import OxfordIITPet
from .pcam import PCAM
Expand Down
Loading