Skip to content

Commit

Permalink
Remove the deprecated AllGatherGrad class (#16360)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored and lantiga committed Jan 19, 2023
1 parent cf0952b commit 5d648e4
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 47 deletions.
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Removed the `Trainer(ipus=...)` argument
* Removed the `Trainer(num_processes=...)` argument

- Removed the deprecated `pytorch_lightning.utilities.AllGatherGrad` class ([#16360](https://github.com/Lightning-AI/lightning/pull/16360))

- Removed the deprecated `resume_from_checkpoint` Trainer argument ([#16167](https://github.com/Lightning-AI/lightning/pull/16167))

- Removed the deprecated automatic GPU selection ([#16184](https://github.com/Lightning-AI/lightning/pull/16184))
Expand Down
1 change: 0 additions & 1 deletion src/pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from lightning_fabric.utilities import LightningEnum # noqa: F401
from lightning_fabric.utilities import move_data_to_device # noqa: F401
from pytorch_lightning.utilities.distributed import AllGatherGrad # noqa: F401
from pytorch_lightning.utilities.enums import GradClipAlgorithmType # noqa: F401
from pytorch_lightning.utilities.grads import grad_norm # noqa: F401
from pytorch_lightning.utilities.imports import ( # noqa: F401
Expand Down
37 changes: 1 addition & 36 deletions src/pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# limitations under the License.
"""Utilities that can be used with distributed training."""

from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional

import torch
from torch import Tensor
from torch.nn.parallel.distributed import DistributedDataParallel

from lightning_fabric.utilities.distributed import _all_gather_ddp_if_available as new_all_gather_ddp_if_available
Expand Down Expand Up @@ -177,40 +176,6 @@ def gather_all_tensors(*args: Any, **kwargs: Any) -> Any:
return new_gather_all_tensors(*args, **kwargs)


class AllGatherGrad(torch.autograd.Function):
"""Gathers tensors from the whole group and stacks them.
This implementation is copied from PyTorch.
.. deprecated:: v1.8.0
This function has been deprecated in v1.8.0 in favor of :func:`torch.distributed.nn.functional.all_gather` and
will be removed in v2.0.0.
"""

@staticmethod
def forward( # type: ignore[override]
ctx: Any,
tensor: Tensor,
group: Optional["torch.distributed.ProcessGroup"] = None,
) -> Tensor:
rank_zero_deprecation(
"`AllGatherGrad` has been deprecated in v1.8.0 and will be removed in v2.0.0."
" Use `torch.distributed.nn.functional.all_gather` instead.",
stacklevel=6,
)
ctx.group = group
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(gathered_tensor, tensor, group=group)
gathered_tensor = torch.stack(gathered_tensor, dim=0)
return gathered_tensor

@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor, None]:
grad_output = torch.cat(grad_output)
torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
return grad_output[torch.distributed.get_rank()], None


def get_default_process_group_backend_for_device(*args: Any, **kwargs: Any) -> Any:
rank_zero_deprecation(
"`pytorch_lightning.utilities.distributed.get_default_process_group_backend_for_device` has been deprecated"
Expand Down
10 changes: 0 additions & 10 deletions tests/tests_pytorch/deprecated_api/test_remove_2-0.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,16 +315,6 @@ def test_tuning_trainer_property():
trainer.tuning = True


@RunIf(skip_windows=True)
def test_v1_8_0_deprecated_all_gather_grad():
tensor1 = torch.ones(1, requires_grad=True)
with mock.patch("torch.distributed.all_gather"), mock.patch("torch.distributed.get_world_size", return_value=1):
from pytorch_lightning.utilities import AllGatherGrad

with pytest.deprecated_call(match="`AllGatherGrad` has been deprecated in v1.8"):
AllGatherGrad.apply(tensor1)


def test_v1_8_1_deprecated_rank_zero_only():
from pytorch_lightning.utilities.distributed import rank_zero_only

Expand Down

0 comments on commit 5d648e4

Please sign in to comment.