Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions tests/unit/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os
import time

import torch
import torch.distributed as dist
from torch.multiprocessing import Process

import pytest

# Worker timeout _after_ the first worker has completed.
DEEPSPEED_UNIT_WORKER_TIMEOUT = 5


def distributed_test(world_size=2):
"""A decorator for executing a function (e.g., a unit test) in a distributed manner.
This decorator manages the spawning and joining of processes, initialization of
torch.distributed, and catching of errors.

Usage example:
@distributed_test(worker_size=[2,3])
def my_test():
rank = dist.get_rank()
world_size = dist.get_world_size()
assert(rank < world_size)

Arguments:
world_size (int or list): number of ranks to spawn. Can be a list to spawn
multiple tests.
"""
def dist_wrap(run_func):
"""Second-level decorator for dist_test. This actually wraps the function. """
def dist_init(local_rank, num_procs, *func_args, **func_kwargs):
"""Initialize torch.distributed and execute the user function. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend='nccl',
init_method='env://',
rank=local_rank,
world_size=num_procs)

# XXX temporarily disabled due to CUDA runtime error?
#if torch.cuda.is_available():
# torch.cuda.set_device(local_rank)

run_func(*func_args, **func_kwargs)

def dist_launcher(num_procs, *func_args, **func_kwargs):
"""Launch processes and gracefully handle failures. """

# Spawn all workers on subprocesses.
processes = []
for local_rank in range(num_procs):
p = Process(target=dist_init,
args=(local_rank,
num_procs,
*func_args),
kwargs=func_kwargs)
p.start()
processes.append(p)

# Now loop and wait for a test to complete. The spin-wait here isn't a big
# deal because the number of processes will be O(#GPUs) << O(#CPUs).
any_done = False
while not any_done:
for p in processes:
if not p.is_alive():
any_done = True
break

# Wait for all other processes to complete
for p in processes:
p.join(DEEPSPEED_UNIT_WORKER_TIMEOUT)

failed = [(rank, p) for rank, p in enumerate(processes) if p.exitcode != 0]
for rank, p in failed:
# If it still hasn't terminated, kill it because it hung.
if p.exitcode is None:
p.terminate()
pytest.fail(f'Worker {rank} hung.', pytrace=False)
if p.exitcode < 0:
pytest.fail(f'Worker {rank} killed by signal {-p.exitcode}',
pytrace=False)
if p.exitcode > 0:
pytest.fail(f'Worker {rank} exited with code {p.exitcode}',
pytrace=False)

def run_func_decorator(*func_args, **func_kwargs):
"""Entry point for @distributed_test(). """

if isinstance(world_size, int):
dist_launcher(world_size, *func_args, **func_kwargs)
elif isinstance(world_size, list):
for procs in world_size:
dist_launcher(procs, *func_args, **func_kwargs)
time.sleep(0.5)
else:
raise TypeError(f'world_size must be an integer or a list of integers.')

return run_func_decorator

return dist_wrap
28 changes: 28 additions & 0 deletions tests/unit/test_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch.distributed as dist

from common import distributed_test

import pytest


@distributed_test(world_size=3)
def test_init():
assert dist.is_initialized()
assert dist.get_world_size() == 3
assert dist.get_rank() < 3


# Demonstration of pytest's paramaterization
@pytest.mark.parametrize('number,color', [(1138, 'purple')])
def test_dist_args(number, color):
"""Outer test function with inputs from pytest.mark.parametrize(). Uses a distributed
helper function.
"""
@distributed_test(world_size=2)
def _test_dist_args_helper(x, color='red'):
assert dist.get_world_size() == 2
assert x == 1138
assert color == 'purple'

"""Ensure that we can parse args to distributed_test decorated functions. """
_test_dist_args_helper(number, color=color)