Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup prototype dataset implementation #5774

Merged
merged 10 commits into from
Apr 7, 2022
2 changes: 1 addition & 1 deletion test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def prepare(self, home, config):

mock_info = self._parse_mock_info(self.mock_data_fn(root, config))

with unittest.mock.patch.object(datasets.utils.Dataset2, "__init__"):
with unittest.mock.patch.object(datasets.utils.Dataset, "__init__"):
required_file_names = {
resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources()
}
Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_smoke(self, test_home, dataset_mock, config):

dataset = datasets.load(dataset_mock.name, **config)

if not isinstance(dataset, datasets.utils.Dataset2):
if not isinstance(dataset, datasets.utils.Dataset):
raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.")

@parametrize_dataset_mocks(DATASET_MOCKS)
Expand Down
156 changes: 2 additions & 154 deletions test/test_prototype_datasets_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,163 +2,11 @@

import pytest
from torchvision.prototype import datasets
from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch


def make_minimal_dataset_info(name="name", categories=None, **kwargs):
return datasets.utils.DatasetInfo(name, categories=categories or [], **kwargs)


class TestFrozenMapping:
@pytest.mark.parametrize(
("args", "kwargs"),
[
pytest.param((dict(foo="bar", baz=1),), dict(), id="from_dict"),
pytest.param((), dict(foo="bar", baz=1), id="from_kwargs"),
pytest.param((dict(foo="bar"),), dict(baz=1), id="mixed"),
],
)
def test_instantiation(self, args, kwargs):
FrozenMapping(*args, **kwargs)

def test_unhashable_items(self):
with pytest.raises(TypeError, match="unhashable type"):
FrozenMapping(foo=[])

def test_getitem(self):
options = dict(foo="bar", baz=1)
config = FrozenMapping(options)

for key, value in options.items():
assert config[key] == value

def test_getitem_unknown(self):
with pytest.raises(KeyError):
FrozenMapping()["unknown"]

def test_iter(self):
options = dict(foo="bar", baz=1)
assert set(iter(FrozenMapping(options))) == set(options.keys())

def test_len(self):
options = dict(foo="bar", baz=1)
assert len(FrozenMapping(options)) == len(options)

def test_immutable_setitem(self):
frozen_mapping = FrozenMapping()

with pytest.raises(RuntimeError, match="immutable"):
frozen_mapping["foo"] = "bar"

def test_immutable_delitem(
self,
):
frozen_mapping = FrozenMapping(foo="bar")

with pytest.raises(RuntimeError, match="immutable"):
del frozen_mapping["foo"]

def test_eq(self):
options = dict(foo="bar", baz=1)
assert FrozenMapping(options) == FrozenMapping(options)

def test_ne(self):
options1 = dict(foo="bar", baz=1)
options2 = options1.copy()
options2["baz"] += 1

assert FrozenMapping(options1) != FrozenMapping(options2)

def test_repr(self):
options = dict(foo="bar", baz=1)
output = repr(FrozenMapping(options))

assert isinstance(output, str)
for key, value in options.items():
assert str(key) in output and str(value) in output


class TestFrozenBunch:
def test_getattr(self):
options = dict(foo="bar", baz=1)
config = FrozenBunch(options)

for key, value in options.items():
assert getattr(config, key) == value

def test_getattr_unknown(self):
with pytest.raises(AttributeError, match="no attribute 'unknown'"):
datasets.utils.DatasetConfig().unknown

def test_immutable_setattr(self):
frozen_bunch = FrozenBunch()

with pytest.raises(RuntimeError, match="immutable"):
frozen_bunch.foo = "bar"

def test_immutable_delattr(
self,
):
frozen_bunch = FrozenBunch(foo="bar")

with pytest.raises(RuntimeError, match="immutable"):
del frozen_bunch.foo

def test_repr(self):
options = dict(foo="bar", baz=1)
output = repr(FrozenBunch(options))

assert isinstance(output, str)
assert output.startswith("FrozenBunch")
for key, value in options.items():
assert f"{key}={value}" in output


class TestDatasetInfo:
@pytest.fixture
def info(self):
return make_minimal_dataset_info(valid_options=dict(split=("train", "test"), foo=("bar", "baz")))

def test_default_config(self, info):
valid_options = info._valid_options
default_config = datasets.utils.DatasetConfig({key: values[0] for key, values in valid_options.items()})

assert info.default_config == default_config

@pytest.mark.parametrize(
("valid_options", "options", "expected_error_msg"),
[
(dict(), dict(any_option=None), "does not take any options"),
(dict(split="train"), dict(unknown_option=None), "Unknown option 'unknown_option'"),
(dict(split="train"), dict(split="invalid_argument"), "Invalid argument 'invalid_argument'"),
],
)
def test_make_config_invalid_inputs(self, info, valid_options, options, expected_error_msg):
info = make_minimal_dataset_info(valid_options=valid_options)

with pytest.raises(ValueError, match=expected_error_msg):
info.make_config(**options)

def test_check_dependencies(self):
dependency = "fake_dependency"
info = make_minimal_dataset_info(dependencies=(dependency,))
with pytest.raises(ModuleNotFoundError, match=dependency):
info.check_dependencies()

def test_repr(self, info):
output = repr(info)

assert isinstance(output, str)
assert "DatasetInfo" in output
for key, value in info._valid_options.items():
assert f"{key}={str(value)[1:-1]}" in output

@pytest.mark.parametrize("optional_info", ("citation", "homepage", "license"))
def test_repr_optional_info(self, optional_info):
sentinel = "sentinel"
info = make_minimal_dataset_info(**{optional_info: sentinel})

assert f"{optional_info}={sentinel}" in repr(info)
# TODO: remove this?
return dict(categories=categories or [], **kwargs)


class TestDataset:
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/datasets/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import Any, Dict, List, Callable, Type, Optional, Union, TypeVar

from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.utils import Dataset2
from torchvision.prototype.datasets.utils import Dataset
from torchvision.prototype.utils._internal import add_suggestion


T = TypeVar("T")
D = TypeVar("D", bound=Type[Dataset2])
D = TypeVar("D", bound=Type[Dataset])

BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {}

Expand Down Expand Up @@ -56,7 +56,7 @@ def info(name: str) -> Dict[str, Any]:
return find(BUILTIN_INFOS, name)


def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset2:
def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset:
dataset_cls = find(BUILTIN_DATASETS, name)

if root is None:
Expand Down
10 changes: 5 additions & 5 deletions torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Filter,
IterKeyZipper,
)
from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource, read_categories_file
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
read_mat,
Expand All @@ -22,7 +22,7 @@
from .._api import register_dataset, register_info


CALTECH101_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "caltech101.categories"))
CALTECH101_CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / "caltech101.categories"))


@register_info("caltech101")
Expand All @@ -31,7 +31,7 @@ def _caltech101_info() -> Dict[str, Any]:


@register_dataset("caltech101")
class Caltech101(Dataset2):
class Caltech101(Dataset):
"""
- **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech101
- **dependencies**:
Expand Down Expand Up @@ -152,7 +152,7 @@ def _generate_categories(self) -> List[str]:
return sorted({pathlib.Path(path).parent.name for path, _ in dp})


CALTECH256_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "caltech256.categories"))
CALTECH256_CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / "caltech256.categories"))


@register_info("caltech256")
Expand All @@ -161,7 +161,7 @@ def _caltech256_info() -> Dict[str, Any]:


@register_dataset("caltech256")
class Caltech256(Dataset2):
class Caltech256(Dataset):
"""
- **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech256
"""
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
IterKeyZipper,
)
from torchvision.prototype.datasets.utils import (
Dataset2,
Dataset,
GDriveResource,
OnlineResource,
)
Expand Down Expand Up @@ -68,7 +68,7 @@ def _info() -> Dict[str, Any]:


@register_dataset(NAME)
class CelebA(Dataset2):
class CelebA(Dataset):
"""
- **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
"""
Expand Down
8 changes: 4 additions & 4 deletions torchvision/prototype/datasets/_builtin/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
Filter,
Mapper,
)
from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource, read_categories_file
from torchvision.prototype.datasets.utils._internal import hint_shuffling, path_comparator, hint_sharding, BUILTIN_DIR
from torchvision.prototype.features import Label, Image

Expand All @@ -29,7 +29,7 @@ def __iter__(self) -> Iterator[Tuple[np.ndarray, int]]:
yield from iter(zip(image_arrays, category_idcs))


class _CifarBase(Dataset2):
class _CifarBase(Dataset):
_FILE_NAME: str
_SHA256: str
_LABELS_KEY: str
Expand Down Expand Up @@ -92,7 +92,7 @@ def _generate_categories(self) -> List[str]:
return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY])


CIFAR10_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "cifar10.categories"))
CIFAR10_CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / "cifar10.categories"))


@register_info("cifar10")
Expand All @@ -118,7 +118,7 @@ def _is_data_file(self, data: Tuple[str, Any]) -> bool:
return path.name.startswith("data" if self._split == "train" else "test")


CIFAR100_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "cifar100.categories"))
CIFAR100_CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / "cifar100.categories"))


@register_info("cifar100")
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/clevr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union

from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher
from torchvision.prototype.datasets.utils import Dataset2, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
hint_sharding,
Expand All @@ -24,7 +24,7 @@ def _info() -> Dict[str, Any]:


@register_dataset(NAME)
class CLEVR(Dataset2):
class CLEVR(Dataset):
"""
- **homepage**: https://cs.stanford.edu/people/jcjohns/clevr/
"""
Expand Down
8 changes: 4 additions & 4 deletions torchvision/prototype/datasets/_builtin/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
UnBatcher,
)
from torchvision.prototype.datasets.utils import (
DatasetInfo,
HttpResource,
OnlineResource,
Dataset2,
Dataset,
read_categories_file,
)
from torchvision.prototype.datasets.utils._internal import (
MappingIterator,
Expand All @@ -40,12 +40,12 @@

@register_info(NAME)
def _info() -> Dict[str, Any]:
categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories"))
categories, super_categories = zip(*read_categories_file(BUILTIN_DIR / f"{NAME}.categories"))
return dict(categories=categories, super_categories=super_categories)


@register_dataset(NAME)
class Coco(Dataset2):
class Coco(Dataset):
"""
- **homepage**: https://cocodataset.org/
- **dependencies**:
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/datasets/_builtin/country211.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from typing import Any, Dict, List, Tuple, Union

from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter
from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource, read_categories_file
from torchvision.prototype.datasets.utils._internal import path_comparator, hint_sharding, hint_shuffling, BUILTIN_DIR
from torchvision.prototype.features import EncodedImage, Label

from .._api import register_dataset, register_info

NAME = "country211"

CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories"))
CATEGORIES, *_ = zip(*read_categories_file(BUILTIN_DIR / f"{NAME}.categories"))


@register_info(NAME)
Expand All @@ -19,7 +19,7 @@ def _info() -> Dict[str, Any]:


@register_dataset(NAME)
class Country211(Dataset2):
class Country211(Dataset):
"""
- **homepage**: https://github.com/openai/CLIP/blob/main/data/country211.md
"""
Expand Down
Loading