From b90077eec12606105dffced0608af11c1bb8e2d4 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 3 Mar 2021 23:30:26 +0000 Subject: [PATCH 1/5] Add branch condition for calling move to device in prefetch --- .../plugins/training_type/ddp.py | 14 +++++++-- tests/accelerators/test_ddp.py | 30 ++++++++++++++++++- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 748dcdc9e6b68..cdaf299442ac8 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -43,7 +43,6 @@ from hydra.core.hydra_config import HydraConfig from hydra.utils import get_original_cwd, to_absolute_path - log = logging.getLogger(__name__) @@ -253,13 +252,22 @@ def pre_dispatch(self): if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) - # move the model to the correct device - self.model_to_device() + if self.call_move_to_device_hook_in_pre_dispatch: + # move the model to the correct device + self.model_to_device() self.configure_ddp() self.barrier() + @property + def call_move_to_device_hook_in_pre_dispatch(self) -> bool: + """ + Call the ``model_to_device`` function within pre_dispatch if this is set to True. + Useful for when plugin would like to call model_to_device at another time, or skip the call. + """ + return True + def post_dispatch(self): if "WORLD_SIZE" in os.environ: del os.environ["WORLD_SIZE"] diff --git a/tests/accelerators/test_ddp.py b/tests/accelerators/test_ddp.py index 14e73d920af4b..a7fba770433d1 100644 --- a/tests/accelerators/test_ddp.py +++ b/tests/accelerators/test_ddp.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from unittest import mock from unittest.mock import patch import pytest import torch from pytorch_lightning import Trainer +from pytorch_lightning.plugins import DDPPlugin from tests.accelerators import ddp_model, DDPLauncher from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -91,7 +93,6 @@ def test_torch_distributed_backend_env_variables(tmpdir): _environ = {"PL_TORCH_DISTRIBUTED_BACKEND": "undefined", "CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2"} with patch.dict(os.environ, _environ), \ patch('torch.cuda.device_count', return_value=2): - with pytest.raises(ValueError, match="Invalid backend: 'undefined'"): model = BoringModel() trainer = Trainer( @@ -102,3 +103,30 @@ def test_torch_distributed_backend_env_variables(tmpdir): logger=False, ) trainer.fit(model) + + +@pytest.mark.parametrize('move_to_device_pre_dispatch_enabled', [False, True]) +@mock.patch('pytorch_lightning.plugins.DDPPlugin.model_to_device', autospec=True) +def test_move_to_device_in_pre_dispatch(mock_model_to_device, move_to_device_pre_dispatch_enabled, tmpdir): + """ + Test if ``call_move_to_device_hook_in_pre_dispatch`` is disabled we do not move to device till later + in training. + """ + + class TestPropertyPlugin(DDPPlugin): + + @property + def call_move_to_device_hook_in_pre_dispatch(self) -> bool: + return move_to_device_pre_dispatch_enabled + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, fast_dev_run=True, accelerator='ddp', plugins=TestPropertyPlugin(), num_processes=1 + ) + trainer.fit(model) + + # Check if mocked device was called. Since we're on CPU, model_to_device does nothing anyway. + if move_to_device_pre_dispatch_enabled: + mock_model_to_device.assert_called() + else: + mock_model_to_device.assert_not_called() From 3cf491154620e36d7464b0d6cd33b8ed09620848 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 4 Mar 2021 10:22:09 +0000 Subject: [PATCH 2/5] Move properties, add test --- .../plugins/training_type/ddp.py | 16 ++++++------ .../plugins/training_type/ddp_spawn.py | 15 ++++++++--- tests/accelerators/test_ddp.py | 25 ++++++++++++++++++- 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index cdaf299442ac8..05736ac3e5c2b 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -260,14 +260,6 @@ def pre_dispatch(self): self.barrier() - @property - def call_move_to_device_hook_in_pre_dispatch(self) -> bool: - """ - Call the ``model_to_device`` function within pre_dispatch if this is set to True. - Useful for when plugin would like to call model_to_device at another time, or skip the call. - """ - return True - def post_dispatch(self): if "WORLD_SIZE" in os.environ: del os.environ["WORLD_SIZE"] @@ -321,3 +313,11 @@ def predict(self, *args, **kwargs): def post_training_step(self): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True + + @property + def call_move_to_device_hook_in_pre_dispatch(self) -> bool: + """ + Call the ``model_to_device`` function within pre_dispatch if this is set to True. + Useful for when plugin would like to call model_to_device at another time, or skip the call. + """ + return True diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 9ff4bb8cd2749..f7b70eceec2bd 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -59,7 +59,7 @@ def __init__( self.sync_batchnorm = sync_batchnorm self._ddp_kwargs = kwargs self.dist = LightningDistributed() - self.num_processes = len(parallel_devices) + self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices self.node_rank = 0 self.mp_queue = None @@ -151,8 +151,9 @@ def new_process(self, process_idx, trainer, mp_queue): if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) - # move the model to the correct device - self.model_to_device() + if self.call_move_to_device_hook_in_pre_dispatch: + # move the model to the correct device + self.model_to_device() self.configure_ddp() @@ -290,3 +291,11 @@ def predict(self, *args, **kwargs): def post_training_step(self): if not self.lightning_module.automatic_optimization: self.model.require_backward_grad_sync = True + + @property + def call_move_to_device_hook_in_pre_dispatch(self) -> bool: + """ + Call the ``model_to_device`` function within pre_dispatch if this is set to True. + Useful for when plugin would like to call model_to_device at another time, or skip the call. + """ + return True diff --git a/tests/accelerators/test_ddp.py b/tests/accelerators/test_ddp.py index a7fba770433d1..2d36bcfce40c1 100644 --- a/tests/accelerators/test_ddp.py +++ b/tests/accelerators/test_ddp.py @@ -19,7 +19,7 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.plugins import DDPPlugin +from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin from tests.accelerators import ddp_model, DDPLauncher from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -130,3 +130,26 @@ def call_move_to_device_hook_in_pre_dispatch(self) -> bool: mock_model_to_device.assert_called() else: mock_model_to_device.assert_not_called() + + +@pytest.mark.parametrize('move_to_device_pre_dispatch_enabled', [False, True]) +@mock.patch('pytorch_lightning.plugins.DDPSpawnPlugin.model_to_device', autospec=True) +def test_move_to_device_in_pre_dispatch(mock_model_to_device, move_to_device_pre_dispatch_enabled, tmpdir): + """ + Test if ``call_move_to_device_hook_in_pre_dispatch`` is disabled we do not move to device till later + in training. + """ + + with mock.patch('pytorch_lightning.plugins.DDPSpawnPlugin.call_move_to_device_hook_in_pre_dispatch', + move_to_device_pre_dispatch_enabled): + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, fast_dev_run=True, accelerator='ddp_spawn', num_processes=1 + ) + trainer.fit(model) + + # Check if mocked device was called. Since we're on CPU, model_to_device does nothing anyway. + if move_to_device_pre_dispatch_enabled: + mock_model_to_device.assert_called() + else: + mock_model_to_device.assert_not_called() From 1beb7afaf3b24b86c8f3188cd8f80466ab4f8a57 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 4 Mar 2021 11:10:31 +0000 Subject: [PATCH 3/5] Modify to DDP Test --- tests/accelerators/test_ddp.py | 41 ++++++++-------------------------- 1 file changed, 9 insertions(+), 32 deletions(-) diff --git a/tests/accelerators/test_ddp.py b/tests/accelerators/test_ddp.py index 2d36bcfce40c1..f079c3274e2d8 100644 --- a/tests/accelerators/test_ddp.py +++ b/tests/accelerators/test_ddp.py @@ -105,46 +105,23 @@ def test_torch_distributed_backend_env_variables(tmpdir): trainer.fit(model) -@pytest.mark.parametrize('move_to_device_pre_dispatch_enabled', [False, True]) -@mock.patch('pytorch_lightning.plugins.DDPPlugin.model_to_device', autospec=True) -def test_move_to_device_in_pre_dispatch(mock_model_to_device, move_to_device_pre_dispatch_enabled, tmpdir): +@pytest.mark.parametrize('move_to_device_pre_dispatch_enabled', [True, False]) +@mock.patch('pytorch_lightning.plugins.DDPPlugin.model_to_device') +def test_move_to_device_in_pre_dispatch(mock_model_to_device, tmpdir, move_to_device_pre_dispatch_enabled): """ Test if ``call_move_to_device_hook_in_pre_dispatch`` is disabled we do not move to device till later in training. """ - class TestPropertyPlugin(DDPPlugin): - - @property - def call_move_to_device_hook_in_pre_dispatch(self) -> bool: - return move_to_device_pre_dispatch_enabled - - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, fast_dev_run=True, accelerator='ddp', plugins=TestPropertyPlugin(), num_processes=1 - ) - trainer.fit(model) - - # Check if mocked device was called. Since we're on CPU, model_to_device does nothing anyway. - if move_to_device_pre_dispatch_enabled: - mock_model_to_device.assert_called() - else: - mock_model_to_device.assert_not_called() - - -@pytest.mark.parametrize('move_to_device_pre_dispatch_enabled', [False, True]) -@mock.patch('pytorch_lightning.plugins.DDPSpawnPlugin.model_to_device', autospec=True) -def test_move_to_device_in_pre_dispatch(mock_model_to_device, move_to_device_pre_dispatch_enabled, tmpdir): - """ - Test if ``call_move_to_device_hook_in_pre_dispatch`` is disabled we do not move to device till later - in training. - """ - - with mock.patch('pytorch_lightning.plugins.DDPSpawnPlugin.call_move_to_device_hook_in_pre_dispatch', + with mock.patch(f'pytorch_lightning.plugins.DDPPlugin.call_move_to_device_hook_in_pre_dispatch', move_to_device_pre_dispatch_enabled): model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, fast_dev_run=True, accelerator='ddp_spawn', num_processes=1 + default_root_dir=tmpdir, + fast_dev_run=True, + accelerator='ddp', + plugins=DDPPlugin(), + num_processes=1 ) trainer.fit(model) From 531dfb815b8241035b4f760ecfec5533af0d7d92 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 4 Mar 2021 11:12:06 +0000 Subject: [PATCH 4/5] Format --- tests/accelerators/test_ddp.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/accelerators/test_ddp.py b/tests/accelerators/test_ddp.py index f079c3274e2d8..a45d0d637191b 100644 --- a/tests/accelerators/test_ddp.py +++ b/tests/accelerators/test_ddp.py @@ -113,15 +113,13 @@ def test_move_to_device_in_pre_dispatch(mock_model_to_device, tmpdir, move_to_de in training. """ - with mock.patch(f'pytorch_lightning.plugins.DDPPlugin.call_move_to_device_hook_in_pre_dispatch', - move_to_device_pre_dispatch_enabled): + with mock.patch( + f'pytorch_lightning.plugins.DDPPlugin.call_move_to_device_hook_in_pre_dispatch', + move_to_device_pre_dispatch_enabled + ): model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=True, - accelerator='ddp', - plugins=DDPPlugin(), - num_processes=1 + default_root_dir=tmpdir, fast_dev_run=True, accelerator='ddp', plugins=DDPPlugin(), num_processes=1 ) trainer.fit(model) From c53873ddb2ca53b3837b24daeb636e4b76454d26 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Sun, 7 Mar 2021 11:24:36 +0000 Subject: [PATCH 5/5] Update pytorch_lightning/plugins/training_type/ddp_spawn.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index f7b70eceec2bd..8479124a66a72 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -59,7 +59,7 @@ def __init__( self.sync_batchnorm = sync_batchnorm self._ddp_kwargs = kwargs self.dist = LightningDistributed() - self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices + self.num_processes = len(parallel_devices) if parallel_devices is not None else None self.node_rank = 0 self.mp_queue = None