Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gather_all_tensors #220

Merged
merged 16 commits into from
May 4, 2021
Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed `PSNR` not working with `DDP` ([#214](https://github.com/PyTorchLightning/metrics/pull/214))


- Fixed metric calculation with unequal batch sizes ([#220](https://github.com/PyTorchLightning/metrics/pull/220))


## [0.3.1] - 2021-04-21

- Cleaning remaining inconsistency and fix PL develop integration (
Expand Down
32 changes: 31 additions & 1 deletion tests/bases/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tests.helpers import seed_all
from tests.helpers.testers import DummyMetric, setup_ddp
from torchmetrics import Metric
from torchmetrics.utilities.distributed import gather_all_tensors

seed_all(42)

Expand Down Expand Up @@ -56,8 +57,37 @@ def _test_ddp_sum_cat(rank, worldsize):
assert dummy.bar == worldsize


def _test_ddp_gather_uneven_tensors(rank, worldsize):
setup_ddp(rank, worldsize)
tensor = torch.ones(rank)
result = gather_all_tensors(tensor)
assert len(result) == worldsize
for idx in range(worldsize):
assert len(result[idx]) == idx
assert (result[idx] == torch.ones_like(result[idx])).all()


def _test_ddp_gather_uneven_tensors_multidim(rank, worldsize):
setup_ddp(rank, worldsize)
tensor = torch.ones(rank + 1, 2 - rank)
result = gather_all_tensors(tensor)
assert len(result) == worldsize
for idx in range(worldsize):
val = result[idx]
assert val.shape == (idx + 1, 2 - idx)
assert (val == torch.ones_like(val)).all()


@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.parametrize("process", [_test_ddp_cat, _test_ddp_sum, _test_ddp_sum_cat])
@pytest.mark.parametrize(
"process", [
_test_ddp_cat,
_test_ddp_sum,
_test_ddp_sum_cat,
_test_ddp_gather_uneven_tensors,
_test_ddp_gather_uneven_tensors_multidim,
]
)
def test_ddp(process):
torch.multiprocessing.spawn(process, args=(2, ), nprocs=2)

Expand Down
43 changes: 36 additions & 7 deletions torchmetrics/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union
from typing import Any, List, Optional

import torch
import torch.nn.functional as F
from torch import Tensor


Expand Down Expand Up @@ -88,7 +89,13 @@ def class_reduce(num: Tensor, denom: Tensor, weights: Tensor, class_reduction: s
)


def gather_all_tensors(result: Union[Tensor], group: Optional[Any] = None):
def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
return gathered_result


def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]:
"""
Function to gather all tensors from several ddp processes onto a list that
is broadcasted to all processes
Borda marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -108,11 +115,33 @@ def gather_all_tensors(result: Union[Tensor], group: Optional[Any] = None):
result = result.contiguous()

world_size = torch.distributed.get_world_size(group)

gathered_result = [torch.zeros_like(result) for _ in range(world_size)]

# sync and broadcast all
torch.distributed.barrier(group=group)
torch.distributed.all_gather(gathered_result, result, group)

# if the tensor is scalar, things are easy
if result.ndim == 0:
return _simple_gather_all_tensors(result, group, world_size)

# 1. Gather sizes of all tensors
local_size = torch.tensor(result.shape, device=result.device)
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
torch.distributed.all_gather(local_sizes, local_size, group=group)
max_size = torch.stack(local_sizes).max(dim=0).values
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)

# 2. If shapes are all the same, then do a simple gather:
if all_sizes_equal:
return _simple_gather_all_tensors(result, group, world_size)

# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
pad_dims = []
pad_by = (max_size - local_size).detach().cpu()
for val in reversed(pad_by):
pad_dims.append(0)
pad_dims.append(val.item())
result_padded = F.pad(result, pad_dims)
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result_padded, group)
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
return gathered_result