Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gather_all_tensors #220

Merged
merged 16 commits into from
May 4, 2021
Merged
32 changes: 31 additions & 1 deletion tests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tests.helpers import seed_all
from tests.helpers.testers import DummyMetric, setup_ddp
from torchmetrics import Metric
from torchmetrics.utilities.distributed import gather_all_tensors

seed_all(42)

Expand Down Expand Up @@ -56,8 +57,37 @@ def _test_ddp_sum_cat(rank, worldsize):
assert dummy.bar == worldsize


def _test_ddp_gather_uneven_tensors(rank, worldsize):
setup_ddp(rank, worldsize)
tensor = torch.ones(rank)
result = gather_all_tensors(tensor)
assert len(result) == worldsize
for idx in range(worldsize):
assert len(result[idx]) == idx
assert (result[idx] == torch.ones_like(result[idx])).all()


def _test_ddp_gather_uneven_tensors2(rank, worldsize):
maximsch2 marked this conversation as resolved.
Show resolved Hide resolved
setup_ddp(rank, worldsize)
tensor = torch.ones(rank + 1, 2 - rank)
result = gather_all_tensors(tensor)
assert len(result) == worldsize
for idx in range(worldsize):
val = result[idx]
assert val.shape == (idx + 1, 2 - idx)
assert (val == torch.ones_like(val)).all()


@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.parametrize("process", [_test_ddp_cat, _test_ddp_sum, _test_ddp_sum_cat])
@pytest.mark.parametrize(
"process", [
_test_ddp_cat,
_test_ddp_sum,
_test_ddp_sum_cat,
_test_ddp_gather_uneven_tensors,
_test_ddp_gather_uneven_tensors2,
maximsch2 marked this conversation as resolved.
Show resolved Hide resolved
]
)
def test_ddp(process):
torch.multiprocessing.spawn(process, args=(2, ), nprocs=2)

Expand Down
43 changes: 38 additions & 5 deletions torchmetrics/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Optional, Union

import torch
import torch.nn.functional as F
from torch import Tensor


Expand Down Expand Up @@ -88,6 +89,12 @@ def class_reduce(num: Tensor, denom: Tensor, weights: Tensor, class_reduction: s
)


def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int):
maximsch2 marked this conversation as resolved.
Show resolved Hide resolved
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
return gathered_result


def gather_all_tensors(result: Union[Tensor], group: Optional[Any] = None):
maximsch2 marked this conversation as resolved.
Show resolved Hide resolved
"""
Function to gather all tensors from several ddp processes onto a list that
Expand All @@ -108,11 +115,37 @@ def gather_all_tensors(result: Union[Tensor], group: Optional[Any] = None):
result = result.contiguous()

world_size = torch.distributed.get_world_size(group)

gathered_result = [torch.zeros_like(result) for _ in range(world_size)]

# sync and broadcast all
torch.distributed.barrier(group=group)
torch.distributed.all_gather(gathered_result, result, group)

# if the tensor is scalar, things are easy
if result.ndim == 0:
return _simple_gather_all_tensors(result, group, world_size)

# 1. Gather sizes of all tensors
local_size = torch.tensor(result.shape, device=result.device)
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
torch.distributed.all_gather(local_sizes, local_size, group=group)
max_size = torch.stack(local_sizes).max(dim=0).values
all_sizes_equal = True
for size in local_sizes:
if not (size == max_size).all():
all_sizes_equal = False
break
maximsch2 marked this conversation as resolved.
Show resolved Hide resolved

# 2. If shapes are all the same, then do a simple gather:
if all_sizes_equal:
return _simple_gather_all_tensors(result, group, world_size)

# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
pad_dims = []
pad_by = (max_size - local_size).detach().cpu()
for val in reversed(pad_by):
pad_dims.append(0)
pad_dims.append(val.item())
result_padded = F.pad(result, pad_dims)
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result_padded, group)
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
return gathered_result