diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index f8dc3a0542b..8a929b6907c 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -7,6 +7,7 @@ import torch from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair +from torch.utils.data import DataLoader from torch.utils.data.graph import traverse from torch.utils.data.graph_settings import get_all_graph_pipes from torchdata.datapipes.iter import Shuffler, ShardingFilter @@ -109,19 +110,39 @@ def test_transformable(self, test_home, dataset_mock, config): next(iter(dataset.map(transforms.Identity()))) - @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") + @pytest.mark.parametrize("only_datapipe", [False, True]) @parametrize_dataset_mocks(DATASET_MOCKS) - def test_serializable(self, test_home, dataset_mock, config): + def test_traversable(self, test_home, dataset_mock, config, only_datapipe): dataset_mock.prepare(test_home, config) + dataset = datasets.load(dataset_mock.name, **config) + traverse(dataset, only_datapipe=only_datapipe) + + @parametrize_dataset_mocks(DATASET_MOCKS) + def test_serializable(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) pickle.dumps(dataset) + @pytest.mark.parametrize("num_workers", [0, 1]) + @parametrize_dataset_mocks(DATASET_MOCKS) + def test_data_loader(self, test_home, dataset_mock, config, num_workers): + dataset_mock.prepare(test_home, config) + dataset = datasets.load(dataset_mock.name, **config) + + dl = DataLoader( + dataset, + batch_size=2, + num_workers=num_workers, + collate_fn=lambda batch: batch, + ) + + next(iter(dl)) + # TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also # that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680 # contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now. - @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") @parametrize_dataset_mocks(DATASET_MOCKS) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter)) def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):