From 6b8ea895115e54dbe582073079f1865131263dc5 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 3 Mar 2021 19:14:25 +0000 Subject: [PATCH 1/4] Call clip gradients if clip val greater than 0 --- .../plugins/precision/sharded_native_amp.py | 3 ++ tests/plugins/test_sharded_plugin.py | 28 +++++++++++++++++-- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py index 12ae5d0bc6be3..39dc01f97df11 100644 --- a/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -33,5 +33,8 @@ def __init__(self) -> None: self.scaler = ShardedGradScaler() def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: + if clip_val <= 0: + return + optimizer = cast(OSS, optimizer) optimizer.clip_grad_norm(clip_val, norm_type=norm_type) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 9623aa4d0265c..08fcf911ef8ed 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -1,8 +1,8 @@ import os +from unittest import mock import pytest import torch - from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin @@ -12,7 +12,7 @@ @RunIf(fairscale=True) -@pytest.mark.parametrize(["accelerator"], [("ddp_sharded", ), ("ddp_sharded_spawn", )]) +@pytest.mark.parametrize(["accelerator"], [("ddp_sharded",), ("ddp_sharded_spawn",)]) def test_sharded_ddp_choice(tmpdir, accelerator): """ Test to ensure that plugin is correctly chosen @@ -57,7 +57,7 @@ def test_invalid_apex_sharded(tmpdir): @RunIf(min_gpus=2, amp_native=True, fairscale=True) -@pytest.mark.parametrize(["accelerator"], [("ddp_sharded", ), ("ddp_sharded_spawn", )]) +@pytest.mark.parametrize(["accelerator"], [("ddp_sharded",), ("ddp_sharded_spawn",)]) def test_ddp_choice_sharded_amp(tmpdir, accelerator): """ Test to ensure that plugin native amp plugin is correctly chosen when using sharded @@ -269,3 +269,25 @@ def test_ddp_sharded_plugin_test_multigpu(tmpdir): ) trainer.test(model) + + +@pytest.mark.parametrize("clip_val", [0, 10]) +@RunIf(min_gpus=1, skip_windows=True, amp_native=True, fairscale=True) +@mock.patch('fairscale.optim.oss.OSS.clip_grad_norm') +def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_val, tmpdir): + """ + Ensure that clip gradients is only called if the value is greater than 0. + """ + model = BoringModel() + trainer = Trainer( + accelerator='ddp_sharded', + gpus=1, + precision=16, + fast_dev_run=True, + gradient_clip_val=clip_val + ) + trainer.fit(model) + if clip_val > 0: + mock_oss_clip_grad_norm.assert_called() + else: + mock_oss_clip_grad_norm.assert_not_called() From f03b427ccd1920ef021389142b9b34cabb61d32b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 3 Mar 2021 19:16:30 +0000 Subject: [PATCH 2/4] format From d7bf9ef925c7744d9a6240aea23143c75ecd210e Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 3 Mar 2021 19:17:24 +0000 Subject: [PATCH 3/4] Format --- tests/plugins/test_sharded_plugin.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 08fcf911ef8ed..9404802d14dad 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -3,6 +3,7 @@ import pytest import torch + from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin @@ -12,7 +13,7 @@ @RunIf(fairscale=True) -@pytest.mark.parametrize(["accelerator"], [("ddp_sharded",), ("ddp_sharded_spawn",)]) +@pytest.mark.parametrize(["accelerator"], [("ddp_sharded", ), ("ddp_sharded_spawn", )]) def test_sharded_ddp_choice(tmpdir, accelerator): """ Test to ensure that plugin is correctly chosen @@ -57,7 +58,7 @@ def test_invalid_apex_sharded(tmpdir): @RunIf(min_gpus=2, amp_native=True, fairscale=True) -@pytest.mark.parametrize(["accelerator"], [("ddp_sharded",), ("ddp_sharded_spawn",)]) +@pytest.mark.parametrize(["accelerator"], [("ddp_sharded", ), ("ddp_sharded_spawn", )]) def test_ddp_choice_sharded_amp(tmpdir, accelerator): """ Test to ensure that plugin native amp plugin is correctly chosen when using sharded @@ -279,13 +280,7 @@ def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_v Ensure that clip gradients is only called if the value is greater than 0. """ model = BoringModel() - trainer = Trainer( - accelerator='ddp_sharded', - gpus=1, - precision=16, - fast_dev_run=True, - gradient_clip_val=clip_val - ) + trainer = Trainer(accelerator='ddp_sharded', gpus=1, precision=16, fast_dev_run=True, gradient_clip_val=clip_val) trainer.fit(model) if clip_val > 0: mock_oss_clip_grad_norm.assert_called() From 3f1aeb6748b2731f1cacbc2e408698f6f80e988b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 4 Mar 2021 10:23:04 +0000 Subject: [PATCH 4/4] Move to top of file --- tests/plugins/test_sharded_plugin.py | 32 ++++++++++++++-------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 9404802d14dad..b59563f70e4aa 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -12,6 +12,22 @@ from tests.helpers.runif import RunIf +@pytest.mark.parametrize("clip_val", [0, 10]) +@RunIf(min_gpus=1, skip_windows=True, amp_native=True, fairscale=True) +@mock.patch('fairscale.optim.oss.OSS.clip_grad_norm') +def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_val, tmpdir): + """ + Ensure that clip gradients is only called if the value is greater than 0. + """ + model = BoringModel() + trainer = Trainer(accelerator='ddp_sharded', gpus=1, precision=16, fast_dev_run=True, gradient_clip_val=clip_val) + trainer.fit(model) + if clip_val > 0: + mock_oss_clip_grad_norm.assert_called() + else: + mock_oss_clip_grad_norm.assert_not_called() + + @RunIf(fairscale=True) @pytest.mark.parametrize(["accelerator"], [("ddp_sharded", ), ("ddp_sharded_spawn", )]) def test_sharded_ddp_choice(tmpdir, accelerator): @@ -270,19 +286,3 @@ def test_ddp_sharded_plugin_test_multigpu(tmpdir): ) trainer.test(model) - - -@pytest.mark.parametrize("clip_val", [0, 10]) -@RunIf(min_gpus=1, skip_windows=True, amp_native=True, fairscale=True) -@mock.patch('fairscale.optim.oss.OSS.clip_grad_norm') -def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_val, tmpdir): - """ - Ensure that clip gradients is only called if the value is greater than 0. - """ - model = BoringModel() - trainer = Trainer(accelerator='ddp_sharded', gpus=1, precision=16, fast_dev_run=True, gradient_clip_val=clip_val) - trainer.fit(model) - if clip_val > 0: - mock_oss_clip_grad_norm.assert_called() - else: - mock_oss_clip_grad_norm.assert_not_called()