diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 343a79c76b17f..b453fd1861f5e 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed ``ModelParallelStrategy`` single-file checkpointing when ``torch.compile`` wraps the model so optimizer states no longer raise ``KeyError`` during save ([#21357](https://github.com/Lightning-AI/pytorch-lightning/issues/21357)) + +- Fixed gradient clipping not working with fused optimizers when using ``bf16-mixed`` precision ([#21435](https://github.com/Lightning-AI/pytorch-lightning/issues/21435)) - ### Deprecated diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 5ea62233e1f69..ce6986678d577 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -104,7 +104,7 @@ def clip_gradients( clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: - if clip_val > 0 and _optimizer_handles_unscaling(optimizer): + if clip_val > 0 and self.scaler is not None and _optimizer_handles_unscaling(optimizer): raise RuntimeError( f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping" " because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?" diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index 3894c4256e0b8..0acca49844a1a 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -44,15 +44,31 @@ def test_clip_gradients(): precision.clip_grad_by_norm.assert_called_once() -def test_optimizer_amp_scaling_support_in_step_method(): - """Test that the plugin checks if the optimizer takes over unscaling in its step, making it incompatible with - gradient clipping (example: fused Adam).""" +@pytest.mark.parametrize( + ("precision", "scaler", "should_error"), + [ + ("16-mixed", Mock(), True), # fp16 with scaler: fused optimizer + clip = error + ("bf16-mixed", None, False), # bf16 no scaler: fused optimizer + clip = ok + ], +) +def test_optimizer_amp_scaling_support_in_step_method(precision, scaler, should_error): + """Test that gradient clipping with fused optimizers is only blocked when a scaler is present. + The `_step_supports_amp_scaling` flag indicates the optimizer handles unscaling internally (e.g., fused Adam). + This is incompatible with gradient clipping only when using a GradScaler (16-mixed), since we can't unscale + before clipping. With bf16-mixed there's no scaler, so gradient clipping works normally. + + """ optimizer = Mock(_step_supports_amp_scaling=True) - precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock()) + plugin = MixedPrecision(precision=precision, device="cuda:0", scaler=scaler) + plugin.clip_grad_by_norm = Mock() - with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"): - precision.clip_gradients(optimizer, clip_val=1.0) + if should_error: + with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"): + plugin.clip_gradients(optimizer, clip_val=1.0) + else: + plugin.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM) + plugin.clip_grad_by_norm.assert_called_once() def test_amp_with_no_grad():