Skip to content

Commit

Permalink
add a test case for gather_all_tensors_ddp in #3253
Browse files Browse the repository at this point in the history
  • Loading branch information
ShomyLiu authored and justusschock committed Sep 3, 2020
1 parent 1dacdd8 commit ba8afb6
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/metrics/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
_numpy_metric_conversion,
_tensor_metric_conversion,
sync_ddp_if_available,
gather_all_tensors_if_available,
tensor_metric,
numpy_metric
)
Expand Down Expand Up @@ -134,6 +135,17 @@ def _ddp_test_fn(rank, worldsize, add_offset: bool, reduction_mean=False):
'Sync-Reduce does not work properly with DDP and Tensors'


def _ddp_test_gather_all_tensors(rank, worldsize):
_setup_ddp(rank, worldsize)

tensor = torch.tensor([rank])
gather_tensors = gather_all_tensors_if_available(tensor)
mannual_tensors = [torch.tensor([i]) for i in range(worldsize)]

for t1, t2 in zip(gather_tensors, mannual_tensors):
assert(t1.equal(t2))


@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
def test_sync_reduce_ddp():
"""Make sure sync-reduce works with DDP"""
Expand Down Expand Up @@ -164,6 +176,16 @@ def test_sync_reduce_simple():
'Sync-Reduce does not work properly without DDP and Tensors'


@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
def test_gather_all_tensors_ddp():
"""Make sure gather_all_tensors works with DDP"""
tutils.reset_seed()
tutils.set_random_master_port()

worldsize = 2
mp.spawn(_ddp_test_gather_all_tensors, args=(worldsize, ), nprocs=worldsize)


def _test_tensor_metric(is_ddp: bool):
@tensor_metric()
def tensor_test_metric(*args, **kwargs):
Expand Down

0 comments on commit ba8afb6

Please sign in to comment.