Skip to content

Commit

Permalink
[fbsync] Refactor and simplify prototype datasets (#5778)
Browse files Browse the repository at this point in the history
Summary:
* refactor prototype datasets to inherit from IterDataPipe (#5448)

* refactor prototype datasets to inherit from IterDataPipe

* depend on new architecture

* fix missing file detection

* remove unrelated file

* reinstante decorator for mock registering

* options -> config

* remove passing of info to mock data functions

* refactor categories file generation

* fix imagenet

* fix prototype datasets data loading tests (#5711)

* reenable serialization test

* cleanup

* fix dill test

* trigger CI

* patch DILL_AVAILABLE for pickle serialization

* revert CI changes

* remove dill test and traversable test

* add data loader test

* parametrize over only_datapipe

* draw one sample rather than exhaust data loader

* cleanup

* trigger CI

* migrate VOC prototype dataset (#5743)

* migrate VOC prototype dataset

* cleanup

* revert unrelated mock data changes

* remove categories annotations

* move properties to constructor

* readd homepage

* migrate CIFAR prototype datasets (#5751)

* migrate country211 prototype dataset (#5753)

* migrate CLEVR prototype datsaet (#5752)

* migrate coco prototype (#5473)

* migrate coco prototype

* revert unrelated change

* add kwargs to super constructor call

* remove unneeded changes

* fix docstring position

* make kwargs explicit

* add dependencies to docstring

* fix missing dependency message

* Migrate PCAM prototype dataset (#5745)

* Port PCAM

* skip_integrity_check

* Update torchvision/prototype/datasets/_builtin/pcam.py

* Address comments

* Migrate DTD prototype dataset (#5757)

* Migrate DTD prototype dataset

* Docstring

* Apply suggestions from code review

* Migrate GTSRB prototype dataset (#5746)

* Migrate GTSRB prototype dataset

* ufmt

* Address comments

* Apparently mypy doesn't know that __len__ returns ints. How cute.

* why is the CI not triggered??

* Update torchvision/prototype/datasets/_builtin/gtsrb.py

* migrate CelebA prototype dataset (#5750)

* migrate CelebA prototype dataset

* inline split_id

* Migrate Food101 prototype dataset (#5758)

* Migrate Food101 dataset

* Added length

* Update torchvision/prototype/datasets/_builtin/food101.py

* Migrate Fer2013 prototype dataset (#5759)

* Migrate Fer2013 prototype dataset

* Update torchvision/prototype/datasets/_builtin/fer2013.py

* Migrate EuroSAT prototype dataset (#5760)

* Migrate Semeion prototype dataset (#5761)

* migrate caltech prototype datasets (#5749)

* migrate caltech prototype datasets

* resolve third party dependencies

* Migrate Oxford Pets prototype dataset (#5764)

* Migrate Oxford Pets prototype dataset

* Update torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py

* migrate mnist prototype datasets (#5480)

* migrate MNIST prototype datasets

* Update torchvision/prototype/datasets/_builtin/mnist.py

* Migrate Stanford Cars prototype dataset (#5767)

* Migrate Stanford Cars prototype dataset

* Address comments

* fix category file generation (#5770)

* fix category file generation

* revert unrelated change

* revert unrelated change

* migrate cub200 prototype dataset (#5765)

* migrate cub200 prototype dataset

* address comments

* fix category-file-generation

* Migrate USPS prototype dataset (#5771)

* migrate SBD prototype dataset (#5772)

* migrate SBD prototype dataset

* reuse categories

* Migrate SVHN prototype dataset (#5769)

* add test to enforce __len__ is working on prototype datasets (#5742)

* reactivate special dataset tests

* add missing annotation

* Cleanup prototype dataset implementation (#5774)

* Remove Dataset2 class

* Move read_categories_file out of DatasetInfo

* Remove FrozenBunch and FrozenMapping

* Remove test_prototype_datasets_api.py and move missing dep test somewhere else

* ufmt

* Let read_categories_file accept names instead of paths

* Mypy

* flake8

* fix category file reading

* update prototype dataset README (#5777)

* update prototype dataset README

* fix header level

* Apply suggestions from code review

(Note: this ignores all push blocking failures!)

Reviewed By: jdsgomes, NicolasHug

Differential Revision: D36095693

fbshipit-source-id: d57f2b4a89ef1c45f5e2783ea57bce08df5c561d

Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
  • Loading branch information
3 people authored and facebook-github-bot committed May 5, 2022
1 parent f9cc788 commit e581dd0
Show file tree
Hide file tree
Showing 35 changed files with 1,680 additions and 1,494 deletions.
329 changes: 184 additions & 145 deletions test/builtin_dataset_mocks.py

Large diffs are not rendered by default.

70 changes: 46 additions & 24 deletions test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import torch
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
from torch.utils.data import DataLoader
from torch.utils.data.graph import traverse
from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.datapipes.iter import IterDataPipe, Shuffler, ShardingFilter
from torchdata.datapipes.iter import Shuffler, ShardingFilter
from torchvision._utils import sequence_to_str
from torchvision.prototype import transforms, datasets
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
Expand Down Expand Up @@ -42,14 +43,24 @@ def test_coverage():

@pytest.mark.filterwarnings("error")
class TestCommon:
@pytest.mark.parametrize("name", datasets.list_datasets())
def test_info(self, name):
try:
info = datasets.info(name)
except ValueError:
raise AssertionError("No info available.") from None

if not (isinstance(info, dict) and all(isinstance(key, str) for key in info.keys())):
raise AssertionError("Info should be a dictionary with string keys.")

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_smoke(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)

if not isinstance(dataset, IterDataPipe):
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")
if not isinstance(dataset, datasets.utils.Dataset):
raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.")

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_sample(self, test_home, dataset_mock, config):
Expand All @@ -76,24 +87,7 @@ def test_num_samples(self, test_home, dataset_mock, config):

dataset = datasets.load(dataset_mock.name, **config)

num_samples = 0
for _ in dataset:
num_samples += 1

assert num_samples == mock_info["num_samples"]

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_decoding(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)

undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)}
if undecoded_features:
raise AssertionError(
f"The values of key(s) "
f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded."
)
assert len(list(dataset)) == mock_info["num_samples"]

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
Expand All @@ -116,14 +110,36 @@ def test_transformable(self, test_home, dataset_mock, config):

next(iter(dataset.map(transforms.Identity())))

@pytest.mark.parametrize("only_datapipe", [False, True])
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_serializable(self, test_home, dataset_mock, config):
def test_traversable(self, test_home, dataset_mock, config, only_datapipe):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)

traverse(dataset, only_datapipe=only_datapipe)

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_serializable(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)

pickle.dumps(dataset)

@pytest.mark.parametrize("num_workers", [0, 1])
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_data_loader(self, test_home, dataset_mock, config, num_workers):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)

dl = DataLoader(
dataset,
batch_size=2,
num_workers=num_workers,
collate_fn=lambda batch: batch,
)

next(iter(dl))

# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
Expand All @@ -132,7 +148,6 @@ def test_serializable(self, test_home, dataset_mock, config):
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):

dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)

if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)):
Expand Down Expand Up @@ -160,6 +175,13 @@ def test_infinite_buffer_size(self, test_home, dataset_mock, config):
# resolved
assert dp.buffer_size == INFINITE_BUFFER_SIZE

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_has_length(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)

assert len(dataset) > 0


@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST:
Expand All @@ -186,7 +208,7 @@ class TestGTSRB:
def test_label_matches_path(self, test_home, dataset_mock, config):
# We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
# This test makes sure that they're both the same
if config.split != "train":
if config["split"] != "train":
return

dataset_mock.prepare(test_home, config)
Expand Down
231 changes: 0 additions & 231 deletions test/test_prototype_datasets_api.py

This file was deleted.

20 changes: 19 additions & 1 deletion test/test_prototype_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from datasets_utils import make_fake_flo_file
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile


Expand Down Expand Up @@ -101,3 +101,21 @@ def preprocess_sentinel(path):
assert redirected_resource.file_name == file_name
assert redirected_resource.sha256 == sha256_sentinel
assert redirected_resource._preprocess is preprocess_sentinel


def test_missing_dependency_error():
class DummyDataset(Dataset):
def __init__(self):
super().__init__(root="root", dependencies=("fake_dependency",))

def _resources(self):
pass

def _datapipe(self, resource_dps):
pass

def __len__(self):
pass

with pytest.raises(ModuleNotFoundError, match="depends on the third-party package 'fake_dependency'"):
DummyDataset()
Loading

0 comments on commit e581dd0

Please sign in to comment.