Skip to content

Commit

Permalink
Avoid wrapping LightningModule in DDP plugins when not fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
four4fish committed Aug 25, 2021
1 parent e9f4bff commit 1a2245a
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 13 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug in the binary search mode of auto batch size scaling where exception was thrown if the first trainer run resulted in OOM ([#8954](https://github.com/PyTorchLightning/pytorch-lightning/pull/8954))


- Fixed avoid wrapping LightningModule with data-parallel modules when not fitting in `DDPPlugin`, `DDPSpawnPlugin`, `DDPShardedPlugin`, `DDPSpawnShardedPlugin` ([#6977]https://github.com/PyTorchLightning/pytorch-lightning/issues/6977).


## [1.4.3] - 2021-08-17

- Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861))
Expand Down
18 changes: 14 additions & 4 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import (
_HYDRA_AVAILABLE,
_TORCH_GREATER_EQUAL_1_7,
Expand All @@ -46,6 +47,7 @@
from pytorch_lightning.utilities.distributed import (
distributed_available,
init_ddp_connection,
rank_zero_debug,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
Expand Down Expand Up @@ -302,7 +304,12 @@ def _register_ddp_hooks(self) -> None:
ddp_comm_wrapper=self._ddp_comm_wrapper,
)

def configure_ddp(self):
def configure_ddp(self) -> None:
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
rank_zero_debug(f"In {trainer_fn} stage: Skipping wrapping the model with DistributedDataParallel")
return
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
Expand Down Expand Up @@ -369,13 +376,16 @@ def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
if isinstance(self.model, DistributedDataParallel):
return self.model(*args, **kwargs)
else:
return self.lightning_module.validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
return self.lightning_module.test_step(*args, **kwargs)

def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
return self.lightning_module.predict_step(*args, **kwargs)

def post_training_step(self):
if not self.lightning_module.automatic_optimization:
Expand Down
17 changes: 13 additions & 4 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from pytorch_lightning.utilities.distributed import (
distributed_available,
init_ddp_connection,
rank_zero_debug,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
Expand Down Expand Up @@ -254,7 +255,12 @@ def _register_ddp_hooks(self) -> None:
ddp_comm_wrapper=self._ddp_comm_wrapper,
)

def configure_ddp(self):
def configure_ddp(self) -> None:
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
rank_zero_debug(f"In {trainer_fn} stage: Skipping wrapping the model with DistributedDataParallel")
return
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
Expand Down Expand Up @@ -344,13 +350,16 @@ def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
if isinstance(self.model, DistributedDataParallel):
return self.model(*args, **kwargs)
else:
return self.lightning_module.validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
return self.lightning_module.test_step(*args, **kwargs)

def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
return self.lightning_module.predict_step(*args, **kwargs)

def post_training_step(self):
if not self.lightning_module.automatic_optimization:
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_debug
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _FAIRSCALE_AVAILABLE:
Expand All @@ -34,7 +35,12 @@ class DDPShardedPlugin(DDPPlugin):

_REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M

def configure_ddp(self):
def configure_ddp(self) -> None:
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
rank_zero_debug(f"In {trainer_fn} stage: Skipping wrapping the model with ShardedDataParallel")
return
self._wrap_optimizers()
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model),
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_debug
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _FAIRSCALE_AVAILABLE:
Expand All @@ -33,7 +34,12 @@
class DDPSpawnShardedPlugin(DDPSpawnPlugin):
"""Optimizer sharded training provided by FairScale."""

def configure_ddp(self):
def configure_ddp(self) -> None:
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
rank_zero_debug(f"In {trainer_fn} stage: Skipping wrapping the model with ShardedDataParallel")
return
self._wrap_optimizers()
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
Expand Down
39 changes: 38 additions & 1 deletion tests/plugins/test_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import torch
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.trainer.states import TrainerFn
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -94,3 +95,39 @@ def creates_children(self):
RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`."
):
trainer.fit(model)


@RunIf(min_gpus=1)
def test_ddp_module_wrapper():
"""Tests with ddp plugin."""
model = BoringModel()
ddp_plugin = DDPPlugin()
trainer = Trainer(
max_epochs=1,
gpus=1,
plugins=[ddp_plugin],
)
# test wrap the model if fitting
trainer.state.fn = TrainerFn.FITTING
trainer.accelerator.connect(model)
trainer.accelerator.setup_environment()
trainer.accelerator.setup(trainer)
trainer.lightning_module.trainer = trainer
assert isinstance(trainer.model, LightningModule)
trainer._pre_dispatch()
# in DDPPlugin configure_ddp(), model wrapped by DistributedDataParallel
assert isinstance(trainer.model, DistributedDataParallel)

trainer = Trainer(
max_epochs=1,
gpus=1,
plugins=[ddp_plugin],
)
# test do not wrap the model if trainerFN is not fitting
trainer.accelerator.connect(model)
trainer.accelerator.setup_environment()
trainer.accelerator.setup(trainer)
trainer.lightning_module.trainer = trainer
trainer._pre_dispatch()
# in DDPPlugin configure_ddp(), model are still LightningModule
assert isinstance(trainer.model, LightningModule)
29 changes: 28 additions & 1 deletion tests/plugins/test_ddp_spawn_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.nn.parallel.distributed import DistributedDataParallel

from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins import DDPSpawnPlugin
from tests.helpers.boring_model import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -77,3 +78,29 @@ def test_ddp_spawn_extra_parameters(tmpdir):
trainer.fit(model, datamodule=dm)
assert trainer.callback_metrics[val_name] == torch.tensor(val)
assert model.test_val == "test_val"


class BoringModelDDP(BoringModel):
def on_train_start(self) -> None:
"""Check if trainer module is wrapped as DistributedDataParallel during training stage."""
assert isinstance(self.trainer.model, DistributedDataParallel)

def on_test_start(self) -> None:
"""Check if trainer module remains as LightningModule during test stage."""
assert isinstance(self.trainer.model, LightningModule)

def on_predict_start(self) -> None:
"""Check if trainer module remains as LightningModule during prediction stage."""
assert isinstance(self.trainer.model, LightningModule)


@RunIf(skip_windows=True)
def test_ddp_module_wrapper(tmpdir):
"""Tests with ddp spawn plugin."""
trainer = Trainer(default_root_dir=tmpdir, num_processes=2, accelerator="ddp_spawn", fast_dev_run=True)

model = BoringModelDDP()

trainer.fit(model)
trainer.test(model, dataloaders=model.test_dataloader())
trainer.predict(model, dataloaders=model.predict_dataloader())
32 changes: 31 additions & 1 deletion tests/plugins/test_sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel


@pytest.mark.parametrize("clip_val", [0, 10])
@RunIf(min_gpus=1, skip_windows=True, amp_native=True, fairscale=True)
Expand Down Expand Up @@ -249,3 +253,29 @@ def test_ddp_sharded_plugin_manual_optimization(tmpdir):
model = ManualBoringModel()
trainer = Trainer(default_root_dir=tmpdir, accelerator="ddp_sharded", fast_dev_run=2, gpus=2)
trainer.fit(model)


class BoringModelSharded(BoringModel):
def on_train_start(self) -> None:
"""Check if trainer module is wrapped as ShardedDataParallel during training stage."""
assert isinstance(self.trainer.model, ShardedDataParallel)

def on_test_start(self) -> None:
"""Check if trainer module remains as LightningModule during test stage."""
assert isinstance(self.trainer.model, LightningModule)

def on_predict_start(self) -> None:
"""Check if trainer module remains as LightningModule during prediction stage."""
assert isinstance(self.trainer.model, LightningModule)


@RunIf(skip_windows=True, fairscale=True)
def test_sharded_module_wrapper(tmpdir):
"""Tests with ddp sharded plugin."""
trainer = Trainer(default_root_dir=tmpdir, accelerator="ddp_sharded", fast_dev_run=True)

model = BoringModelSharded()

trainer.fit(model)
trainer.test(model, dataloaders=model.test_dataloader())
trainer.predict(model, dataloaders=model.predict_dataloader())

0 comments on commit 1a2245a

Please sign in to comment.