Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: gather_all_tensors cross GPUs in DDP #3319

Merged
merged 2 commits into from
Sep 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def gather_all_tensors_if_available(result: Union[torch.Tensor],

world_size = torch.distributed.get_world_size(group)

gathered_result = world_size * [torch.zeros_like(result)]
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test to tests/metrics/test_converters.py that actually test that this function does what it is expected to do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, thanks for the advice, I will add a test for this RP.


# sync and broadcast all
torch.distributed.barrier(group=group)
Expand Down
22 changes: 22 additions & 0 deletions tests/metrics/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
_numpy_metric_conversion,
_tensor_metric_conversion,
sync_ddp_if_available,
gather_all_tensors_if_available,
tensor_metric,
numpy_metric
)
Expand Down Expand Up @@ -134,6 +135,17 @@ def _ddp_test_fn(rank, worldsize, add_offset: bool, reduction_mean=False):
'Sync-Reduce does not work properly with DDP and Tensors'


def _ddp_test_gather_all_tensors(rank, worldsize):
_setup_ddp(rank, worldsize)

tensor = torch.tensor([rank])
gather_tensors = gather_all_tensors_if_available(tensor)
mannual_tensors = [torch.tensor([i]) for i in range(worldsize)]

for t1, t2 in zip(gather_tensors, mannual_tensors):
assert(t1.equal(t2))


@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
def test_sync_reduce_ddp():
"""Make sure sync-reduce works with DDP"""
Expand Down Expand Up @@ -164,6 +176,16 @@ def test_sync_reduce_simple():
'Sync-Reduce does not work properly without DDP and Tensors'


@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
def test_gather_all_tensors_ddp():
"""Make sure gather_all_tensors works with DDP"""
tutils.reset_seed()
tutils.set_random_master_port()

worldsize = 2
mp.spawn(_ddp_test_gather_all_tensors, args=(worldsize, ), nprocs=worldsize)


def _test_tensor_metric(is_ddp: bool):
@tensor_metric()
def tensor_test_metric(*args, **kwargs):
Expand Down