Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup prototype dataset implementation #5774

Merged
merged 10 commits into from
Apr 7, 2022
2 changes: 1 addition & 1 deletion test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def prepare(self, home, config):

mock_info = self._parse_mock_info(self.mock_data_fn(root, config))

with unittest.mock.patch.object(datasets.utils.Dataset2, "__init__"):
with unittest.mock.patch.object(datasets.utils.Dataset, "__init__"):
required_file_names = {
resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources()
}
Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_smoke(self, test_home, dataset_mock, config):

dataset = datasets.load(dataset_mock.name, **config)

if not isinstance(dataset, datasets.utils.Dataset2):
if not isinstance(dataset, datasets.utils.Dataset):
raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.")

@parametrize_dataset_mocks(DATASET_MOCKS)
Expand Down
231 changes: 0 additions & 231 deletions test/test_prototype_datasets_api.py

This file was deleted.

20 changes: 19 additions & 1 deletion test/test_prototype_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from datasets_utils import make_fake_flo_file
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile


Expand Down Expand Up @@ -101,3 +101,21 @@ def preprocess_sentinel(path):
assert redirected_resource.file_name == file_name
assert redirected_resource.sha256 == sha256_sentinel
assert redirected_resource._preprocess is preprocess_sentinel


def test_missing_dependency_error():
class DummyDataset(Dataset):
def __init__(self):
super().__init__(root="root", dependencies=("fake_dependency",))

def _resources(self):
pass

def _datapipe(self, resource_dps):
pass

def __len__(self):
pass

with pytest.raises(ModuleNotFoundError, match="depends on the third-party package 'fake_dependency'"):
DummyDataset()
6 changes: 3 additions & 3 deletions torchvision/prototype/datasets/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import Any, Dict, List, Callable, Type, Optional, Union, TypeVar

from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.utils import Dataset2
from torchvision.prototype.datasets.utils import Dataset
from torchvision.prototype.utils._internal import add_suggestion


T = TypeVar("T")
D = TypeVar("D", bound=Type[Dataset2])
D = TypeVar("D", bound=Type[Dataset])

BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {}

Expand Down Expand Up @@ -56,7 +56,7 @@ def info(name: str) -> Dict[str, Any]:
return find(BUILTIN_INFOS, name)


def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset2:
def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset:
dataset_cls = find(BUILTIN_DATASETS, name)

if root is None:
Expand Down
18 changes: 6 additions & 12 deletions torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,26 @@
Filter,
IterKeyZipper,
)
from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
read_mat,
hint_sharding,
hint_shuffling,
BUILTIN_DIR,
read_categories_file,
)
from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage

from .._api import register_dataset, register_info


CALTECH101_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "caltech101.categories"))


@register_info("caltech101")
def _caltech101_info() -> Dict[str, Any]:
return dict(categories=CALTECH101_CATEGORIES)
return dict(categories=read_categories_file("caltech101"))


@register_dataset("caltech101")
class Caltech101(Dataset2):
class Caltech101(Dataset):
"""
- **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech101
- **dependencies**:
Expand Down Expand Up @@ -152,16 +149,13 @@ def _generate_categories(self) -> List[str]:
return sorted({pathlib.Path(path).parent.name for path, _ in dp})


CALTECH256_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "caltech256.categories"))


@register_info("caltech256")
def _caltech256_info() -> Dict[str, Any]:
return dict(categories=CALTECH256_CATEGORIES)
return dict(categories=read_categories_file("caltech256"))


@register_dataset("caltech256")
class Caltech256(Dataset2):
class Caltech256(Dataset):
"""
- **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech256
"""
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
IterKeyZipper,
)
from torchvision.prototype.datasets.utils import (
Dataset2,
Dataset,
GDriveResource,
OnlineResource,
)
Expand Down Expand Up @@ -68,7 +68,7 @@ def _info() -> Dict[str, Any]:


@register_dataset(NAME)
class CelebA(Dataset2):
class CelebA(Dataset):
"""
- **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
"""
Expand Down
Loading