Skip to content

Commit

Permalink
Add distributed test
Browse files Browse the repository at this point in the history
  • Loading branch information
ejguan committed Aug 5, 2022
1 parent 0ce199a commit ffe3a1e
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import os
import pytest
import unittest

from functools import partial
from expecttest import TestCase

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)


from torchdata.datapipes.iter import IterableWrapper

TEST_MASTER_ADDR = "127.0.0.1"
TEST_MASTER_PORT = "29500"
DEFAULT_WORLD_SIZE = 2


if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)


def launch_distributed_training(backend, world_size, fn):
os.environ["MASTER_ADDR"] = TEST_MASTER_ADDR
os.environ["MASTER_PORT"] = TEST_MASTER_PORT
mp.spawn(
fn,
args=(world_size, backend, ),
nprocs=world_size,
join=True,
)


class DistributedTest(TestCase):
@staticmethod
def _test_fullsync(rank, world_size, backend):
dist.init_process_group(backend, rank=rank, world_size=world_size)
# Use a prime number to make sure uneven data sharding
data_length = 23
dp = IterableWrapper(list(range(data_length))).sharding_filter().fullsync()
torch.utils.data.graph_settings.apply_sharding(dp, world_size, rank)
for _ in range(2):
res = []
for d in dp:
res.append(d)
# Simulate training synchronization
dist.barrier()
assert res == list(range(rank, data_length // world_size * world_size, world_size))

@parametrize("backend", ["gloo", "nccl"] if torch.cuda.nccl.is_available([]) else ["gloo", ])
def test_fullsync(self, backend) -> None:
world_size = DEFAULT_WORLD_SIZE if backend == "gloo" else torch.cuda.device_count()
launch_distributed_training(backend, world_size, DistributedTest._test_fullsync)


instantiate_parametrized_tests(DistributedTest)


if __name__ == "__main__":
unittest.main()

0 comments on commit ffe3a1e

Please sign in to comment.