Skip to content

Commit

Permalink
Update setup logic in training type plugins (data-parallel) [3 / n] (#…
Browse files Browse the repository at this point in the history
…10010)



Co-authored-by: thomas chaton <thomas@grid.ai>
  • Loading branch information
awaelchli and tchaton authored Oct 19, 2021
1 parent 854bdc0 commit 4aaca17
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988))
* Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018))
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))

### Changed

Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import List, Optional

import torch
from torch.nn import DataParallel
from torch.nn import DataParallel, Module

from pytorch_lightning.overrides.data_parallel import LightningParallelModule
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
Expand Down Expand Up @@ -54,7 +54,11 @@ def world_size(self) -> int:
def setup(self) -> None:
# model needs to be moved to the device before it is wrapped
self.model_to_device()
self._model = DataParallel(LightningParallelModule(self._model), self.parallel_devices)
self._model = self._setup_model(LightningParallelModule(self._model))

def _setup_model(self, model: Module) -> DataParallel:
"""Wraps the given model into a :class:`~torch.nn.parallel.DataParallel` module."""
return DataParallel(module=model, device_ids=self.parallel_devices)

def reduce(self, collection: _METRIC_COLLECTION, *args, **kwargs) -> _METRIC_COLLECTION:
"""Reduces a collection of tensors from all processes. It can be applied to just a single tensor.
Expand Down

0 comments on commit 4aaca17

Please sign in to comment.