Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid wrapping LightningModule in *DataParallel overrides when not fitting #8632

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed truncated backprop through time enablement when set as a property on the LightningModule and not the Trainer ([#8804](https://github.com/PyTorchLightning/pytorch-lightning/pull/8804/))

- Fixed avoid wrapping LightningModule with data-parallel when not fitting in `DDPPlugin`, `DDPSpawnPlugin`, `DDPShardedPlugin`, `DDPSpawnShardedPlugin` ([#6977]https://github.com/PyTorchLightning/pytorch-lightning/issues/6977).


- Fixed plateau scheduler stepping on incomplete epoch ([#8861](https://github.com/PyTorchLightning/pytorch-lightning/pull/8861))

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def disable(self):
self.enable = False

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) # important :-)
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) # important
percent = (self.train_batch_idx / self.total_train_batches) * 100
sys.stdout.flush()
sys.stdout.write(f'{percent:.01f} percent complete \r')
Expand Down
18 changes: 14 additions & 4 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import (
_HYDRA_AVAILABLE,
_TORCH_GREATER_EQUAL_1_7,
Expand All @@ -44,6 +45,7 @@
)
from pytorch_lightning.utilities.distributed import (
distributed_available,
rank_zero_debug,
rank_zero_info,
rank_zero_only,
ReduceOp,
Expand Down Expand Up @@ -303,7 +305,12 @@ def _register_ddp_hooks(self) -> None:
ddp_comm_wrapper=self._ddp_comm_wrapper,
)

def configure_ddp(self):
def configure_ddp(self) -> None:
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
rank_zero_debug(f"In {trainer_fn} stage: Skipping wrapping the model with DistributedDataParallel")
return
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
Expand Down Expand Up @@ -389,13 +396,16 @@ def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
if isinstance(self.model, DistributedDataParallel):
return self.model(*args, **kwargs)
else:
return self.lightning_module.validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
return self.lightning_module.test_step(*args, **kwargs)

def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
return self.lightning_module.predict_step(*args, **kwargs)

def post_training_step(self):
if not self.lightning_module.automatic_optimization:
Expand Down
17 changes: 13 additions & 4 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import (
distributed_available,
rank_zero_debug,
rank_zero_info,
rank_zero_only,
ReduceOp,
Expand Down Expand Up @@ -253,7 +254,12 @@ def _register_ddp_hooks(self) -> None:
ddp_comm_wrapper=self._ddp_comm_wrapper,
)

def configure_ddp(self):
def configure_ddp(self) -> None:
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
rank_zero_debug(f"In {trainer_fn} stage: Skipping wrapping the model with DistributedDataParallel")
return
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
Expand Down Expand Up @@ -368,13 +374,16 @@ def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
if isinstance(self.model, DistributedDataParallel):
return self.model(*args, **kwargs)
else:
return self.lightning_module.validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
return self.lightning_module.test_step(*args, **kwargs)

def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
return self.lightning_module.predict_step(*args, **kwargs)

def post_training_step(self):
if not self.lightning_module.automatic_optimization:
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_debug
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _FAIRSCALE_AVAILABLE:
Expand All @@ -34,7 +35,12 @@ class DDPShardedPlugin(DDPPlugin):

_REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M

def configure_ddp(self):
def configure_ddp(self) -> None:
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
rank_zero_debug(f"In {trainer_fn} stage: Skipping wrapping the model with ShardedDataParallel")
return
self._wrap_optimizers()
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model),
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_debug
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _FAIRSCALE_AVAILABLE:
Expand All @@ -33,7 +34,12 @@
class DDPSpawnShardedPlugin(DDPSpawnPlugin):
"""Optimizer sharded training provided by FairScale."""

def configure_ddp(self):
def configure_ddp(self) -> None:
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
rank_zero_debug(f"In {trainer_fn} stage: Skipping wrapping the model with ShardedDataParallel")
return
self._wrap_optimizers()
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
Expand Down
28 changes: 27 additions & 1 deletion tests/plugins/test_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins import DDPPlugin
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -69,3 +69,29 @@ def test_ddp_barrier_non_consecutive_device_ids(barrier_mock, tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, gpus=gpus, accelerator="ddp")
trainer.fit(model)
barrier_mock.assert_any_call(device_ids=[gpus[trainer.local_rank]])


class BoringModelDDP(BoringModel):
def on_train_start(self) -> None:
"""Check if trainer module is wrapped as DistributedDataParallel during training stage."""
assert isinstance(self.trainer.model, DistributedDataParallel)

def on_test_start(self) -> None:
"""Check if trainer module remains as LightningModule during test stage."""
assert isinstance(self.trainer.model, LightningModule)

def on_predict_start(self) -> None:
"""Check if trainer module remains as LightningModule during prediction stage."""
assert isinstance(self.trainer.model, LightningModule)


@RunIf(skip_windows=True)
def test_ddp_module_wrapper():
"""Tests with ddp plugin."""
trainer = Trainer(num_processes=2, accelerator="ddp_cpu", fast_dev_run=True)

model = BoringModelDDP()

trainer.fit(model)
trainer.test(model, dataloaders=model.test_dataloader())
trainer.predict(model, dataloaders=model.predict_dataloader())
29 changes: 28 additions & 1 deletion tests/plugins/test_ddp_spawn_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.nn.parallel.distributed import DistributedDataParallel

from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins import DDPSpawnPlugin
from tests.helpers.boring_model import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -77,3 +78,29 @@ def test_ddp_spawn_extra_parameters(tmpdir):
trainer.fit(model, datamodule=dm)
assert trainer.callback_metrics[val_name] == torch.tensor(val)
assert model.test_val == "test_val"


class BoringModelDDP(BoringModel):
def on_train_start(self) -> None:
"""Check if trainer module is wrapped as DistributedDataParallel during training stage."""
assert isinstance(self.trainer.model, DistributedDataParallel)

def on_test_start(self) -> None:
"""Check if trainer module remains as LightningModule during test stage."""
assert isinstance(self.trainer.model, LightningModule)

def on_predict_start(self) -> None:
"""Check if trainer module remains as LightningModule during prediction stage."""
assert isinstance(self.trainer.model, LightningModule)


@RunIf(skip_windows=True)
def test_ddp_module_wrapper():
"""Tests with ddp spawn plugin."""
trainer = Trainer(num_processes=2, accelerator="ddp_spawn", fast_dev_run=True)

model = BoringModelDDP()

trainer.fit(model)
trainer.test(model, dataloaders=model.test_dataloader())
trainer.predict(model, dataloaders=model.predict_dataloader())
32 changes: 31 additions & 1 deletion tests/plugins/test_sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import DDPShardedPlugin, DDPSpawnShardedPlugin
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel


@pytest.mark.parametrize("clip_val", [0, 10])
@RunIf(min_gpus=1, skip_windows=True, amp_native=True, fairscale=True)
Expand Down Expand Up @@ -249,3 +253,29 @@ def test_ddp_sharded_plugin_manual_optimization(tmpdir):
model = ManualBoringModel()
trainer = Trainer(default_root_dir=tmpdir, accelerator="ddp_sharded", fast_dev_run=2, gpus=2)
trainer.fit(model)


class BoringModelSharded(BoringModel):
def on_train_start(self) -> None:
"""Check if trainer module is wrapped as ShardedDataParallel during training stage."""
assert isinstance(self.trainer.model, ShardedDataParallel)

def on_test_start(self) -> None:
"""Check if trainer module remains as LightningModule during test stage."""
assert isinstance(self.trainer.model, LightningModule)

def on_predict_start(self) -> None:
"""Check if trainer module remains as LightningModule during prediction stage."""
assert isinstance(self.trainer.model, LightningModule)


@RunIf(skip_windows=True, fairscale=True)
def test_sharded_module_wrapper():
"""Tests with ddp sharded plugin."""
trainer = Trainer(accelerator="ddp_sharded", fast_dev_run=True)

model = BoringModelSharded()

trainer.fit(model)
trainer.test(model, dataloaders=model.test_dataloader())
trainer.predict(model, dataloaders=model.predict_dataloader())