Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make all_reduce consistent for both NCCL and GLOO #18235

Merged
merged 24 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue that would prevent the user to set the multiprocessing start method after importing lightning ([#18177](https://github.com/Lightning-AI/lightning/pull/18177))


- Fixed an issue with `Fabric.all_reduce()` not performing an inplace operation for all backends consistently ([#18235](https://github.com/Lightning-AI/lightning/pull/18235))


## [2.0.5] - 2023-07-07

Expand Down
4 changes: 3 additions & 1 deletion src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,11 +542,13 @@ def all_reduce(
) -> Union[Tensor, Dict, List, Tuple]:
"""Reduce tensors or collections of tensors from multiple processes.

The reduction on tensors is applied in-place, meaning the result will be placed back into the input tensor.
This method needs to be called on all processes and the tensors need to have the same shape across all
processes, otherwise your program will stall forever.

Args:
data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof.
data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof. Tensor will be
modified in-place.
group: the process group to reduce results across. Defaults to all processes (world).
reduce_op: the reduction operation. Defaults to 'mean'. Can also be a string 'sum' or ReduceOp.
Some strategies may limit the choices here.
Expand Down
27 changes: 16 additions & 11 deletions src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def _sync_ddp_if_available(


def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Tensor:
"""Function to reduce the tensors from several DDP processes to one main process.
"""Reduces a tensor across several distributed processes.

This operation is performed in-place, meaning the result will be placed back into the input tensor on all processes.

Args:
result: The value to sync and reduce (typically tensor or number)
Expand All @@ -114,24 +116,25 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.

Return:
reduced value
The reduced value.
"""
divide_by_world_size = False

if group is None:
group = torch.distributed.group.WORLD
reduce_op = "avg" if reduce_op == "mean" else reduce_op
group = torch.distributed.group.WORLD if group is None else group

op: Optional[ReduceOp]
if isinstance(reduce_op, str):
if reduce_op.lower() in ("avg", "mean"):
if reduce_op.lower() == "avg" and torch.distributed.get_backend(group) == "gloo":
# The GLOO backend does not support the `ReduceOp.AVG` operation
op = ReduceOp.SUM # type: ignore[assignment]
divide_by_world_size = True
else:
op = getattr(ReduceOp, reduce_op.upper())
else:
op = reduce_op

# WA for HPU. HPU doesn't support Long types, forcefully set it to float
# HPU doesn't support Long types, forcefully set it to float
# TODO: move this to the `lightning_habana` package
if (
package_available("habana_frameworks")
and os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1"
Expand All @@ -147,11 +150,13 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U
# Sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
world_size = torch.distributed.get_world_size(group)

if divide_by_world_size:
result = result / torch.distributed.get_world_size(group)

return result
if not divide_by_world_size:
return result
if not torch.is_floating_point(result):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
return result.copy_(result / world_size)
return result.div_(world_size)
carmocca marked this conversation as resolved.
Show resolved Hide resolved


class _AllGather(torch.autograd.Function):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def update(self, value: _VALUE, batch_size: int) -> None:

def compute(self) -> Tensor:
if self.is_tensor:
value = self.meta.sync(self.value)
value = self.meta.sync(self.value.clone()) # `clone` because `sync` is in-place
if self.meta.is_mean_reduction:
cumulated_batch_size = self.meta.sync(self.cumulated_batch_size)
return value / cumulated_batch_size
Expand Down
33 changes: 30 additions & 3 deletions tests/tests_fabric/utilities/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies import DDPStrategy
from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from lightning.fabric.utilities.distributed import _gather_all_tensors
from lightning.fabric.utilities.distributed import _gather_all_tensors, _sync_ddp
from tests_fabric.helpers.runif import RunIf


Expand Down Expand Up @@ -62,20 +62,47 @@ def _test_all_gather_uneven_tensors_multidim(strategy):
assert (val == torch.ones_like(val)).all()


def _test_all_reduce(strategy):
rank = strategy.local_rank
device = strategy.root_device
world_size = strategy.num_processes

for dtype in (torch.long, torch.int, torch.float, torch.half):
# max
tensor = torch.tensor(rank + 1, device=device, dtype=dtype)
expected = torch.tensor(2, device=device, dtype=dtype)
result = _sync_ddp(tensor, reduce_op="max")
assert torch.equal(result, expected)
assert result is tensor # inplace
# sum
tensor = torch.tensor(rank + 1, device=device, dtype=dtype)
expected = torch.tensor(sum(range(1, world_size + 1)), device=device, dtype=dtype)
result = _sync_ddp(tensor, reduce_op="sum")
assert torch.equal(result, expected)
assert result is tensor # inplace
# average
tensor = torch.tensor(rank + 1, device=device, dtype=dtype)
expected = torch.tensor(sum(range(1, world_size + 1)) / 2, device=device, dtype=dtype)
result = _sync_ddp(tensor, reduce_op="avg")
assert torch.equal(result, expected)
assert result is tensor # inplace


@RunIf(skip_windows=True)
@pytest.mark.parametrize(
"process",
[
_test_all_gather_uneven_tensors_multidim,
_test_all_gather_uneven_tensors,
_test_all_reduce,
],
)
@pytest.mark.parametrize(
"devices",
[
pytest.param([torch.device("cuda:0"), torch.device("cuda:1")], marks=RunIf(min_cuda_gpus=2)),
[torch.device("cpu")] * 2,
[torch.device("cpu"), torch.device("cpu")],
],
)
def test_gather_all_tensors(devices, process):
def test_collective_operations(devices, process):
spawn_launch(process, devices)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def on_train_epoch_end(self) -> None:
assert metrics["callback"]["tracking"] == expected
assert computed_value == 2

assert self.results["training_step.tracking_2"].value == total * devices
assert self.results["training_step.tracking_2"].value == total
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
assert metrics["callback"]["tracking_2"] == expected
assert computed_value == 2
self.has_validated_sum = True
Expand Down
Loading