diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index f8dc3a0542b..06f86c0a016 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -1,5 +1,6 @@ import functools import io +import os import pickle from pathlib import Path @@ -7,6 +8,7 @@ import torch from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair +from torch.utils.data import DataLoader from torch.utils.data.graph import traverse from torch.utils.data.graph_settings import get_all_graph_pipes from torchdata.datapipes.iter import Shuffler, ShardingFilter @@ -30,6 +32,23 @@ def test_home(mocker, tmp_path): yield tmp_path +@pytest.fixture +def ddp_fixture(): + # Note: we only test DDP with world_size=1, but it should be enough for our purpose. + # If we ever need to go full DDP, we'll need to implement a much more complex logic, similar to + # MultiProcessTestCase from torch core. + + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29501" + torch.distributed.init_process_group(backend="gloo", world_size=1, rank=0) + torch.distributed.barrier() + + yield + + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + def test_coverage(): untested_datasets = set(datasets.list_datasets()) - DATASET_MOCKS.keys() if untested_datasets: @@ -109,7 +128,7 @@ def test_transformable(self, test_home, dataset_mock, config): next(iter(dataset.map(transforms.Identity()))) - @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") + # @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") @parametrize_dataset_mocks(DATASET_MOCKS) def test_serializable(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) @@ -118,10 +137,20 @@ def test_serializable(self, test_home, dataset_mock, config): pickle.dumps(dataset) + @parametrize_dataset_mocks(DATASET_MOCKS) + def test_ddp(self, test_home, dataset_mock, config, ddp_fixture): + dataset_mock.prepare(test_home, config) + + dataset = datasets.load(dataset_mock.name, **config) + + dl = DataLoader(dataset, collate_fn=lambda batch: batch) + + next(iter(dl)) + # TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also # that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680 # contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now. - @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") + # @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") @parametrize_dataset_mocks(DATASET_MOCKS) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter)) def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):