diff --git a/tests/unit/common.py b/tests/unit/common.py new file mode 100644 index 000000000000..4109eb0abb65 --- /dev/null +++ b/tests/unit/common.py @@ -0,0 +1,100 @@ +import os +import time + +import torch +import torch.distributed as dist +from torch.multiprocessing import Process + +import pytest + +# Worker timeout _after_ the first worker has completed. +DEEPSPEED_UNIT_WORKER_TIMEOUT = 5 + + +def distributed_test(world_size=2): + """A decorator for executing a function (e.g., a unit test) in a distributed manner. + This decorator manages the spawning and joining of processes, initialization of + torch.distributed, and catching of errors. + + Usage example: + @distributed_test(worker_size=[2,3]) + def my_test(): + rank = dist.get_rank() + world_size = dist.get_world_size() + assert(rank < world_size) + + Arguments: + world_size (int or list): number of ranks to spawn. Can be a list to spawn + multiple tests. + """ + def dist_wrap(run_func): + """Second-level decorator for dist_test. This actually wraps the function. """ + def dist_init(local_rank, num_procs, *func_args, **func_kwargs): + """Initialize torch.distributed and execute the user function. """ + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29500' + dist.init_process_group(backend='nccl', + init_method='env://', + rank=local_rank, + world_size=num_procs) + + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + + run_func(*func_args, **func_kwargs) + + def dist_launcher(num_procs, *func_args, **func_kwargs): + """Launch processes and gracefully handle failures. """ + + # Spawn all workers on subprocesses. + processes = [] + for local_rank in range(num_procs): + p = Process(target=dist_init, + args=(local_rank, + num_procs, + *func_args), + kwargs=func_kwargs) + p.start() + processes.append(p) + + # Now loop and wait for a test to complete. The spin-wait here isn't a big + # deal because the number of processes will be O(#GPUs) << O(#CPUs). + any_done = False + while not any_done: + for p in processes: + if not p.is_alive(): + any_done = True + break + + # Wait for all other processes to complete + for p in processes: + p.join(DEEPSPEED_UNIT_WORKER_TIMEOUT) + + failed = [(rank, p) for rank, p in enumerate(processes) if p.exitcode != 0] + for rank, p in failed: + # If it still hasn't terminated, kill it because it hung. + if p.exitcode is None: + p.terminate() + pytest.fail(f'Worker {rank} hung.', pytrace=False) + if p.exitcode < 0: + pytest.fail(f'Worker {rank} killed by signal {-p.exitcode}', + pytrace=False) + if p.exitcode > 0: + pytest.fail(f'Worker {rank} exited with code {p.exitcode}', + pytrace=False) + + def run_func_decorator(*func_args, **func_kwargs): + """Entry point for @distributed_test(). """ + + if isinstance(world_size, int): + dist_launcher(world_size, *func_args, **func_kwargs) + elif isinstance(world_size, list): + for procs in world_size: + dist_launcher(procs, *func_args, **func_kwargs) + time.sleep(0.5) + else: + raise TypeError(f'world_size must be an integer or a list of integers.') + + return run_func_decorator + + return dist_wrap diff --git a/tests/unit/test_dist.py b/tests/unit/test_dist.py new file mode 100644 index 000000000000..2b794137dfd2 --- /dev/null +++ b/tests/unit/test_dist.py @@ -0,0 +1,28 @@ +import torch.distributed as dist + +from common import distributed_test + +import pytest + + +@distributed_test(world_size=3) +def test_init(): + assert dist.is_initialized() + assert dist.get_world_size() == 3 + assert dist.get_rank() < 3 + + +# Demonstration of pytest's paramaterization +@pytest.mark.parametrize('number,color', [(1138, 'purple')]) +def test_dist_args(number, color): + """Outer test function with inputs from pytest.mark.parametrize(). Uses a distributed + helper function. + """ + @distributed_test(world_size=2) + def _test_dist_args_helper(x, color='red'): + assert dist.get_world_size() == 2 + assert x == 1138 + assert color == 'purple' + + """Ensure that we can parse args to distributed_test decorated functions. """ + _test_dist_args_helper(number, color=color)