Skip to content

Commit

Permalink
Make all_reduce consistent for both NCCL and GLOO (#18235)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
awaelchli and carmocca authored Aug 9, 2023
1 parent 27d9125 commit 70e31b6
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 17 deletions.
2 changes: 2 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue that would prevent the user to set the multiprocessing start method after importing lightning ([#18177](https://github.com/Lightning-AI/lightning/pull/18177))


- Fixed an issue with `Fabric.all_reduce()` not performing an inplace operation for all backends consistently ([#18235](https://github.com/Lightning-AI/lightning/pull/18235))


## [2.0.5] - 2023-07-07

Expand Down
4 changes: 3 additions & 1 deletion src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,11 +557,13 @@ def all_reduce(
) -> Union[Tensor, Dict, List, Tuple]:
"""Reduce tensors or collections of tensors from multiple processes.
The reduction on tensors is applied in-place, meaning the result will be placed back into the input tensor.
This method needs to be called on all processes and the tensors need to have the same shape across all
processes, otherwise your program will stall forever.
Args:
data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof.
data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof. Tensor will be
modified in-place.
group: the process group to reduce results across. Defaults to all processes (world).
reduce_op: the reduction operation. Defaults to 'mean'. Can also be a string 'sum' or ReduceOp.
Some strategies may limit the choices here.
Expand Down
29 changes: 18 additions & 11 deletions src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def _sync_ddp_if_available(


def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Tensor:
"""Function to reduce the tensors from several DDP processes to one main process.
"""Reduces a tensor across several distributed processes.
This operation is performed in-place, meaning the result will be placed back into the input tensor on all processes.
Args:
result: The value to sync and reduce (typically tensor or number)
Expand All @@ -116,25 +118,26 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
Return:
reduced value
The reduced value.
"""
divide_by_world_size = False

if group is None:
group = torch.distributed.group.WORLD
group = torch.distributed.group.WORLD if group is None else group

op: Optional[ReduceOp]
if isinstance(reduce_op, str):
if reduce_op.lower() in ("avg", "mean"):
reduce_op = "avg" if reduce_op == "mean" else reduce_op
if reduce_op.lower() == "avg" and torch.distributed.get_backend(group) == "gloo":
# The GLOO backend does not support the `ReduceOp.AVG` operation
op = ReduceOp.SUM # type: ignore[assignment]
divide_by_world_size = True
else:
op = getattr(ReduceOp, reduce_op.upper())
else:
op = reduce_op

# WA for HPU. HPU doesn't support Long types, forcefully set it to float
# HPU doesn't support Long types, forcefully set it to float
# TODO: move this to the `lightning_habana` package
if (
package_available("habana_frameworks")
and os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1"
Expand All @@ -150,11 +153,15 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U
# Sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
world_size = torch.distributed.get_world_size(group)

if divide_by_world_size:
result = result / torch.distributed.get_world_size(group)

return result
if not divide_by_world_size:
return result
# `torch.distributed.all_reduce` is in-place, so we should do the division in-place to leave the modified tensors
# with the expected value
if not torch.is_floating_point(result):
return result.copy_(result / world_size)
return result.div_(world_size)


class _AllGather(torch.autograd.Function):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def update(self, value: _VALUE, batch_size: int) -> None:

def compute(self) -> Tensor:
if self.is_tensor:
value = self.meta.sync(self.value)
value = self.meta.sync(self.value.clone()) # `clone` because `sync` is in-place
if self.meta.is_mean_reduction:
cumulated_batch_size = self.meta.sync(self.cumulated_batch_size)
return value / cumulated_batch_size
Expand Down
33 changes: 30 additions & 3 deletions tests/tests_fabric/utilities/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies import DDPStrategy
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning.fabric.utilities.distributed import _gather_all_tensors
from lightning.fabric.utilities.distributed import _gather_all_tensors, _sync_ddp
from tests_fabric.helpers.runif import RunIf


Expand Down Expand Up @@ -62,20 +62,47 @@ def _test_all_gather_uneven_tensors_multidim(strategy):
assert (val == torch.ones_like(val)).all()


def _test_all_reduce(strategy):
rank = strategy.local_rank
device = strategy.root_device
world_size = strategy.num_processes

for dtype in (torch.long, torch.int, torch.float, torch.half):
# max
tensor = torch.tensor(rank + 1, device=device, dtype=dtype)
expected = torch.tensor(2, device=device, dtype=dtype)
result = _sync_ddp(tensor, reduce_op="max")
assert torch.equal(result, expected)
assert result is tensor # inplace
# sum
tensor = torch.tensor(rank + 1, device=device, dtype=dtype)
expected = torch.tensor(sum(range(1, world_size + 1)), device=device, dtype=dtype)
result = _sync_ddp(tensor, reduce_op="sum")
assert torch.equal(result, expected)
assert result is tensor # inplace
# average
tensor = torch.tensor(rank + 1, device=device, dtype=dtype)
expected = torch.tensor(sum(range(1, world_size + 1)) / 2, device=device, dtype=dtype)
result = _sync_ddp(tensor, reduce_op="avg")
assert torch.equal(result, expected)
assert result is tensor # inplace


@RunIf(skip_windows=True)
@pytest.mark.parametrize(
"process",
[
_test_all_gather_uneven_tensors_multidim,
_test_all_gather_uneven_tensors,
_test_all_reduce,
],
)
@pytest.mark.parametrize(
"devices",
[
pytest.param([torch.device("cuda:0"), torch.device("cuda:1")], marks=RunIf(min_cuda_gpus=2)),
[torch.device("cpu")] * 2,
[torch.device("cpu"), torch.device("cpu")],
],
)
def test_gather_all_tensors(devices, process):
def test_collective_operations(devices, process):
spawn_launch(process, devices)
2 changes: 1 addition & 1 deletion tests/tests_pytorch/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def on_train_epoch_end(self) -> None:
assert metrics["callback"]["tracking"] == expected
assert computed_value == 2

assert self.results["training_step.tracking_2"].value == total * devices
assert self.results["training_step.tracking_2"].value == total
assert metrics["callback"]["tracking_2"] == expected
assert computed_value == 2
self.has_validated_sum = True
Expand Down

0 comments on commit 70e31b6

Please sign in to comment.