Skip to content

Commit e4611ef

Browse files
awaelchlicarmocca
andauthored
Support fused Adam with mixed precision (#15555)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
1 parent bf4653e commit e4611ef

File tree

5 files changed

+179
-4
lines changed

5 files changed

+179
-4
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6161

6262
- Fixed an issue with `WandbLogger(log_model=True|'all)` raising an error and not being able to serialize tensors in the metadata ([#15544](https://github.com/Lightning-AI/lightning/pull/15544))
6363

64+
- Fixed the gradient unscaling logic when using `Trainer(precision=16)` and fused optimizers such as `Adam(..., fused=True)` ([#15544](https://github.com/Lightning-AI/lightning/pull/15544))
65+
6466
- Fixed model state transfer in multiprocessing launcher when running multi-node ([#15567](https://github.com/Lightning-AI/lightning/pull/15567))
6567

6668

src/pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616

1717
import torch
1818
from torch import Tensor
19-
from torch.optim import LBFGS
19+
from torch.optim import LBFGS, Optimizer
2020

2121
import pytorch_lightning as pl
2222
from lightning_lite.accelerators.cuda import _patch_cuda_is_available
2323
from lightning_lite.utilities.types import Optimizable
2424
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
25-
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType
25+
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType, GradClipAlgorithmType
2626
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2727

2828
if _TORCH_GREATER_EQUAL_1_10:
@@ -83,8 +83,13 @@ def optimizer_step( # type: ignore[override]
8383
f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
8484
)
8585
closure_result = closure()
86-
# `unscale` after the closure is executed but before the `on_before_optimizer_step` hook.
87-
self.scaler.unscale_(optimizer)
86+
87+
if not _optimizer_handles_unscaling(optimizer):
88+
# Unscaling needs to be performed here in case we are going to apply gradient clipping.
89+
# Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam).
90+
# Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook.
91+
self.scaler.unscale_(optimizer)
92+
8893
self._after_closure(model, optimizer, optimizer_idx)
8994
skipped_backward = closure_result is None
9095
# in manual optimization, the closure does not return a value
@@ -95,6 +100,19 @@ def optimizer_step( # type: ignore[override]
95100
return step_output
96101
return closure_result
97102

103+
def clip_gradients(
104+
self,
105+
optimizer: Optimizer,
106+
clip_val: Union[int, float] = 0.0,
107+
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
108+
) -> None:
109+
if clip_val > 0 and _optimizer_handles_unscaling(optimizer):
110+
raise RuntimeError(
111+
f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping"
112+
" because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?"
113+
)
114+
super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)
115+
98116
def autocast_context_manager(self) -> Union["old_autocast", "new_autocast"]:
99117
if _TORCH_GREATER_EQUAL_1_10:
100118
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
@@ -116,3 +134,13 @@ def state_dict(self) -> Dict[str, Any]:
116134
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
117135
if self.scaler is not None:
118136
self.scaler.load_state_dict(state_dict)
137+
138+
139+
def _optimizer_handles_unscaling(optimizer: Any) -> bool:
140+
"""Determines whether a PyTorch optimizer handles unscaling gradients in the step method rather than through the
141+
:class:`torch.cuda.amp.GradScaler`.
142+
143+
Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return
144+
value will only be reliable for built-in PyTorch optimizers.
145+
"""
146+
return getattr(optimizer, "_step_supports_amp_scaling", False)

tests/tests_lite/plugins/precision/test_native_amp_integration.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from tests_lite.helpers.models import BoringLite
1919
from tests_lite.helpers.runif import RunIf
2020

21+
from lightning_lite import LightningLite, seed_everything
22+
2123

2224
class NativeMixedPrecisionModule(nn.Module):
2325
def __init__(self, expected_dtype):
@@ -70,3 +72,37 @@ def test_native_mixed_precision(accelerator, precision, expected_dtype):
7072
lite = NativeMixedPrecisionBoringLite(accelerator=accelerator, precision=precision)
7173
lite.expected_dtype = expected_dtype
7274
lite.run()
75+
76+
77+
@RunIf(min_torch="1.13", min_cuda_gpus=1)
78+
def test_native_mixed_precision_fused_optimizer_parity():
79+
def run(fused=False):
80+
seed_everything(1234)
81+
lite = LightningLite(accelerator="cuda", precision=16, devices=1)
82+
83+
model = nn.Linear(10, 10).to(lite.device) # TODO: replace with individual setup_model call
84+
optimizer = torch.optim.Adam(model.parameters(), lr=1.0, fused=fused)
85+
86+
model, optimizer = lite.setup(model, optimizer)
87+
assert isinstance(lite._precision.scaler, torch.cuda.amp.GradScaler)
88+
89+
data = torch.randn(10, 10, device="cuda")
90+
target = torch.randn(10, 10, device="cuda")
91+
92+
losses = []
93+
for _ in range(5):
94+
optimizer.zero_grad()
95+
output = model(data)
96+
loss = (output - target).abs().sum()
97+
lite.backward(loss)
98+
optimizer.step()
99+
losses.append(loss.detach())
100+
return torch.stack(losses), model.parameters()
101+
102+
losses, params = run(fused=False)
103+
losses_fused, params_fused = run(fused=True)
104+
105+
# Both the regular and the fused version of Adam produce the same losses and model weights
106+
torch.testing.assert_close(losses, losses_fused)
107+
for p, q in zip(params, params_fused):
108+
torch.testing.assert_close(p, q)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from unittest.mock import Mock
15+
16+
import pytest
17+
from torch.optim import Optimizer
18+
19+
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin
20+
from pytorch_lightning.utilities import GradClipAlgorithmType
21+
22+
23+
def test_clip_gradients():
24+
"""Test that `.clip_gradients()` is a no-op when clipping is disabled."""
25+
optimizer = Mock(spec=Optimizer)
26+
precision = NativeMixedPrecisionPlugin(precision=16, device="cuda:0", scaler=Mock())
27+
precision.clip_grad_by_value = Mock()
28+
precision.clip_grad_by_norm = Mock()
29+
precision.clip_gradients(optimizer)
30+
precision.clip_grad_by_value.assert_not_called()
31+
precision.clip_grad_by_norm.assert_not_called()
32+
33+
precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE)
34+
precision.clip_grad_by_value.assert_called_once()
35+
precision.clip_grad_by_norm.assert_not_called()
36+
37+
precision.clip_grad_by_value.reset_mock()
38+
precision.clip_grad_by_norm.reset_mock()
39+
40+
precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM)
41+
precision.clip_grad_by_value.assert_not_called()
42+
precision.clip_grad_by_norm.assert_called_once()
43+
44+
45+
def test_optimizer_amp_scaling_support_in_step_method():
46+
"""Test that the plugin checks if the optimizer takes over unscaling in its step, making it incompatible with
47+
gradient clipping (example: fused Adam)."""
48+
49+
optimizer = Mock(_step_supports_amp_scaling=True)
50+
precision = NativeMixedPrecisionPlugin(precision=16, device="cuda:0", scaler=Mock())
51+
52+
with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"):
53+
precision.clip_gradients(optimizer, clip_val=1.0)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import torch
15+
16+
from lightning_lite import seed_everything
17+
from pytorch_lightning import Trainer
18+
from pytorch_lightning.demos.boring_classes import BoringModel
19+
from tests_pytorch.helpers.runif import RunIf
20+
21+
22+
class FusedOptimizerParityModel(BoringModel):
23+
def __init__(self, fused=False):
24+
super().__init__()
25+
self.fused = fused
26+
27+
def configure_optimizers(self):
28+
assert isinstance(self.trainer.precision_plugin.scaler, torch.cuda.amp.GradScaler)
29+
return torch.optim.Adam(self.parameters(), lr=1.0, fused=self.fused)
30+
31+
32+
@RunIf(min_torch="1.13", min_cuda_gpus=1)
33+
def test_native_mixed_precision_fused_optimizer_parity(tmpdir):
34+
def run(fused=False):
35+
seed_everything(1234)
36+
model = FusedOptimizerParityModel(fused)
37+
trainer = Trainer(
38+
default_root_dir=tmpdir,
39+
accelerator="cuda",
40+
devices=1,
41+
precision=16,
42+
max_steps=5,
43+
logger=False,
44+
enable_checkpointing=False,
45+
enable_progress_bar=False,
46+
enable_model_summary=False,
47+
)
48+
trainer.fit(model)
49+
return model.parameters()
50+
51+
params = run(fused=False)
52+
params_fused = run(fused=True)
53+
54+
# Both the regular and the fused version of Adam produce the same losses and model weights
55+
for p, q in zip(params, params_fused):
56+
torch.testing.assert_close(p, q)

0 commit comments

Comments
 (0)