Skip to content

Commit

Permalink
Cleanup prototype dataset implementation (#5774)
Browse files Browse the repository at this point in the history
* Remove Dataset2 class

* Move read_categories_file out of DatasetInfo

* Remove FrozenBunch and FrozenMapping

* Remove test_prototype_datasets_api.py and move missing dep test somewhere else

* ufmt

* Let read_categories_file accept names instead of paths

* Mypy

* flake8

* fix category file reading

Co-authored-by: Philip Meier <github.pmeier@posteo.de>
  • Loading branch information
NicolasHug and pmeier committed Apr 7, 2022
1 parent 5062a32 commit 3be12c7
Show file tree
Hide file tree
Showing 31 changed files with 121 additions and 607 deletions.
2 changes: 1 addition & 1 deletion test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def prepare(self, home, config):

mock_info = self._parse_mock_info(self.mock_data_fn(root, config))

with unittest.mock.patch.object(datasets.utils.Dataset2, "__init__"):
with unittest.mock.patch.object(datasets.utils.Dataset, "__init__"):
required_file_names = {
resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources()
}
Expand Down
2 changes: 1 addition & 1 deletion test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_smoke(self, test_home, dataset_mock, config):

dataset = datasets.load(dataset_mock.name, **config)

if not isinstance(dataset, datasets.utils.Dataset2):
if not isinstance(dataset, datasets.utils.Dataset):
raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.")

@parametrize_dataset_mocks(DATASET_MOCKS)
Expand Down
231 changes: 0 additions & 231 deletions test/test_prototype_datasets_api.py

This file was deleted.

20 changes: 19 additions & 1 deletion test/test_prototype_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from datasets_utils import make_fake_flo_file
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile


Expand Down Expand Up @@ -101,3 +101,21 @@ def preprocess_sentinel(path):
assert redirected_resource.file_name == file_name
assert redirected_resource.sha256 == sha256_sentinel
assert redirected_resource._preprocess is preprocess_sentinel


def test_missing_dependency_error():
class DummyDataset(Dataset):
def __init__(self):
super().__init__(root="root", dependencies=("fake_dependency",))

def _resources(self):
pass

def _datapipe(self, resource_dps):
pass

def __len__(self):
pass

with pytest.raises(ModuleNotFoundError, match="depends on the third-party package 'fake_dependency'"):
DummyDataset()
6 changes: 3 additions & 3 deletions torchvision/prototype/datasets/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import Any, Dict, List, Callable, Type, Optional, Union, TypeVar

from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.utils import Dataset2
from torchvision.prototype.datasets.utils import Dataset
from torchvision.prototype.utils._internal import add_suggestion


T = TypeVar("T")
D = TypeVar("D", bound=Type[Dataset2])
D = TypeVar("D", bound=Type[Dataset])

BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {}

Expand Down Expand Up @@ -56,7 +56,7 @@ def info(name: str) -> Dict[str, Any]:
return find(BUILTIN_INFOS, name)


def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset2:
def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset:
dataset_cls = find(BUILTIN_DATASETS, name)

if root is None:
Expand Down
18 changes: 6 additions & 12 deletions torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,26 @@
Filter,
IterKeyZipper,
)
from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
read_mat,
hint_sharding,
hint_shuffling,
BUILTIN_DIR,
read_categories_file,
)
from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage

from .._api import register_dataset, register_info


CALTECH101_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "caltech101.categories"))


@register_info("caltech101")
def _caltech101_info() -> Dict[str, Any]:
return dict(categories=CALTECH101_CATEGORIES)
return dict(categories=read_categories_file("caltech101"))


@register_dataset("caltech101")
class Caltech101(Dataset2):
class Caltech101(Dataset):
"""
- **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech101
- **dependencies**:
Expand Down Expand Up @@ -152,16 +149,13 @@ def _generate_categories(self) -> List[str]:
return sorted({pathlib.Path(path).parent.name for path, _ in dp})


CALTECH256_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "caltech256.categories"))


@register_info("caltech256")
def _caltech256_info() -> Dict[str, Any]:
return dict(categories=CALTECH256_CATEGORIES)
return dict(categories=read_categories_file("caltech256"))


@register_dataset("caltech256")
class Caltech256(Dataset2):
class Caltech256(Dataset):
"""
- **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech256
"""
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
IterKeyZipper,
)
from torchvision.prototype.datasets.utils import (
Dataset2,
Dataset,
GDriveResource,
OnlineResource,
)
Expand Down Expand Up @@ -68,7 +68,7 @@ def _info() -> Dict[str, Any]:


@register_dataset(NAME)
class CelebA(Dataset2):
class CelebA(Dataset):
"""
- **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
"""
Expand Down
Loading

0 comments on commit 3be12c7

Please sign in to comment.