From 6f77ec20ed639260c52d5287a8d4f212d95bc759 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 26 Oct 2020 14:08:51 -0700 Subject: [PATCH 1/7] Update ddp_plugin.py --- pytorch_lightning/plugins/ddp_plugin.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 27deeeddfdb45..fad36454f5916 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -1,12 +1,16 @@ -from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel -from pytorch_lightning.core.lightning import LightningModule from typing import List +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel + class DDPPlugin(object): """ Plugin to link a custom ddp implementation to any arbitrary accelerator. + This plugin forwards all constructor arguments to `LightningDistributedDataParallel`, + which in turn forwards all args to `DistributedDataParallel`. + Example:: class MyDDP(DDPPlugin): @@ -17,11 +21,16 @@ def configure_ddp(self, model, device_ids): my_ddp = MyDDP() trainer = Trainer(accelerator='ddp_x', plugins=[my_ddp]) - """ - def configure_ddp(self, model: LightningModule, device_ids: List[int]) -> LightningDistributedDataParallel: + def __init__(self, **kwargs): + self._kwargs: Dict[str, Any] = kwargs + + def configure_ddp( + self, model: LightningModule, device_ids: List[int] + ) -> LightningDistributedDataParallel: """ + Pass through all customizations from constructor to `LightningDistributedDataParallel`. Override to define a custom DDP implementation. .. note:: Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel @@ -43,5 +52,7 @@ def configure_ddp(self, model, device_ids): the model wrapped in LightningDistributedDataParallel """ - model = LightningDistributedDataParallel(model, device_ids=device_ids, find_unused_parameters=True) + model = LightningDistributedDataParallel( + model, device_ids=device_ids, find_unused_parameters=True, **self._kwargs + ) return model From 3c3b1155bfd3830b252bd9f25b821b12413087c6 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 26 Oct 2020 15:12:17 -0700 Subject: [PATCH 2/7] Update ddp_plugin.py --- pytorch_lightning/plugins/ddp_plugin.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index fad36454f5916..15aeaa6e8c1de 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -52,7 +52,13 @@ def configure_ddp(self, model, device_ids): the model wrapped in LightningDistributedDataParallel """ + # if unset, default `find_unused_parameters` `True` + self._kwargs["find_unused_parameters"] = self._kwargs.get( + "find_unused_parameters", True + ) model = LightningDistributedDataParallel( - model, device_ids=device_ids, find_unused_parameters=True, **self._kwargs + model, + device_ids=device_ids, + **self._kwargs, ) return model From e8d855dac20c2b70757d2051ec2a73f21bcac6ab Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 26 Oct 2020 15:16:25 -0700 Subject: [PATCH 3/7] Update ddp_plugin.py --- pytorch_lightning/plugins/ddp_plugin.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 15aeaa6e8c1de..654b99bff358b 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -24,7 +24,7 @@ def configure_ddp(self, model, device_ids): """ def __init__(self, **kwargs): - self._kwargs: Dict[str, Any] = kwargs + self._ddp_args: Dict[str, Any] = kwargs def configure_ddp( self, model: LightningModule, device_ids: List[int] @@ -53,12 +53,12 @@ def configure_ddp(self, model, device_ids): """ # if unset, default `find_unused_parameters` `True` - self._kwargs["find_unused_parameters"] = self._kwargs.get( + self._ddp_args["find_unused_parameters"] = self._ddp_args.get( "find_unused_parameters", True ) model = LightningDistributedDataParallel( model, device_ids=device_ids, - **self._kwargs, + **self._ddp_args, ) return model From 1a86f82c4f95e090caa086b7914a99b6c5a50a31 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 26 Oct 2020 15:19:03 -0700 Subject: [PATCH 4/7] Update test_ddp_plugin.py --- tests/plugins/test_ddp_plugin.py | 111 ++++++++++++++++++++++--------- 1 file changed, 81 insertions(+), 30 deletions(-) diff --git a/tests/plugins/test_ddp_plugin.py b/tests/plugins/test_ddp_plugin.py index b190f34395522..69cd0e3beb7b4 100644 --- a/tests/plugins/test_ddp_plugin.py +++ b/tests/plugins/test_ddp_plugin.py @@ -1,25 +1,30 @@ -from pytorch_lightning.callbacks import Callback -from tests.base.boring_model import BoringModel -from pytorch_lightning import accelerators, Trainer -from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -import pytest import os from unittest import mock +import pytest +from pytorch_lightning import Trainer, accelerators +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.plugins.ddp_plugin import DDPPlugin +from tests.base.boring_model import BoringModel -@mock.patch.dict(os.environ, { - "CUDA_VISIBLE_DEVICES": "0,1", - "SLURM_NTASKS": "2", - "SLURM_JOB_NAME": "SOME_NAME", - "SLURM_NODEID": "0", - "LOCAL_RANK": "0", - "SLURM_LOCALID": "0" -}) -@mock.patch('torch.cuda.device_count', return_value=2) -@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'], - [('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)]) -def test_ddp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): +@mock.patch.dict( + os.environ, + { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0", + }, +) +@mock.patch("torch.cuda.device_count", return_value=2) +@pytest.mark.parametrize( + ["ddp_backend", "gpus", "num_processes"], + [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], +) +def test_ddp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): class CB(Callback): def on_fit_start(self, trainer, pl_module): assert isinstance(trainer.accelerator_backend.ddp_plugin, DDPPlugin) @@ -31,24 +36,29 @@ def on_fit_start(self, trainer, pl_module): gpus=gpus, num_processes=num_processes, distributed_backend=ddp_backend, - callbacks=[CB()] + callbacks=[CB()], ) with pytest.raises(SystemExit): trainer.fit(model) -@mock.patch.dict(os.environ, { - "CUDA_VISIBLE_DEVICES": "0,1", - "SLURM_NTASKS": "2", - "SLURM_JOB_NAME": "SOME_NAME", - "SLURM_NODEID": "0", - "LOCAL_RANK": "0", - "SLURM_LOCALID": "0" -}) -@mock.patch('torch.cuda.device_count', return_value=2) -@pytest.mark.parametrize(['ddp_backend', 'gpus', 'num_processes'], - [('ddp_cpu', None, None), ('ddp', 2, 0), ('ddp2', 2, 0), ('ddp_spawn', 2, 0)]) +@mock.patch.dict( + os.environ, + { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0", + }, +) +@mock.patch("torch.cuda.device_count", return_value=2) +@pytest.mark.parametrize( + ["ddp_backend", "gpus", "num_processes"], + [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], +) def test_ddp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes): class MyDDP(DDPPlugin): pass @@ -65,7 +75,48 @@ def on_fit_start(self, trainer, pl_module): num_processes=num_processes, distributed_backend=ddp_backend, plugins=[MyDDP()], - callbacks=[CB()] + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@mock.patch.dict( + os.environ, + { + "CUDA_VISIBLE_DEVICES": "0,1", + "SLURM_NTASKS": "2", + "SLURM_JOB_NAME": "SOME_NAME", + "SLURM_NODEID": "0", + "LOCAL_RANK": "0", + "SLURM_LOCALID": "0", + }, +) +@mock.patch("torch.cuda.device_count", return_value=2) +@pytest.mark.parametrize( + ["ddp_backend", "gpus", "num_processes"], + [("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)], +) +def test_ddp_choice_custom_ddp_cpu_custom_args( + tmpdir, ddp_backend, gpus, num_processes +): + class MyDDP(DDPPlugin): + pass + + class CB(Callback): + def on_fit_start(self, trainer, pl_module): + assert isinstance(trainer.accelerator_backend.ddp_plugin, MyDDP) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + gpus=gpus, + num_processes=num_processes, + distributed_backend=ddp_backend, + plugins=[MyDDP(broadcast_buffers=False, find_unused_parameters=True)], + callbacks=[CB()], ) with pytest.raises(SystemExit): From 429795aade46128703498bb6c1039a2033e24b08 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 26 Oct 2020 19:49:44 -0700 Subject: [PATCH 5/7] Update pytorch_lightning/plugins/ddp_plugin.py --- pytorch_lightning/plugins/ddp_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 654b99bff358b..f96f4b28e6eda 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -24,7 +24,7 @@ def configure_ddp(self, model, device_ids): """ def __init__(self, **kwargs): - self._ddp_args: Dict[str, Any] = kwargs + self.ddp_kwargs: Dict[str, Any] = kwargs def configure_ddp( self, model: LightningModule, device_ids: List[int] From 1e435fad690c6db8a5234ddfdd16456d39b189bb Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 26 Oct 2020 19:49:51 -0700 Subject: [PATCH 6/7] Update pytorch_lightning/plugins/ddp_plugin.py --- pytorch_lightning/plugins/ddp_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index f96f4b28e6eda..0f5fd2eeab364 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -53,12 +53,12 @@ def configure_ddp(self, model, device_ids): """ # if unset, default `find_unused_parameters` `True` - self._ddp_args["find_unused_parameters"] = self._ddp_args.get( + self.ddp_kwargs["find_unused_parameters"] = self.ddp_kwargs.get( "find_unused_parameters", True ) model = LightningDistributedDataParallel( model, device_ids=device_ids, - **self._ddp_args, + **self.ddp_kwargs, ) return model From f759e14dc15f8203ea3ab992a5b803a0bbcfafe4 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 27 Oct 2020 10:46:27 +0000 Subject: [PATCH 7/7] Fixed imports, make ddp_kwargs protected --- pytorch_lightning/plugins/ddp_plugin.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 0f5fd2eeab364..4c4fdc8f0d368 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Dict, Any from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel @@ -24,7 +24,7 @@ def configure_ddp(self, model, device_ids): """ def __init__(self, **kwargs): - self.ddp_kwargs: Dict[str, Any] = kwargs + self._ddp_kwargs: Dict[str, Any] = kwargs def configure_ddp( self, model: LightningModule, device_ids: List[int] @@ -53,12 +53,12 @@ def configure_ddp(self, model, device_ids): """ # if unset, default `find_unused_parameters` `True` - self.ddp_kwargs["find_unused_parameters"] = self.ddp_kwargs.get( + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( "find_unused_parameters", True ) model = LightningDistributedDataParallel( model, device_ids=device_ids, - **self.ddp_kwargs, + **self._ddp_kwargs, ) return model