From ffe3a1e9d6f67fa94629f753a434d26ff4e87d8b Mon Sep 17 00:00:00 2001 From: erjia Date: Thu, 4 Aug 2022 21:43:52 +0000 Subject: [PATCH] Add distributed test --- test/test_distributed.py | 73 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 test/test_distributed.py diff --git a/test/test_distributed.py b/test/test_distributed.py new file mode 100644 index 000000000..ddae7def3 --- /dev/null +++ b/test/test_distributed.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import os +import pytest +import unittest + +from functools import partial +from expecttest import TestCase + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) + + +from torchdata.datapipes.iter import IterableWrapper + +TEST_MASTER_ADDR = "127.0.0.1" +TEST_MASTER_PORT = "29500" +DEFAULT_WORLD_SIZE = 2 + + +if not dist.is_available(): + print("Distributed not available, skipping tests", file=sys.stderr) + sys.exit(0) + + +def launch_distributed_training(backend, world_size, fn): + os.environ["MASTER_ADDR"] = TEST_MASTER_ADDR + os.environ["MASTER_PORT"] = TEST_MASTER_PORT + mp.spawn( + fn, + args=(world_size, backend, ), + nprocs=world_size, + join=True, + ) + + +class DistributedTest(TestCase): + @staticmethod + def _test_fullsync(rank, world_size, backend): + dist.init_process_group(backend, rank=rank, world_size=world_size) + # Use a prime number to make sure uneven data sharding + data_length = 23 + dp = IterableWrapper(list(range(data_length))).sharding_filter().fullsync() + torch.utils.data.graph_settings.apply_sharding(dp, world_size, rank) + for _ in range(2): + res = [] + for d in dp: + res.append(d) + # Simulate training synchronization + dist.barrier() + assert res == list(range(rank, data_length // world_size * world_size, world_size)) + + @parametrize("backend", ["gloo", "nccl"] if torch.cuda.nccl.is_available([]) else ["gloo", ]) + def test_fullsync(self, backend) -> None: + world_size = DEFAULT_WORLD_SIZE if backend == "gloo" else torch.cuda.device_count() + launch_distributed_training(backend, world_size, DistributedTest._test_fullsync) + + +instantiate_parametrized_tests(DistributedTest) + + +if __name__ == "__main__": + unittest.main()