diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 4465f978887ec..4f2c3e349eca3 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -192,6 +192,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue that would prevent the user to set the multiprocessing start method after importing lightning ([#18177](https://github.com/Lightning-AI/lightning/pull/18177)) +- Fixed an issue with `Fabric.all_reduce()` not performing an inplace operation for all backends consistently ([#18235](https://github.com/Lightning-AI/lightning/pull/18235)) + ## [2.0.5] - 2023-07-07 diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 0f229a5538c54..60d0af7b57eae 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -557,11 +557,13 @@ def all_reduce( ) -> Union[Tensor, Dict, List, Tuple]: """Reduce tensors or collections of tensors from multiple processes. + The reduction on tensors is applied in-place, meaning the result will be placed back into the input tensor. This method needs to be called on all processes and the tensors need to have the same shape across all processes, otherwise your program will stall forever. Args: - data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof. + data: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof. Tensor will be + modified in-place. group: the process group to reduce results across. Defaults to all processes (world). reduce_op: the reduction operation. Defaults to 'mean'. Can also be a string 'sum' or ReduceOp. Some strategies may limit the choices here. diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index c7f52161c4c38..0c261a9f09905 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -107,7 +107,9 @@ def _sync_ddp_if_available( def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> Tensor: - """Function to reduce the tensors from several DDP processes to one main process. + """Reduces a tensor across several distributed processes. + + This operation is performed in-place, meaning the result will be placed back into the input tensor on all processes. Args: result: The value to sync and reduce (typically tensor or number) @@ -116,17 +118,17 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U Can also be a string of 'avg', 'mean' to calculate the mean during reduction. Return: - reduced value + The reduced value. """ divide_by_world_size = False - - if group is None: - group = torch.distributed.group.WORLD + group = torch.distributed.group.WORLD if group is None else group op: Optional[ReduceOp] if isinstance(reduce_op, str): - if reduce_op.lower() in ("avg", "mean"): + reduce_op = "avg" if reduce_op == "mean" else reduce_op + if reduce_op.lower() == "avg" and torch.distributed.get_backend(group) == "gloo": + # The GLOO backend does not support the `ReduceOp.AVG` operation op = ReduceOp.SUM # type: ignore[assignment] divide_by_world_size = True else: @@ -134,7 +136,8 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U else: op = reduce_op - # WA for HPU. HPU doesn't support Long types, forcefully set it to float + # HPU doesn't support Long types, forcefully set it to float + # TODO: move this to the `lightning_habana` package if ( package_available("habana_frameworks") and os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1" @@ -150,11 +153,15 @@ def _sync_ddp(result: Tensor, group: Optional[Any] = None, reduce_op: Optional[U # Sync all processes before reduction torch.distributed.barrier(group=group) torch.distributed.all_reduce(result, op=op, group=group, async_op=False) + world_size = torch.distributed.get_world_size(group) - if divide_by_world_size: - result = result / torch.distributed.get_world_size(group) - - return result + if not divide_by_world_size: + return result + # `torch.distributed.all_reduce` is in-place, so we should do the division in-place to leave the modified tensors + # with the expected value + if not torch.is_floating_point(result): + return result.copy_(result / world_size) + return result.div_(world_size) class _AllGather(torch.autograd.Function): diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index c74f34dbe1644..bf574be3d26a3 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -237,7 +237,7 @@ def update(self, value: _VALUE, batch_size: int) -> None: def compute(self) -> Tensor: if self.is_tensor: - value = self.meta.sync(self.value) + value = self.meta.sync(self.value.clone()) # `clone` because `sync` is in-place if self.meta.is_mean_reduction: cumulated_batch_size = self.meta.sync(self.cumulated_batch_size) return value / cumulated_batch_size diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index e6259f5d21dba..49b47bf54407d 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -7,7 +7,7 @@ from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import DDPStrategy from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher -from lightning.fabric.utilities.distributed import _gather_all_tensors +from lightning.fabric.utilities.distributed import _gather_all_tensors, _sync_ddp from tests_fabric.helpers.runif import RunIf @@ -62,20 +62,47 @@ def _test_all_gather_uneven_tensors_multidim(strategy): assert (val == torch.ones_like(val)).all() +def _test_all_reduce(strategy): + rank = strategy.local_rank + device = strategy.root_device + world_size = strategy.num_processes + + for dtype in (torch.long, torch.int, torch.float, torch.half): + # max + tensor = torch.tensor(rank + 1, device=device, dtype=dtype) + expected = torch.tensor(2, device=device, dtype=dtype) + result = _sync_ddp(tensor, reduce_op="max") + assert torch.equal(result, expected) + assert result is tensor # inplace + # sum + tensor = torch.tensor(rank + 1, device=device, dtype=dtype) + expected = torch.tensor(sum(range(1, world_size + 1)), device=device, dtype=dtype) + result = _sync_ddp(tensor, reduce_op="sum") + assert torch.equal(result, expected) + assert result is tensor # inplace + # average + tensor = torch.tensor(rank + 1, device=device, dtype=dtype) + expected = torch.tensor(sum(range(1, world_size + 1)) / 2, device=device, dtype=dtype) + result = _sync_ddp(tensor, reduce_op="avg") + assert torch.equal(result, expected) + assert result is tensor # inplace + + @RunIf(skip_windows=True) @pytest.mark.parametrize( "process", [ _test_all_gather_uneven_tensors_multidim, _test_all_gather_uneven_tensors, + _test_all_reduce, ], ) @pytest.mark.parametrize( "devices", [ pytest.param([torch.device("cuda:0"), torch.device("cuda:1")], marks=RunIf(min_cuda_gpus=2)), - [torch.device("cpu")] * 2, + [torch.device("cpu"), torch.device("cpu")], ], ) -def test_gather_all_tensors(devices, process): +def test_collective_operations(devices, process): spawn_launch(process, devices) diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index fd3cf57fd3773..d00b24962ead7 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -356,7 +356,7 @@ def on_train_epoch_end(self) -> None: assert metrics["callback"]["tracking"] == expected assert computed_value == 2 - assert self.results["training_step.tracking_2"].value == total * devices + assert self.results["training_step.tracking_2"].value == total assert metrics["callback"]["tracking_2"] == expected assert computed_value == 2 self.has_validated_sum = True