-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Refactor LightningDistributedDataParallel #5185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b9ea2b1
80d5992
214d7ba
01838c9
178084a
26d8540
0569795
2f8fbf9
4f1f23a
3d2bf26
1d60a50
46d10d9
e3097c8
08c486e
d02bb0a
0c392a4
1ebbbdc
855d426
d137267
60d8fe5
be8e11e
558c938
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,9 +14,10 @@ | |
|
|
||
| import itertools | ||
| import threading | ||
| import warnings | ||
| from collections.abc import Iterable, Mapping | ||
| from itertools import chain | ||
| from typing import Optional | ||
| from typing import Any, Optional | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
|
|
@@ -25,6 +26,7 @@ | |
| from torch.nn.parallel import DistributedDataParallel | ||
| from torch.nn.parallel._functions import Gather | ||
|
|
||
| from pytorch_lightning.core.lightning import LightningModule | ||
| from pytorch_lightning.core.step_result import Result | ||
| from pytorch_lightning.utilities.warnings import WarningCache | ||
|
|
||
|
|
@@ -150,73 +152,75 @@ def parallel_apply(self, replicas, inputs, kwargs): | |
|
|
||
|
|
||
| class LightningDistributedDataParallel(DistributedDataParallel): | ||
| """ | ||
| Override the forward call in lightning so it goes to training and validation step respectively | ||
| """ | ||
| PREPARE_FOR_BACKWARDS = True | ||
|
|
||
| def parallel_apply(self, replicas, inputs, kwargs): | ||
| return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) | ||
| def __init__(self, module: LightningModule, *args, **kwargs): | ||
| warnings.warn( | ||
| "The usage of `LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4." | ||
| " From now on we recommend to directly sublcass `torch.nn.parallel.DistributedDataParallel`.", | ||
| DeprecationWarning | ||
| ) | ||
| super().__init__(LightningDistributedModule(module), *args, **kwargs) | ||
|
|
||
| def forward(self, *inputs, **kwargs): # pragma: no-cover | ||
| self._sync_params() | ||
| self.reducer_reset_hooks() | ||
| fx_called: str = '' | ||
|
|
||
| if self.device_ids: | ||
|
|
||
| inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) | ||
| if len(self.device_ids) == 1: | ||
| # -------------- | ||
| # LIGHTNING MOD | ||
| # -------------- | ||
| # normal | ||
| # output = self.module(*inputs[0], **kwargs[0]) | ||
| # lightning | ||
| if self.module.training: | ||
| output = self.module.training_step(*inputs[0], **kwargs[0]) | ||
| fx_called = 'training_step' | ||
| elif self.module.testing: | ||
| output = self.module.test_step(*inputs[0], **kwargs[0]) | ||
| fx_called = 'test_step' | ||
| else: | ||
| output = self.module.validation_step(*inputs[0], **kwargs[0]) | ||
| fx_called = 'validation_step' | ||
| else: | ||
| outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) | ||
| output = self.gather(outputs, self.output_device) | ||
| else: | ||
| # output = self.module(*inputs, **kwargs) | ||
| # normal lightning (ddp_cpu) | ||
| if self.module.training: | ||
| output = self.module.training_step(*inputs, **kwargs) | ||
| elif self.module.testing: | ||
| output = self.module.test_step(*inputs, **kwargs) | ||
| else: | ||
| output = self.module.validation_step(*inputs, **kwargs) | ||
|
|
||
| if not self._reducer_prepared_for_backwards and self.PREPARE_FOR_BACKWARDS: | ||
| self.reducer_prepare_for_backwards(output) | ||
| class LightningDistributedModule(torch.nn.Module): | ||
|
|
||
| def __init__(self, pl_module: LightningModule): | ||
| """ | ||
| Wraps the user's LightningModule and redirects the forward call to the appropriate | ||
| method, either ``training_step``, ``validation_step`` or ```test_step``. | ||
| This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel` as | ||
| shown in the example. | ||
|
|
||
| if output is None: | ||
| warn_missing_output(f'{fx_called} returned None. Did you forget to return an output') | ||
| Example: | ||
|
|
||
| ddp_model = DistributedDataParallel( | ||
| module=LightningDistributedModule(lightning_module), | ||
| device_ids=[local_rank], | ||
| ... | ||
| ) | ||
|
|
||
| Args: | ||
| pl_module: the model to wrap | ||
|
|
||
| """ | ||
| super().__init__() | ||
| self.module = pl_module | ||
|
|
||
| def forward(self, *inputs, **kwargs): | ||
| if self.module.training: | ||
| output = self.module.training_step(*inputs, **kwargs) | ||
| warn_if_output_is_none(output, "training_step") | ||
| elif self.module.testing: | ||
| output = self.module.test_step(*inputs, **kwargs) | ||
| warn_if_output_is_none(output, "test_step") | ||
| else: | ||
| output = self.module.validation_step(*inputs, **kwargs) | ||
| warn_if_output_is_none(output, "validation_step") | ||
| return output | ||
|
|
||
| def reducer_prepare_for_backwards(self, output): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure where this should go, it requires the reducer from ddp
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @SeanNaren do you remember the conversation we started in #4976 about this? You had an idea there, I'm trying to understand it, maybe you can explain again? :)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have been thinking about this one maybe too much ahah, but I didn't find a better way to do it as
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could open a PR in PyTorch to at least move it to a function we can use |
||
| self._reducer_prepared_for_backwards = True | ||
| if torch.is_grad_enabled(): | ||
| # We'll return the output object verbatim since it is a freeform | ||
| # object. We need to find any tensors in this object, though, | ||
| # because we need to figure out which parameters were used during | ||
| # this forward pass, to ensure we short circuit reduction for any | ||
| # unused parameters. Only if `find_unused_parameters` is set. | ||
| if self.find_unused_parameters: | ||
| self.reducer.prepare_for_backward(list(_find_tensors(output))) | ||
| else: | ||
| self.reducer.prepare_for_backward([]) | ||
|
|
||
| def reducer_reset_hooks(self): | ||
| self._reducer_prepared_for_backwards = False | ||
|
|
||
| # In manual_optimization, we need to call reducer prepare_for_backward. | ||
| # Note: Keep track of Pytorch DDP and update if there is a change | ||
| # https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638 | ||
| def prepare_for_backward(model: DistributedDataParallel, output: Any): | ||
| if torch.is_grad_enabled() and model.require_backward_grad_sync: | ||
| model.require_forward_param_sync = True | ||
| # We'll return the output object verbatim since it is a freeform | ||
| # object. We need to find any tensors in this object, though, | ||
| # because we need to figure out which parameters were used during | ||
| # this forward pass, to ensure we short circuit reduction for any | ||
| # unused parameters. Only if `find_unused_parameters` is set. | ||
| if model.find_unused_parameters: | ||
| model.reducer.prepare_for_backward(list(_find_tensors(output))) | ||
| else: | ||
| model.reducer.prepare_for_backward([]) | ||
| else: | ||
| model.require_forward_param_sync = False | ||
|
|
||
|
|
||
| def warn_if_output_is_none(output: Any, method_name: str) -> None: | ||
| if output is None: | ||
| warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?') | ||
|
|
||
|
|
||
| def warn_missing_output(fx_called): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,11 +16,12 @@ | |
| from typing import Any, Dict, List, Union | ||
|
|
||
| import torch.distributed as torch_distrib | ||
| from torch.nn.parallel.distributed import DistributedDataParallel | ||
| from torch.optim import Optimizer | ||
|
|
||
| from pytorch_lightning import _logger as log | ||
| from pytorch_lightning.core.lightning import LightningModule | ||
| from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel | ||
| from pytorch_lightning.overrides.data_parallel import LightningDistributedModule, prepare_for_backward | ||
| from pytorch_lightning.plugins.plugin import LightningPlugin | ||
| from pytorch_lightning.utilities import DeviceType | ||
|
|
||
|
|
@@ -29,15 +30,14 @@ class DDPPlugin(LightningPlugin): | |
| """ | ||
| Plugin to link a custom ddp implementation to any arbitrary accelerator. | ||
|
|
||
| This plugin forwards all constructor arguments to `LightningDistributedDataParallel`, | ||
| which in turn forwards all args to `DistributedDataParallel`. | ||
| This plugin forwards all constructor arguments to :class:`~torch.nn.parallel.DistributedDataParallel`. | ||
|
|
||
| Example:: | ||
|
|
||
| class MyDDP(DDPPlugin): | ||
|
|
||
| def configure_ddp(self, model, device_ids): | ||
| model = MyDDPWrapper(model, device_ids) | ||
| model = MyDDPWrapper(LightningDistributedModule(model), device_ids) | ||
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return model | ||
|
|
||
| my_ddp = MyDDP() | ||
|
|
@@ -49,32 +49,40 @@ def __init__(self, **kwargs): | |
|
|
||
| def configure_ddp( | ||
| self, model: LightningModule, device_ids: List[int] | ||
| ) -> LightningDistributedDataParallel: | ||
| ) -> DistributedDataParallel: | ||
| """ | ||
| Pass through all customizations from constructor to `LightningDistributedDataParallel`. | ||
| Pass through all customizations from constructor to :class:`~torch.nn.parallel.DistributedDataParallel`. | ||
| Override to define a custom DDP implementation. | ||
|
|
||
| .. note:: Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel | ||
|
|
||
| .. note:: This requires that your DDP implementation subclasses | ||
| :class:`~torch.nn.parallel.DistributedDataParallel` and that | ||
| the original LightningModule gets wrapped by | ||
| :class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedModule`. | ||
|
|
||
| The default implementation is:: | ||
|
|
||
| def configure_ddp(self, model, device_ids): | ||
| model = LightningDistributedDataParallel( | ||
| model, device_ids=device_ids, **self._ddp_kwargs | ||
| model = DistributedDataParallel( | ||
| LightningDistributedModule(model), | ||
| device_ids=device_ids, | ||
| **self._ddp_kwargs, | ||
| ) | ||
| return model | ||
|
|
||
| Args: | ||
| model: the lightningModule | ||
| model: the LightningModule | ||
| device_ids: the list of devices available | ||
|
|
||
| Returns: | ||
| the model wrapped in LightningDistributedDataParallel | ||
| the model wrapped in :class:`~torch.nn.parallel.DistributedDataParallel` | ||
|
|
||
| """ | ||
| model = LightningDistributedDataParallel( | ||
| model, | ||
| # if unset, default `find_unused_parameters` `True` | ||
| self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( | ||
| "find_unused_parameters", True | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm not sure why this is the default. this incurs a perf hit and is different from the DDP default
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was added in this PR by you: #4382, I'm not sure if it's necessary.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey @awaelchli. Not sure Will looked into the default for
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be false, it is only recommended to do this if necessary: https://pytorch.org/docs/stable/notes/ddp.html#internal-design
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, made it default False to be in line with pytorch DDP: #5435
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in #4382 I was preserving the prior behavior without digging into the full history behind the setting :/ This could be a nice speedup for distributed training jobs. @SeanNaren n00b question: is there a way to estimate the possible gains using the lightning benchmarks? |
||
| ) | ||
| model = DistributedDataParallel( | ||
| module=LightningDistributedModule(model), | ||
| device_ids=device_ids, | ||
| **self._ddp_kwargs, | ||
| ) | ||
|
|
@@ -131,7 +139,7 @@ def on_after_setup_optimizers(self, trainer): | |
|
|
||
| def get_model_from_plugin( | ||
| self, | ||
| model: Union[LightningDistributedDataParallel, LightningModule] | ||
| model: Union[DistributedDataParallel, LightningModule] | ||
| ) -> LightningModule: | ||
| """ | ||
| Override to modify returning base :class:`LightningModule` | ||
|
|
@@ -147,24 +155,23 @@ def get_model_from_plugin( | |
| Returns: Reference :class:`LightningModule` within parallel wrapper. | ||
|
|
||
| """ | ||
| if isinstance(model, LightningDistributedDataParallel): | ||
| return model.module | ||
| if isinstance(model, DistributedDataParallel): | ||
| model = model.module | ||
| if isinstance(model, LightningDistributedModule): | ||
| model = model.module | ||
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return model | ||
|
|
||
| @contextmanager | ||
| def block_backward_sync(self, model: LightningDistributedDataParallel): | ||
| def block_backward_sync(self, model: DistributedDataParallel): | ||
| """ | ||
| Blocks ddp sync gradients behaviour on backwards pass. | ||
| This is useful for skipping sync when accumulating gradients, reducing communication overhead | ||
| Returns: context manager with sync behaviour off | ||
| """ | ||
| yield model.no_sync() | ||
|
Comment on lines
164
to
171
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the only reason I can think of why it is there is so the user can override this method in their own plugin, though there is not much customization they can do to this context manager :)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We basically added this, since we did not want anything that is only DDP specific (i.e. any typechecks against the prior |
||
|
|
||
| def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any): | ||
| model.reducer_prepare_for_backwards(output) | ||
|
|
||
| def on_after_manual_backward(self, model: LightningDistributedDataParallel): | ||
| model.reducer_reset_hooks() | ||
| def on_before_manual_backward(self, model: DistributedDataParallel, output: Any): | ||
| prepare_for_backward(model, output) | ||
|
|
||
| def distributed_sampler_kwargs(self, distributed_sampler_kwargs): | ||
| return distributed_sampler_kwargs | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any chance we could use the trainer state instead. If we user might change his model state by accident, we won't be calling the right function.
I am also thinking about people doing MC Dropout evaluation.
module.trainingwould be True if they don't settraining()properly.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about that for quite a while, but I think, we should rely on the module attribute, but not from self.module but from self.
Users would probably only change it based on their LightningModule and when we rely on self.training and self.testing (probably has to be added here then) we should be fine in that regard and did not tie it to close to the trainer, since we try to get rid of all the trainer references everywhere.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with both of you. The .training attribute is part of the nn.Module, however the .testing is set by the Trainer. This may not be so obvious, because there is no trace of training or testing attributes anywhere in the LightningModule.
One attribute being part of nn.Module and one being part of LightningModule is strange and will not be so easy to debug.
I'd like to keep it a strict refactor about the DDP class in this PR and not change the attributes yet.
Shall I create an issue so we can follow up on this?
Both of your suggestions seem reasonable to me.
Justus's idea would require additional logic in the training loop to set the attributes on this wrapper.
Thomas's idea would basically be replacing self.module.training with self.module.trainer.training, right? We could also thing about adding read-only properties to the LightningModule as we did for other attributes that need to reference trainer, so user can know about possible name collision
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good, I would be just careful about adding too many new attributes to LightningModule :]