Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `callbacks` argument in `Trainer` to allow `Callback` input ([#5446](https://github.com/PyTorchLightning/pytorch-lightning/pull/5446))


- Changed the default of `find_unused_parameters` to `False` in DDP ([#5435](https://github.com/PyTorchLightning/pytorch-lightning/pull/5435))
- Changed the default of `find_unused_parameters` to `False` in DDP ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185))


### Deprecated

Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,6 @@ def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):

# once backward has been applied, release graph
closure_loss = closure_loss.detach()

if not automatic_optimization and self.ddp_plugin is not None:
# Manually prepare for reduce as user calling backwards manually
self.ddp_plugin.on_after_manual_backward(self.trainer.model)
return closure_loss

def clip_gradients(self, optimizer, clip_val=None):
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,8 +564,7 @@ def transfer_batch_to_device(self, batch, device)
Note:
This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support
for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` or
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
:class:`~torch.nn.parallel.DistributedDataParallel` and
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.

See Also:
Expand Down
128 changes: 66 additions & 62 deletions pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

import itertools
import threading
import warnings
from collections.abc import Iterable, Mapping
from itertools import chain
from typing import Optional
from typing import Any, Optional

import torch
from torch import Tensor
Expand All @@ -25,6 +26,7 @@
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel._functions import Gather

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -150,73 +152,75 @@ def parallel_apply(self, replicas, inputs, kwargs):


class LightningDistributedDataParallel(DistributedDataParallel):
"""
Override the forward call in lightning so it goes to training and validation step respectively
"""
PREPARE_FOR_BACKWARDS = True

def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
def __init__(self, module: LightningModule, *args, **kwargs):
warnings.warn(
"The usage of `LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4."
" From now on we recommend to directly sublcass `torch.nn.parallel.DistributedDataParallel`.",
DeprecationWarning
)
super().__init__(LightningDistributedModule(module), *args, **kwargs)

def forward(self, *inputs, **kwargs): # pragma: no-cover
self._sync_params()
self.reducer_reset_hooks()
fx_called: str = ''

if self.device_ids:

inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
# --------------
# LIGHTNING MOD
# --------------
# normal
# output = self.module(*inputs[0], **kwargs[0])
# lightning
if self.module.training:
output = self.module.training_step(*inputs[0], **kwargs[0])
fx_called = 'training_step'
elif self.module.testing:
output = self.module.test_step(*inputs[0], **kwargs[0])
fx_called = 'test_step'
else:
output = self.module.validation_step(*inputs[0], **kwargs[0])
fx_called = 'validation_step'
else:
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
# output = self.module(*inputs, **kwargs)
# normal lightning (ddp_cpu)
if self.module.training:
output = self.module.training_step(*inputs, **kwargs)
elif self.module.testing:
output = self.module.test_step(*inputs, **kwargs)
else:
output = self.module.validation_step(*inputs, **kwargs)

if not self._reducer_prepared_for_backwards and self.PREPARE_FOR_BACKWARDS:
self.reducer_prepare_for_backwards(output)
class LightningDistributedModule(torch.nn.Module):

def __init__(self, pl_module: LightningModule):
"""
Wraps the user's LightningModule and redirects the forward call to the appropriate
method, either ``training_step``, ``validation_step`` or ```test_step``.
This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel` as
shown in the example.

if output is None:
warn_missing_output(f'{fx_called} returned None. Did you forget to return an output')
Example:

ddp_model = DistributedDataParallel(
module=LightningDistributedModule(lightning_module),
device_ids=[local_rank],
...
)

Args:
pl_module: the model to wrap

"""
super().__init__()
self.module = pl_module

def forward(self, *inputs, **kwargs):
if self.module.training:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance we could use the trainer state instead. If we user might change his model state by accident, we won't be calling the right function.

I am also thinking about people doing MC Dropout evaluation. module.training would be True if they don't set training() properly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about that for quite a while, but I think, we should rely on the module attribute, but not from self.module but from self.

Users would probably only change it based on their LightningModule and when we rely on self.training and self.testing (probably has to be added here then) we should be fine in that regard and did not tie it to close to the trainer, since we try to get rid of all the trainer references everywhere.

Copy link
Contributor Author

@awaelchli awaelchli Jan 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with both of you. The .training attribute is part of the nn.Module, however the .testing is set by the Trainer. This may not be so obvious, because there is no trace of training or testing attributes anywhere in the LightningModule.
One attribute being part of nn.Module and one being part of LightningModule is strange and will not be so easy to debug.

I'd like to keep it a strict refactor about the DDP class in this PR and not change the attributes yet.
Shall I create an issue so we can follow up on this?

Both of your suggestions seem reasonable to me.
Justus's idea would require additional logic in the training loop to set the attributes on this wrapper.
Thomas's idea would basically be replacing self.module.training with self.module.trainer.training, right? We could also thing about adding read-only properties to the LightningModule as we did for other attributes that need to reference trainer, so user can know about possible name collision

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good, I would be just careful about adding too many new attributes to LightningModule :]

output = self.module.training_step(*inputs, **kwargs)
warn_if_output_is_none(output, "training_step")
elif self.module.testing:
output = self.module.test_step(*inputs, **kwargs)
warn_if_output_is_none(output, "test_step")
else:
output = self.module.validation_step(*inputs, **kwargs)
warn_if_output_is_none(output, "validation_step")
return output

def reducer_prepare_for_backwards(self, output):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure where this should go, it requires the reducer from ddp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SeanNaren do you remember the conversation we started in #4976 about this? You had an idea there, I'm trying to understand it, maybe you can explain again? :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have been thinking about this one maybe too much ahah, but I didn't find a better way to do it as backward and optimizer.step are being called in training_step and DDP reducer is being called on training_step output

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could open a PR in PyTorch to at least move it to a function we can use

self._reducer_prepared_for_backwards = True
if torch.is_grad_enabled():
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])

def reducer_reset_hooks(self):
self._reducer_prepared_for_backwards = False

# In manual_optimization, we need to call reducer prepare_for_backward.
# Note: Keep track of Pytorch DDP and update if there is a change
# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638
def prepare_for_backward(model: DistributedDataParallel, output: Any):
if torch.is_grad_enabled() and model.require_backward_grad_sync:
model.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if model.find_unused_parameters:
model.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
model.reducer.prepare_for_backward([])
else:
model.require_forward_param_sync = False


def warn_if_output_is_none(output: Any, method_name: str) -> None:
if output is None:
warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?')


def warn_missing_output(fx_called):
Expand Down
53 changes: 30 additions & 23 deletions pytorch_lightning/plugins/ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from typing import Any, Dict, List, Union

import torch.distributed as torch_distrib
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import Optimizer

from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.overrides.data_parallel import LightningDistributedModule, prepare_for_backward
from pytorch_lightning.plugins.plugin import LightningPlugin
from pytorch_lightning.utilities import DeviceType

Expand All @@ -29,15 +30,14 @@ class DDPPlugin(LightningPlugin):
"""
Plugin to link a custom ddp implementation to any arbitrary accelerator.

This plugin forwards all constructor arguments to `LightningDistributedDataParallel`,
which in turn forwards all args to `DistributedDataParallel`.
This plugin forwards all constructor arguments to :class:`~torch.nn.parallel.DistributedDataParallel`.

Example::

class MyDDP(DDPPlugin):

def configure_ddp(self, model, device_ids):
model = MyDDPWrapper(model, device_ids)
model = MyDDPWrapper(LightningDistributedModule(model), device_ids)
return model

my_ddp = MyDDP()
Expand All @@ -49,32 +49,40 @@ def __init__(self, **kwargs):

def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> LightningDistributedDataParallel:
) -> DistributedDataParallel:
"""
Pass through all customizations from constructor to `LightningDistributedDataParallel`.
Pass through all customizations from constructor to :class:`~torch.nn.parallel.DistributedDataParallel`.
Override to define a custom DDP implementation.

.. note:: Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel

.. note:: This requires that your DDP implementation subclasses
:class:`~torch.nn.parallel.DistributedDataParallel` and that
the original LightningModule gets wrapped by
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedModule`.

The default implementation is::

def configure_ddp(self, model, device_ids):
model = LightningDistributedDataParallel(
model, device_ids=device_ids, **self._ddp_kwargs
model = DistributedDataParallel(
LightningDistributedModule(model),
device_ids=device_ids,
**self._ddp_kwargs,
)
return model

Args:
model: the lightningModule
model: the LightningModule
device_ids: the list of devices available

Returns:
the model wrapped in LightningDistributedDataParallel
the model wrapped in :class:`~torch.nn.parallel.DistributedDataParallel`

"""
model = LightningDistributedDataParallel(
model,
# if unset, default `find_unused_parameters` `True`
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
"find_unused_parameters", True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure why this is the default. this incurs a perf hit and is different from the DDP default

Copy link
Contributor Author

@awaelchli awaelchli Dec 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was added in this PR by you: #4382, I'm not sure if it's necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @awaelchli. Not sure Will looked into the default for find_unused_parameters. Let's stick to Pytorch default which is False right ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be false, it is only recommended to do this if necessary: https://pytorch.org/docs/stable/notes/ddp.html#internal-design

Copy link
Contributor Author

@awaelchli awaelchli Jan 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, made it default False to be in line with pytorch DDP: #5435

Copy link
Contributor

@ananthsub ananthsub Jan 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in #4382 I was preserving the prior behavior without digging into the full history behind the setting :/

This could be a nice speedup for distributed training jobs. @SeanNaren n00b question: is there a way to estimate the possible gains using the lightning benchmarks?

)
model = DistributedDataParallel(
module=LightningDistributedModule(model),
device_ids=device_ids,
**self._ddp_kwargs,
)
Expand Down Expand Up @@ -131,7 +139,7 @@ def on_after_setup_optimizers(self, trainer):

def get_model_from_plugin(
self,
model: Union[LightningDistributedDataParallel, LightningModule]
model: Union[DistributedDataParallel, LightningModule]
) -> LightningModule:
"""
Override to modify returning base :class:`LightningModule`
Expand All @@ -147,24 +155,23 @@ def get_model_from_plugin(
Returns: Reference :class:`LightningModule` within parallel wrapper.

"""
if isinstance(model, LightningDistributedDataParallel):
return model.module
if isinstance(model, DistributedDataParallel):
model = model.module
if isinstance(model, LightningDistributedModule):
model = model.module
return model

@contextmanager
def block_backward_sync(self, model: LightningDistributedDataParallel):
def block_backward_sync(self, model: DistributedDataParallel):
"""
Blocks ddp sync gradients behaviour on backwards pass.
This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
yield model.no_sync()
Comment on lines 164 to 171
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need block_backward_sync still? can we directly call model.no_sync()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the only reason I can think of why it is there is so the user can override this method in their own plugin, though there is not much customization they can do to this context manager :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We basically added this, since we did not want anything that is only DDP specific (i.e. any typechecks against the prior LightningDistributedDataParallel within the trainer/training-loop as this one should be backend agnostic.


def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any):
model.reducer_prepare_for_backwards(output)

def on_after_manual_backward(self, model: LightningDistributedDataParallel):
model.reducer_reset_hooks()
def on_before_manual_backward(self, model: DistributedDataParallel, output: Any):
prepare_for_backward(model, output)

def distributed_sampler_kwargs(self, distributed_sampler_kwargs):
return distributed_sampler_kwargs
Expand Down
9 changes: 4 additions & 5 deletions pytorch_lightning/plugins/ddp_sequential_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from pytorch_lightning import LightningModule
from pytorch_lightning import _logger as log
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -137,7 +136,7 @@ def init_ddp_connection(
self._infer_model_balance(trainer)
self._assert_valid_model_balance(trainer)

def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any):
def on_before_manual_backward(self, model: DistributedDataParallel, output: Any):
pass

def _infer_model_balance(self, trainer):
Expand Down Expand Up @@ -267,10 +266,10 @@ def _check_arguments(self, trainer):
def configure_ddp(
self,
model: LightningModule, device_ids: List[int]) -> DistributedDataParallel:
ddp_plugin = RPCPlugin(process_group=mpu.get_data_parallel_group()).configure_ddp(model, device_ids)
model = RPCPlugin(process_group=mpu.get_data_parallel_group()).configure_ddp(model, device_ids)
# Plugin handle backwards across processes. Currently not supported for DDP + pipe parallel
ddp_plugin.PREPARE_FOR_BACKWARDS = False
return ddp_plugin
model.require_backward_grad_sync = False
return model

@rank_zero_only
def rpc_save_model(
Expand Down
5 changes: 1 addition & 4 deletions pytorch_lightning/plugins/sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union, Any
from typing import Any, List, Optional, Union

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
Expand Down Expand Up @@ -97,6 +97,3 @@ def required_plugins(self, amp_backend: AMPType, trainer) -> list:

def on_before_manual_backward(self, model: 'LightningShardedDataParallel', output: Any):
pass

def on_after_manual_backward(self, model: 'LightningShardedDataParallel'):
pass
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def progress_bar_callback(self):
@property
def progress_bar_dict(self) -> dict:
""" Read-only for progress bar metrics. """
ref_model = self.model if not self.data_parallel else self.model.module
ref_model = self.get_model()
ref_model = cast(LightningModule, ref_model)
return dict(**ref_model.get_progress_bar_dict(), **self.logger_connector.progress_bar_metrics)

Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,7 @@ def setup_training(self, model: LightningModule):
# --------------------------
# Setup??
# --------------------------
ref_model = model
if self.trainer.data_parallel:
ref_model = model.module
ref_model = self.trainer.get_model()

# set the ranks and devices
self.trainer.accelerator_backend.dist.rank = self.trainer.global_rank
Expand Down
Loading