Skip to content

Commit

Permalink
Avoid wrapping LightningModule in DDP plugins when not fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
four4fish committed Aug 27, 2021
1 parent 53885af commit 4a7c7a4
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 23 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug in the binary search mode of auto batch size scaling where exception was thrown if the first trainer run resulted in OOM ([#8954](https://github.com/PyTorchLightning/pytorch-lightning/pull/8954))


- Fixed avoid wrapping LightningModule with data-parallel modules when not fitting in `DDPPlugin`, `DDPSpawnPlugin`, `DDPShardedPlugin`, `DDPSpawnShardedPlugin` ([#6977]https://github.com/PyTorchLightning/pytorch-lightning/issues/6977).


## [1.4.3] - 2021-08-17

- Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861))
Expand Down
27 changes: 18 additions & 9 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
)
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _TORCH_GREATER_EQUAL_1_10:
from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer
Expand Down Expand Up @@ -361,7 +362,7 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()

def configure_ddp(self):
def configure_ddp(self) -> None:
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
Expand All @@ -380,7 +381,10 @@ def pre_dispatch(self):
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

self.configure_ddp()
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()

# share ddp pids to all processes
self._share_information_to_prevent_deadlock()
Expand Down Expand Up @@ -424,17 +428,22 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp,
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor

def training_step(self, *args, **kwargs):
def training_step(self, *args, **kwargs) -> Optional[Any]:
return self.model(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
if isinstance(self.model, DistributedDataParallel):
# used when calling `trainer.fit`
return self.model(*args, **kwargs)
else:
# used when calling `trainer.validate`
return self.lightning_module.validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
return self.lightning_module.test_step(*args, **kwargs)

def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def predict_step(self, *args, **kwargs) -> Any:
return self.lightning_module.predict_step(*args, **kwargs)

def post_training_step(self):
if not self.lightning_module.automatic_optimization:
Expand Down
27 changes: 18 additions & 9 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
sync_ddp_if_available,
)
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
Expand Down Expand Up @@ -201,7 +202,10 @@ def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQ
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

self.configure_ddp()
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()

self.barrier()

Expand Down Expand Up @@ -254,7 +258,7 @@ def _register_ddp_hooks(self) -> None:
ddp_comm_wrapper=self._ddp_comm_wrapper,
)

def configure_ddp(self):
def configure_ddp(self) -> None:
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
Expand Down Expand Up @@ -340,17 +344,22 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp,
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor

def training_step(self, *args, **kwargs):
def training_step(self, *args, **kwargs) -> Optional[Any]:
return self.model(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
if isinstance(self.model, DistributedDataParallel):
# used when calling `trainer.fit`
return self.model(*args, **kwargs)
else:
# used when calling `trainer.validate`
return self.lightning_module.validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
return self.lightning_module.test_step(*args, **kwargs)

def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
def predict_step(self, *args, **kwargs) -> Any:
return self.lightning_module.predict_step(*args, **kwargs)

def post_training_step(self):
if not self.lightning_module.automatic_optimization:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DDPShardedPlugin(DDPPlugin):

_REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M

def configure_ddp(self):
def configure_ddp(self) -> None:
self._wrap_optimizers()
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model),
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
class DDPSpawnShardedPlugin(DDPSpawnPlugin):
"""Optimizer sharded training provided by FairScale."""

def configure_ddp(self):
def configure_ddp(self) -> None:
self._wrap_optimizers()
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
Expand Down
37 changes: 36 additions & 1 deletion tests/plugins/test_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import torch
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.trainer.states import TrainerFn
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -94,3 +95,37 @@ def creates_children(self):
RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`."
):
trainer.fit(model)


@RunIf(skip_windows=True)
def test_ddp_module_wrapper():
"""Tests with ddp plugin."""
model = BoringModel()
ddp_plugin = DDPPlugin()
trainer = Trainer(
max_epochs=1,
plugins=[ddp_plugin],
)
# test wrap the model if fitting
trainer.state.fn = TrainerFn.FITTING
trainer.accelerator.connect(model)
trainer.accelerator.setup_environment()
trainer.accelerator.setup(trainer)
trainer.lightning_module.trainer = trainer
assert isinstance(trainer.model, LightningModule)
trainer._pre_dispatch()
# in DDPPlugin configure_ddp(), model wrapped by DistributedDataParallel
assert isinstance(trainer.model, DistributedDataParallel)

trainer = Trainer(
max_epochs=1,
plugins=[ddp_plugin],
)
# test do not wrap the model if trainerFN is not fitting
trainer.accelerator.connect(model)
trainer.accelerator.setup_environment()
trainer.accelerator.setup(trainer)
trainer.lightning_module.trainer = trainer
trainer._pre_dispatch()
# in DDPPlugin configure_ddp(), model are still LightningModule
assert isinstance(trainer.model, LightningModule)
38 changes: 37 additions & 1 deletion tests/plugins/test_ddp_spawn_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.nn.parallel.distributed import DistributedDataParallel

from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerFn
from tests.helpers.boring_model import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -77,3 +79,37 @@ def test_ddp_spawn_extra_parameters(tmpdir):
trainer.fit(model, datamodule=dm)
assert trainer.callback_metrics[val_name] == torch.tensor(val)
assert model.test_val == "test_val"


class BoringModelDDP(BoringModel):
def on_train_start(self) -> None:
"""Check if trainer module is wrapped as DistributedDataParallel during training stage."""
assert isinstance(self.trainer.model, DistributedDataParallel)

def on_validation_start(self) -> None:
"""Check if trainer module remains as LightningModule during test stage."""
if self.trainer.state.fn == TrainerFn.FITTING:
assert isinstance(self.trainer.model, DistributedDataParallel)
else:
assert isinstance(self.trainer.model, LightningModule)

def on_test_start(self) -> None:
"""Check if trainer module remains as LightningModule during test stage."""
assert isinstance(self.trainer.model, LightningModule)

def on_predict_start(self) -> None:
"""Check if trainer module remains as LightningModule during prediction stage."""
assert isinstance(self.trainer.model, LightningModule)


@RunIf(skip_windows=True)
def test_ddp_module_wrapper(tmpdir):
"""Tests with ddp spawn plugin."""
trainer = Trainer(default_root_dir=tmpdir, num_processes=2, accelerator="ddp_spawn", fast_dev_run=True)

model = BoringModelDDP()

trainer.fit(model)
trainer.validate(model, dataloaders=model.val_dataloader())
trainer.test(model, dataloaders=model.test_dataloader())
trainer.predict(model, dataloaders=model.predict_dataloader())
41 changes: 40 additions & 1 deletion tests/plugins/test_sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel


@pytest.mark.parametrize("clip_val", [0, 10])
@RunIf(min_gpus=1, skip_windows=True, amp_native=True, fairscale=True)
Expand Down Expand Up @@ -249,3 +254,37 @@ def test_ddp_sharded_plugin_manual_optimization(tmpdir):
model = ManualBoringModel()
trainer = Trainer(default_root_dir=tmpdir, accelerator="ddp_sharded", fast_dev_run=2, gpus=2)
trainer.fit(model)


class BoringModelSharded(BoringModel):
def on_train_start(self) -> None:
"""Check if trainer module is wrapped as ShardedDataParallel during training stage."""
assert isinstance(self.trainer.model, ShardedDataParallel)

def on_test_start(self) -> None:
"""Check if trainer module remains as LightningModule during test stage."""
assert isinstance(self.trainer.model, LightningModule)

def on_validation_start(self) -> None:
"""Check if trainer module remains as LightningModule during test stage."""
if self.trainer.state.fn == TrainerFn.FITTING:
assert isinstance(self.trainer.model, ShardedDataParallel)
else:
assert isinstance(self.trainer.model, LightningModule)

def on_predict_start(self) -> None:
"""Check if trainer module remains as LightningModule during prediction stage."""
assert isinstance(self.trainer.model, LightningModule)


@RunIf(skip_windows=True, fairscale=True)
def test_sharded_module_wrapper(tmpdir):
"""Tests with ddp sharded plugin."""
trainer = Trainer(default_root_dir=tmpdir, accelerator="ddp_sharded", fast_dev_run=True)

model = BoringModelSharded()

trainer.fit(model)
trainer.test(model, dataloaders=model.test_dataloader())
trainer.validate(model, dataloaders=model.val_dataloader)
trainer.predict(model, dataloaders=model.predict_dataloader())

0 comments on commit 4a7c7a4

Please sign in to comment.