Skip to content

Commit

Permalink
Add distributed test
Browse files Browse the repository at this point in the history
  • Loading branch information
ejguan committed Aug 8, 2022
1 parent 0ce199a commit 5fd4981
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 5 deletions.
89 changes: 89 additions & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import os
import unittest

from functools import partial
from unittest import TestCase

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize

from torchdata.datapipes.iter import IterableWrapper
from torchdata.datapipes.iter.util.prefetch import PrefetchTimeoutError

TEST_MASTER_ADDR = "127.0.0.1"
TEST_MASTER_PORT = "29500"
DEFAULT_WORLD_SIZE = 2


if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)


def launch_distributed_training(backend, world_size, fn):
os.environ["MASTER_ADDR"] = TEST_MASTER_ADDR
os.environ["MASTER_PORT"] = TEST_MASTER_PORT
mp.spawn(
fn,
args=(
world_size,
backend,
),
nprocs=world_size,
join=True,
)


class DistributedTest(TestCase):
@staticmethod
def _test_fullsync(rank, world_size, backend):
dist.init_process_group(backend, rank=rank, world_size=world_size)
# Use a prime number to make sure uneven data sharding
data_length = 23
dp = IterableWrapper(list(range(data_length))).sharding_filter()
torch.utils.data.graph_settings.apply_sharding(dp, world_size, rank)

dp1 = dp.fullsync()
for _ in range(2):
res = []
for d in dp1:
res.append(d)
# Simulate training synchronization
dist.barrier()
assert res == list(range(rank, data_length // world_size * world_size, world_size))

# Timeout Test
dp2 = dp.fullsync(timeout=0.01)
try:
for _ in range(2):
_ = list(dp2)
except Exception as e:
assert isinstance(e, PrefetchTimeoutError)

@parametrize(
"backend",
["gloo", "nccl"]
if torch.cuda.nccl.is_available([])
else [
"gloo",
],
)
def test_fullsync(self, backend) -> None:
world_size = DEFAULT_WORLD_SIZE if backend == "gloo" else torch.cuda.device_count()
launch_distributed_training(backend, world_size, DistributedTest._test_fullsync)


instantiate_parametrized_tests(DistributedTest)


if __name__ == "__main__":
unittest.main()
5 changes: 0 additions & 5 deletions torchdata/datapipes/iter/util/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,6 @@ def _callback_fn(self, exp: Expected) -> None:
self._cv.notify()

def __iter__(self) -> Iterator[T_co]:
if not (dist.is_available() and dist.is_initialized()):
raise RuntimeError("Torch Distributed is required to be initialized")
self._process_group = dist.new_group(backend="gloo")
self._world_size = dist.get_world_size()

assert self._executor is None

if not (dist.is_available() and dist.is_initialized()):
Expand Down

0 comments on commit 5fd4981

Please sign in to comment.