From 35d03b8984b2f1b6145fe41045db286e3e5e4d94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Sep 2021 18:15:50 +0200 Subject: [PATCH 1/5] move unscale to post_backward for automatic optimization --- CHANGELOG.md | 7 ++++ pytorch_lightning/accelerators/accelerator.py | 6 +-- .../plugins/precision/native_amp.py | 13 ++++++- .../plugins/precision/precision_plugin.py | 7 +++- tests/plugins/test_amp_plugins.py | 39 ++++++++++++++++++- 5 files changed, 65 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7cfa086b8fc6e..3ac3cf47839a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [unreleased] - 2021-??-?? + +- Moved the gradient unscaling in `NativeMixedPrecisionPlugin` from `pre_optimizer_step` to `post_backward` ([#9606](https://github.com/PyTorchLightning/pytorch-lightning/pull/9606)) +- Fixed gradient unscaling being called too late, causing gradient clipping and gradient norm tracking to be applied incorrectly ([#9606](https://github.com/PyTorchLightning/pytorch-lightning/pull/9606)) + + ## [1.4.8] - 2021-09-22 - Fixed error reporting in DDP process reconciliation when processes are launched by an external agent (#9389) @@ -12,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `add_argparse_args` raising `TypeError` when args are typed as `typing.Generic` in Python 3.6 ([#9554](https://github.com/PyTorchLightning/pytorch-lightning/pull/9554)) - Fixed back-compatibility for saving hyperparameters from a single container and inferring its argument name by reverting [#9125](https://github.com/PyTorchLightning/pytorch-lightning/pull/9125) ([#9642](https://github.com/PyTorchLightning/pytorch-lightning/pull/9642)) + ## [1.4.7] - 2021-09-14 - Fixed logging of nan parameters ([#9364](https://github.com/PyTorchLightning/pytorch-lightning/pull/9364)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index c1aa4281aaabb..8e6d2e2bc4406 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -264,7 +264,7 @@ def validation_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OU """ return self.training_type_plugin.validation_step_end(output) - def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: + def backward(self, closure_loss: Tensor, optimizer: torch.optim.Optimizer, *args: Any, **kwargs: Any) -> Tensor: """Forwards backward-calls to the precision plugin. Args: @@ -273,9 +273,9 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: self.training_type_plugin.pre_backward(closure_loss) closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss) - self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs) + self.precision_plugin.backward(self.lightning_module, closure_loss, optimizer, *args, **kwargs) - closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss) + closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss, optimizer) self.training_type_plugin.post_backward(closure_loss) return closure_loss diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 8da701f6494c1..1c4a55013357b 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -55,7 +55,8 @@ def pre_optimizer_step( " To request, please file a Github issue in PyTorch and tag @mcarilli" ) result = lambda_closure() # native amp does not support closures - self.scaler.unscale_(optimizer) + if not model.automatic_optimization: + self.scaler.unscale_(optimizer) super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) skipped_backward = result is None # in manual optimization, the closure does not return a value @@ -65,6 +66,16 @@ def pre_optimizer_step( self.scaler.update() return False + def post_backward( + self, model: "pl.LightningModule", closure_loss: torch.Tensor, optimizer: Optimizer + ) -> torch.Tensor: + ret_val = super().post_backward(model, closure_loss, optimizer) + # unscale here to have it inside the closure before the grad tracking and clipping + if model.automatic_optimization and not model.trainer.fit_loop._should_accumulate(): + self.scaler.unscale_(optimizer) + return ret_val + + @contextmanager def train_step_context(self) -> Generator[None, None, None]: """Enable autocast context""" diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 1261fea87c06e..d897613a4a62e 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -79,8 +79,10 @@ def backward( else: closure_loss.backward(*args, **kwargs) - def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor: - """Run after precision plugin executes backward + def post_backward( + self, model: "pl.LightningModule", closure_loss: Tensor, optimizer: torch.optim.Optimizer + ) -> Tensor: + """Run after precision plugin executes backward. Args: model: the model to be optimized @@ -89,6 +91,7 @@ def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Te # once backward has been applied, release graph closure_loss = closure_loss.detach() model.trainer.call_hook("on_after_backward") + return closure_loss def pre_optimizer_step( diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index d5862635c7af9..8705280743382 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -18,7 +18,7 @@ import pytest import torch -from pytorch_lightning import Trainer +from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from tests.helpers import BoringModel @@ -174,3 +174,40 @@ def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir): assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin) model = BoringModel() trainer.fit(model) + + +class GradientUnscaleNativeAMPPlugin(NativeMixedPrecisionPlugin): + _was_scaled_finite = 0 + + def post_backward(self, model, closure_loss, optimizer) -> torch.Tensor: + norm_before = torch.nn.utils.clip_grad_norm_(model.parameters(), 2) + ret_val = super().post_backward(model, closure_loss, optimizer) + norm_after = torch.nn.utils.clip_grad_norm_(model.parameters(), 2) + + # norm_after unscale should be smaller by scaling factor greater than 1 + if not (torch.isinf(norm_before) or torch.isnan(norm_before)): + assert norm_after < norm_before + # during initial phase of finding the appropriate scaling, AMP skips optimizer steps that have + # non-finite gradients; we count and assert that we had at least one finite gradient here + self._was_scaled_finite += 1 + return ret_val + + +@RunIf(min_gpus=1, amp_native=True) +def test_correct_native_grad_unscaling(tmpdir): + """Test that the gradient clipping gets applied at the appropriate place when using mixed precision plugins.""" + seed_everything(42) + plugin = GradientUnscaleNativeAMPPlugin() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=4, + max_epochs=1, + precision=16, + amp_backend="native", + gpus=1, + plugins=plugin, + ) + assert isinstance(trainer.precision_plugin, GradientUnscaleNativeAMPPlugin) + model = BoringModel() + trainer.fit(model) + assert plugin._was_scaled_finite From 26732fd51bb29a4bcec5415a81a00781ff0efc0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Sep 2021 18:20:33 +0200 Subject: [PATCH 2/5] fix should_accumulate attribute access --- pytorch_lightning/plugins/precision/native_amp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 1c4a55013357b..c5fb2d77410a6 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -71,11 +71,10 @@ def post_backward( ) -> torch.Tensor: ret_val = super().post_backward(model, closure_loss, optimizer) # unscale here to have it inside the closure before the grad tracking and clipping - if model.automatic_optimization and not model.trainer.fit_loop._should_accumulate(): + if model.automatic_optimization and not model.trainer.fit_loop.should_accumulate(): self.scaler.unscale_(optimizer) return ret_val - @contextmanager def train_step_context(self) -> Generator[None, None, None]: """Enable autocast context""" From 23ad9f4515b037949f956715404f1fb24673e71f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Sep 2021 18:25:12 +0200 Subject: [PATCH 3/5] Update pytorch_lightning/plugins/precision/native_amp.py Co-authored-by: Sean Naren --- pytorch_lightning/plugins/precision/native_amp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index c5fb2d77410a6..aae8b801a946e 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -56,6 +56,8 @@ def pre_optimizer_step( ) result = lambda_closure() # native amp does not support closures if not model.automatic_optimization: + # unscale in manual optimization as user does not rely on lightning + # to call backward, but does call LightningOptimizer.step self.scaler.unscale_(optimizer) super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) skipped_backward = result is None From 3bdbc1dab103fd193fc5472ec9510ff29c4d7268 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 24 Sep 2021 18:31:37 +0200 Subject: [PATCH 4/5] add assertion for LightningOptimizer --- tests/plugins/test_amp_plugins.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index 8705280743382..48d26f449fc96 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -19,6 +19,7 @@ import torch from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from tests.helpers import BoringModel @@ -180,6 +181,7 @@ class GradientUnscaleNativeAMPPlugin(NativeMixedPrecisionPlugin): _was_scaled_finite = 0 def post_backward(self, model, closure_loss, optimizer) -> torch.Tensor: + assert not isinstance(optimizer, LightningOptimizer) norm_before = torch.nn.utils.clip_grad_norm_(model.parameters(), 2) ret_val = super().post_backward(model, closure_loss, optimizer) norm_after = torch.nn.utils.clip_grad_norm_(model.parameters(), 2) From 225d5f4356ceebc95ebeb784b9ef5b50c2a5ce5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 27 Sep 2021 13:08:46 +0200 Subject: [PATCH 5/5] Update tests/plugins/test_amp_plugins.py Co-authored-by: thomas chaton --- tests/plugins/test_amp_plugins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index 48d26f449fc96..806cb753cb15b 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -188,7 +188,7 @@ def post_backward(self, model, closure_loss, optimizer) -> torch.Tensor: # norm_after unscale should be smaller by scaling factor greater than 1 if not (torch.isinf(norm_before) or torch.isnan(norm_before)): - assert norm_after < norm_before + assert norm_after < norm_before * 10 # during initial phase of finding the appropriate scaling, AMP skips optimizer steps that have # non-finite gradients; we count and assert that we had at least one finite gradient here self._was_scaled_finite += 1