diff --git a/CHANGELOG.md b/CHANGELOG.md index b33a492a3ed04..35ff8e275ed91 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,7 +77,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `callbacks` argument in `Trainer` to allow `Callback` input ([#5446](https://github.com/PyTorchLightning/pytorch-lightning/pull/5446)) -- Changed the default of `find_unused_parameters` to `False` in DDP ([#5435](https://github.com/PyTorchLightning/pytorch-lightning/pull/5435)) +- Changed the default of `find_unused_parameters` to `False` in DDP ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185)) + ### Deprecated diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 9b56119a04c3e..1b3ae6f23058a 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -104,10 +104,6 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs): # once backward has been applied, release graph closure_loss = closure_loss.detach() - - if not automatic_optimization and self.ddp_plugin is not None: - # Manually prepare for reduce as user calling backwards manually - self.ddp_plugin.on_after_manual_backward(self.trainer.model) return closure_loss def clip_gradients(self, optimizer, clip_val=None): diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 0b97d328904ac..f27c18513831f 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -564,8 +564,7 @@ def transfer_batch_to_device(self, batch, device) Note: This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support for your custom batch objects, you need to define your custom - :class:`~torch.nn.parallel.DistributedDataParallel` or - :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and + :class:`~torch.nn.parallel.DistributedDataParallel` and override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`. See Also: diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index f6f045134f2f9..69676cf77e079 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -14,9 +14,10 @@ import itertools import threading +import warnings from collections.abc import Iterable, Mapping from itertools import chain -from typing import Optional +from typing import Any, Optional import torch from torch import Tensor @@ -25,6 +26,7 @@ from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel._functions import Gather +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities.warnings import WarningCache @@ -150,73 +152,75 @@ def parallel_apply(self, replicas, inputs, kwargs): class LightningDistributedDataParallel(DistributedDataParallel): - """ - Override the forward call in lightning so it goes to training and validation step respectively - """ - PREPARE_FOR_BACKWARDS = True - def parallel_apply(self, replicas, inputs, kwargs): - return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) + def __init__(self, module: LightningModule, *args, **kwargs): + warnings.warn( + "The usage of `LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4." + " From now on we recommend to directly sublcass `torch.nn.parallel.DistributedDataParallel`.", + DeprecationWarning + ) + super().__init__(LightningDistributedModule(module), *args, **kwargs) - def forward(self, *inputs, **kwargs): # pragma: no-cover - self._sync_params() - self.reducer_reset_hooks() - fx_called: str = '' - - if self.device_ids: - - inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) - if len(self.device_ids) == 1: - # -------------- - # LIGHTNING MOD - # -------------- - # normal - # output = self.module(*inputs[0], **kwargs[0]) - # lightning - if self.module.training: - output = self.module.training_step(*inputs[0], **kwargs[0]) - fx_called = 'training_step' - elif self.module.testing: - output = self.module.test_step(*inputs[0], **kwargs[0]) - fx_called = 'test_step' - else: - output = self.module.validation_step(*inputs[0], **kwargs[0]) - fx_called = 'validation_step' - else: - outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) - output = self.gather(outputs, self.output_device) - else: - # output = self.module(*inputs, **kwargs) - # normal lightning (ddp_cpu) - if self.module.training: - output = self.module.training_step(*inputs, **kwargs) - elif self.module.testing: - output = self.module.test_step(*inputs, **kwargs) - else: - output = self.module.validation_step(*inputs, **kwargs) - if not self._reducer_prepared_for_backwards and self.PREPARE_FOR_BACKWARDS: - self.reducer_prepare_for_backwards(output) +class LightningDistributedModule(torch.nn.Module): + + def __init__(self, pl_module: LightningModule): + """ + Wraps the user's LightningModule and redirects the forward call to the appropriate + method, either ``training_step``, ``validation_step`` or ```test_step``. + This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel` as + shown in the example. - if output is None: - warn_missing_output(f'{fx_called} returned None. Did you forget to return an output') + Example: + + ddp_model = DistributedDataParallel( + module=LightningDistributedModule(lightning_module), + device_ids=[local_rank], + ... + ) + + Args: + pl_module: the model to wrap + + """ + super().__init__() + self.module = pl_module + + def forward(self, *inputs, **kwargs): + if self.module.training: + output = self.module.training_step(*inputs, **kwargs) + warn_if_output_is_none(output, "training_step") + elif self.module.testing: + output = self.module.test_step(*inputs, **kwargs) + warn_if_output_is_none(output, "test_step") + else: + output = self.module.validation_step(*inputs, **kwargs) + warn_if_output_is_none(output, "validation_step") return output - def reducer_prepare_for_backwards(self, output): - self._reducer_prepared_for_backwards = True - if torch.is_grad_enabled(): - # We'll return the output object verbatim since it is a freeform - # object. We need to find any tensors in this object, though, - # because we need to figure out which parameters were used during - # this forward pass, to ensure we short circuit reduction for any - # unused parameters. Only if `find_unused_parameters` is set. - if self.find_unused_parameters: - self.reducer.prepare_for_backward(list(_find_tensors(output))) - else: - self.reducer.prepare_for_backward([]) - - def reducer_reset_hooks(self): - self._reducer_prepared_for_backwards = False + +# In manual_optimization, we need to call reducer prepare_for_backward. +# Note: Keep track of Pytorch DDP and update if there is a change +# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638 +def prepare_for_backward(model: DistributedDataParallel, output: Any): + if torch.is_grad_enabled() and model.require_backward_grad_sync: + model.require_forward_param_sync = True + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if model.find_unused_parameters: + model.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + model.reducer.prepare_for_backward([]) + else: + model.require_forward_param_sync = False + + +def warn_if_output_is_none(output: Any, method_name: str) -> None: + if output is None: + warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?') def warn_missing_output(fx_called): diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index ced9958a247b2..f0da9e5ff1a2d 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -16,11 +16,12 @@ from typing import Any, Dict, List, Union import torch.distributed as torch_distrib +from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel +from pytorch_lightning.overrides.data_parallel import LightningDistributedModule, prepare_for_backward from pytorch_lightning.plugins.plugin import LightningPlugin from pytorch_lightning.utilities import DeviceType @@ -29,15 +30,14 @@ class DDPPlugin(LightningPlugin): """ Plugin to link a custom ddp implementation to any arbitrary accelerator. - This plugin forwards all constructor arguments to `LightningDistributedDataParallel`, - which in turn forwards all args to `DistributedDataParallel`. + This plugin forwards all constructor arguments to :class:`~torch.nn.parallel.DistributedDataParallel`. Example:: class MyDDP(DDPPlugin): def configure_ddp(self, model, device_ids): - model = MyDDPWrapper(model, device_ids) + model = MyDDPWrapper(LightningDistributedModule(model), device_ids) return model my_ddp = MyDDP() @@ -49,32 +49,40 @@ def __init__(self, **kwargs): def configure_ddp( self, model: LightningModule, device_ids: List[int] - ) -> LightningDistributedDataParallel: + ) -> DistributedDataParallel: """ - Pass through all customizations from constructor to `LightningDistributedDataParallel`. + Pass through all customizations from constructor to :class:`~torch.nn.parallel.DistributedDataParallel`. Override to define a custom DDP implementation. - .. note:: Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel - + .. note:: This requires that your DDP implementation subclasses + :class:`~torch.nn.parallel.DistributedDataParallel` and that + the original LightningModule gets wrapped by + :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedModule`. The default implementation is:: def configure_ddp(self, model, device_ids): - model = LightningDistributedDataParallel( - model, device_ids=device_ids, **self._ddp_kwargs + model = DistributedDataParallel( + LightningDistributedModule(model), + device_ids=device_ids, + **self._ddp_kwargs, ) return model Args: - model: the lightningModule + model: the LightningModule device_ids: the list of devices available Returns: - the model wrapped in LightningDistributedDataParallel + the model wrapped in :class:`~torch.nn.parallel.DistributedDataParallel` """ - model = LightningDistributedDataParallel( - model, + # if unset, default `find_unused_parameters` `True` + self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( + "find_unused_parameters", True + ) + model = DistributedDataParallel( + module=LightningDistributedModule(model), device_ids=device_ids, **self._ddp_kwargs, ) @@ -131,7 +139,7 @@ def on_after_setup_optimizers(self, trainer): def get_model_from_plugin( self, - model: Union[LightningDistributedDataParallel, LightningModule] + model: Union[DistributedDataParallel, LightningModule] ) -> LightningModule: """ Override to modify returning base :class:`LightningModule` @@ -147,12 +155,14 @@ def get_model_from_plugin( Returns: Reference :class:`LightningModule` within parallel wrapper. """ - if isinstance(model, LightningDistributedDataParallel): - return model.module + if isinstance(model, DistributedDataParallel): + model = model.module + if isinstance(model, LightningDistributedModule): + model = model.module return model @contextmanager - def block_backward_sync(self, model: LightningDistributedDataParallel): + def block_backward_sync(self, model: DistributedDataParallel): """ Blocks ddp sync gradients behaviour on backwards pass. This is useful for skipping sync when accumulating gradients, reducing communication overhead @@ -160,11 +170,8 @@ def block_backward_sync(self, model: LightningDistributedDataParallel): """ yield model.no_sync() - def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any): - model.reducer_prepare_for_backwards(output) - - def on_after_manual_backward(self, model: LightningDistributedDataParallel): - model.reducer_reset_hooks() + def on_before_manual_backward(self, model: DistributedDataParallel, output: Any): + prepare_for_backward(model, output) def distributed_sampler_kwargs(self, distributed_sampler_kwargs): return distributed_sampler_kwargs diff --git a/pytorch_lightning/plugins/ddp_sequential_plugin.py b/pytorch_lightning/plugins/ddp_sequential_plugin.py index 4898e371edb30..82250d1ed9fdd 100644 --- a/pytorch_lightning/plugins/ddp_sequential_plugin.py +++ b/pytorch_lightning/plugins/ddp_sequential_plugin.py @@ -21,7 +21,6 @@ from pytorch_lightning import LightningModule from pytorch_lightning import _logger as log -from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -137,7 +136,7 @@ def init_ddp_connection( self._infer_model_balance(trainer) self._assert_valid_model_balance(trainer) - def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any): + def on_before_manual_backward(self, model: DistributedDataParallel, output: Any): pass def _infer_model_balance(self, trainer): @@ -267,10 +266,10 @@ def _check_arguments(self, trainer): def configure_ddp( self, model: LightningModule, device_ids: List[int]) -> DistributedDataParallel: - ddp_plugin = RPCPlugin(process_group=mpu.get_data_parallel_group()).configure_ddp(model, device_ids) + model = RPCPlugin(process_group=mpu.get_data_parallel_group()).configure_ddp(model, device_ids) # Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel - ddp_plugin.PREPARE_FOR_BACKWARDS = False - return ddp_plugin + model.require_backward_grad_sync = False + return model @rank_zero_only def rpc_save_model( diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 83395d4826a3a..510a44ad1bddf 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union, Any +from typing import Any, List, Optional, Union from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.optimizer import is_lightning_optimizer @@ -97,6 +97,3 @@ def required_plugins(self, amp_backend: AMPType, trainer) -> list: def on_before_manual_backward(self, model: 'LightningShardedDataParallel', output: Any): pass - - def on_after_manual_backward(self, model: 'LightningShardedDataParallel'): - pass diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index 786e775668ca2..ceff14897dc39 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -188,7 +188,7 @@ def progress_bar_callback(self): @property def progress_bar_dict(self) -> dict: """ Read-only for progress bar metrics. """ - ref_model = self.model if not self.data_parallel else self.model.module + ref_model = self.get_model() ref_model = cast(LightningModule, ref_model) return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 714a4592d984c..6b49dc63f52b4 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -137,9 +137,7 @@ def setup_training(self, model: LightningModule): # -------------------------- # Setup?? # -------------------------- - ref_model = model - if self.trainer.data_parallel: - ref_model = model.module + ref_model = self.trainer.get_model() # set the ranks and devices self.trainer.accelerator_backend.dist.rank = self.trainer.global_rank diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index 045732f97f9eb..a71f1a1c6cff6 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -12,10 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test deprecated functionality which will be removed in vX.Y.Z""" +import sys + import pytest import torch from pytorch_lightning import Trainer +from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel +from pytorch_lightning.plugins.ddp_plugin import DDPPlugin +from tests.base import BoringModel from tests.deprecated_api import _soft_unimport_module @@ -109,3 +114,32 @@ def test_v1_4_0_deprecated_metrics(): with pytest.deprecated_call(match='will be removed in v1.4'): iou(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3))) + + +class CustomDDPPlugin(DDPPlugin): + + def configure_ddp(self, model, device_ids): + # old, deprecated implementation + with pytest.deprecated_call( + match='`LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4.' + ): + model = LightningDistributedDataParallel( + module=model, + device_ids=device_ids, + **self._ddp_kwargs, + ) + return model + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +def test_v1_4_0_deprecated_lightning_distributed_data_parallel(tmpdir): + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + gpus=2, + accelerator="ddp_spawn", + plugins=[CustomDDPPlugin()] + ) + trainer.fit(model) diff --git a/tests/overrides/__init__.py b/tests/overrides/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py new file mode 100644 index 0000000000000..8c8f1649e73c7 --- /dev/null +++ b/tests/overrides/test_data_parallel.py @@ -0,0 +1,55 @@ +from unittest.mock import MagicMock + +import pytest +import torch + +from pytorch_lightning.overrides.data_parallel import LightningDistributedModule + + +def test_lightning_distributed_module_methods(): + """ Test that the LightningDistributedModule redirects .forward() to the LightningModule methods. """ + pl_module = MagicMock() + dist_module = LightningDistributedModule(pl_module) + + batch = torch.rand(5) + batch_idx = 3 + + pl_module.training = True + pl_module.testing = False + dist_module(batch, batch_idx) + pl_module.training_step.assert_called_with(batch, batch_idx) + + pl_module.training = False + pl_module.testing = True + dist_module(batch, batch_idx) + pl_module.test_step.assert_called_with(batch, batch_idx) + + pl_module.training = False + pl_module.testing = False + dist_module(batch, batch_idx) + pl_module.validation_step.assert_called_with(batch, batch_idx) + + +def test_lightning_distributed_module_warn_none_output(): + """ Test that the LightningDistributedModule warns about forgotten return statement. """ + pl_module = MagicMock() + dist_module = LightningDistributedModule(pl_module) + + pl_module.training_step.return_value = None + pl_module.validation_step.return_value = None + pl_module.test_step.return_value = None + + with pytest.warns(UserWarning, match="Your training_step returned None"): + pl_module.training = True + pl_module.testing = False + dist_module() + + with pytest.warns(UserWarning, match="Your test_step returned None"): + pl_module.training = False + pl_module.testing = True + dist_module() + + with pytest.warns(UserWarning, match="Your validation_step returned None"): + pl_module.training = False + pl_module.testing = False + dist_module()