Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix location of unscale in mixed precision plugin #9606

Merged
merged 5 commits into from
Sep 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [unreleased] - 2021-??-??

- Moved the gradient unscaling in `NativeMixedPrecisionPlugin` from `pre_optimizer_step` to `post_backward` ([#9606](https://github.com/PyTorchLightning/pytorch-lightning/pull/9606))
- Fixed gradient unscaling being called too late, causing gradient clipping and gradient norm tracking to be applied incorrectly ([#9606](https://github.com/PyTorchLightning/pytorch-lightning/pull/9606))


## [1.4.8] - 2021-09-22

- Fixed error reporting in DDP process reconciliation when processes are launched by an external agent (#9389)
- Added PL_RECONCILE_PROCESS environment variable to enable process reconciliation regardless of cluster environment settings (#9389)
- Fixed `add_argparse_args` raising `TypeError` when args are typed as `typing.Generic` in Python 3.6 ([#9554](https://github.com/PyTorchLightning/pytorch-lightning/pull/9554))
- Fixed back-compatibility for saving hyperparameters from a single container and inferring its argument name by reverting [#9125](https://github.com/PyTorchLightning/pytorch-lightning/pull/9125) ([#9642](https://github.com/PyTorchLightning/pytorch-lightning/pull/9642))


## [1.4.7] - 2021-09-14

- Fixed logging of nan parameters ([#9364](https://github.com/PyTorchLightning/pytorch-lightning/pull/9364))
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def validation_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OU
"""
return self.training_type_plugin.validation_step_end(output)

def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
def backward(self, closure_loss: Tensor, optimizer: torch.optim.Optimizer, *args: Any, **kwargs: Any) -> Tensor:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Forwards backward-calls to the precision plugin.

Args:
Expand All @@ -273,9 +273,9 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
self.training_type_plugin.pre_backward(closure_loss)
closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss)

self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
self.precision_plugin.backward(self.lightning_module, closure_loss, optimizer, *args, **kwargs)

closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss)
closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss, optimizer)
self.training_type_plugin.post_backward(closure_loss)

return closure_loss
Expand Down
14 changes: 13 additions & 1 deletion pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def pre_optimizer_step(
" To request, please file a Github issue in PyTorch and tag @mcarilli"
)
result = lambda_closure() # native amp does not support closures
self.scaler.unscale_(optimizer)
if not model.automatic_optimization:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
# unscale in manual optimization as user does not rely on lightning
# to call backward, but does call LightningOptimizer.step
self.scaler.unscale_(optimizer)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
skipped_backward = result is None
# in manual optimization, the closure does not return a value
Expand All @@ -65,6 +68,15 @@ def pre_optimizer_step(
self.scaler.update()
return False

def post_backward(
self, model: "pl.LightningModule", closure_loss: torch.Tensor, optimizer: Optimizer
) -> torch.Tensor:
ret_val = super().post_backward(model, closure_loss, optimizer)
# unscale here to have it inside the closure before the grad tracking and clipping
if model.automatic_optimization and not model.trainer.fit_loop.should_accumulate():
self.scaler.unscale_(optimizer)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
return ret_val

@contextmanager
def train_step_context(self) -> Generator[None, None, None]:
"""Enable autocast context"""
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ def backward(
else:
closure_loss.backward(*args, **kwargs)

def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor:
"""Run after precision plugin executes backward
def post_backward(
self, model: "pl.LightningModule", closure_loss: Tensor, optimizer: torch.optim.Optimizer
) -> Tensor:
"""Run after precision plugin executes backward.

Args:
model: the model to be optimized
Expand All @@ -89,6 +91,7 @@ def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Te
# once backward has been applied, release graph
closure_loss = closure_loss.detach()
model.trainer.call_hook("on_after_backward")

return closure_loss

def pre_optimizer_step(
Expand Down
41 changes: 40 additions & 1 deletion tests/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from tests.helpers import BoringModel
Expand Down Expand Up @@ -174,3 +175,41 @@ def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir):
assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin)
model = BoringModel()
trainer.fit(model)


class GradientUnscaleNativeAMPPlugin(NativeMixedPrecisionPlugin):
_was_scaled_finite = 0

def post_backward(self, model, closure_loss, optimizer) -> torch.Tensor:
assert not isinstance(optimizer, LightningOptimizer)
norm_before = torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
ret_val = super().post_backward(model, closure_loss, optimizer)
norm_after = torch.nn.utils.clip_grad_norm_(model.parameters(), 2)

# norm_after unscale should be smaller by scaling factor greater than 1
if not (torch.isinf(norm_before) or torch.isnan(norm_before)):
assert norm_after < norm_before * 10
# during initial phase of finding the appropriate scaling, AMP skips optimizer steps that have
# non-finite gradients; we count and assert that we had at least one finite gradient here
self._was_scaled_finite += 1
return ret_val


@RunIf(min_gpus=1, amp_native=True)
def test_correct_native_grad_unscaling(tmpdir):
"""Test that the gradient clipping gets applied at the appropriate place when using mixed precision plugins."""
seed_everything(42)
plugin = GradientUnscaleNativeAMPPlugin()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=4,
max_epochs=1,
precision=16,
amp_backend="native",
gpus=1,
plugins=plugin,
)
assert isinstance(trainer.precision_plugin, GradientUnscaleNativeAMPPlugin)
model = BoringModel()
trainer.fit(model)
assert plugin._was_scaled_finite