Skip to content

Commit

Permalink
update non training steps in ddp and ddp_spawn
Browse files Browse the repository at this point in the history
  • Loading branch information
ninginthecloud committed Aug 2, 2021
1 parent 94d155d commit 5f081bc
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 40 deletions.
35 changes: 21 additions & 14 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,16 +290,6 @@ def pre_configure_ddp(self):
self._ddp_kwargs["find_unused_parameters"] = True

def _wrap_model(self) -> None:
# skip warpping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
self._model = (
LightningDistributedModule(self.model)
if not isinstance(self.model, (LightningDistributedModule))
else self.model
)
rank_zero_debug(f"In {trainer_fn} stage: Skipping wrapping the model with DistributedDataParallel")
return
self._model = DistributedDataParallel(
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
)
Expand All @@ -318,6 +308,11 @@ def _register_ddp_hooks(self) -> None:
)

def configure_ddp(self):
# skip warpping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
rank_zero_debug(f"In {trainer_fn} stage: Skipping wrapping the model with DistributedDataParallel")
return
self.pre_configure_ddp()
self._wrap_model()
self._register_ddp_hooks()
Expand Down Expand Up @@ -398,16 +393,28 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp,
return tensor

def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
if isinstance(self.model, DistributedDataParallel):
return self.model(*args, **kwargs)
else:
return self.model.training_step(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
if isinstance(self.model, DistributedDataParallel):
return self.model(*args, **kwargs)
else:
return self.model.validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
if isinstance(self.model, DistributedDataParallel):
return self.model(*args, **kwargs)
else:
return self.model.test_step(*args, **kwargs)

def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
if isinstance(self.model, DistributedDataParallel):
return self.model(*args, **kwargs)
else:
return self.model.predict_step(*args, **kwargs)

def post_training_step(self):
if not self.lightning_module.automatic_optimization:
Expand Down
35 changes: 21 additions & 14 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,6 @@ def pre_configure_ddp(self):
self._ddp_kwargs["find_unused_parameters"] = True

def _wrap_model(self) -> None:
# skip warpping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
self._model = (
LightningDistributedModule(self.model)
if not isinstance(self.model, (LightningDistributedModule))
else self.model
)
rank_zero_debug(f"In {trainer_fn} stage: Skipping wrapping the model with DistributedDataParallel")
return
self._model = DistributedDataParallel(
LightningDistributedModule(self.model), device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs
)
Expand All @@ -261,6 +251,11 @@ def _register_ddp_hooks(self) -> None:
)

def configure_ddp(self):
# skip warpping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
rank_zero_debug(f"In {trainer_fn} stage: Skipping wrapping the model with DistributedDataParallel")
return
self.pre_configure_ddp()
self._wrap_model()
self._register_ddp_hooks()
Expand Down Expand Up @@ -370,16 +365,28 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp,
return tensor

def training_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
if isinstance(self.model, DistributedDataParallel):
return self.model(*args, **kwargs)
else:
return self.model.training_step(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
if isinstance(self.model, DistributedDataParallel):
return self.model(*args, **kwargs)
else:
return self.model.validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
if isinstance(self.model, DistributedDataParallel):
return self.model(*args, **kwargs)
else:
return self.model.test_step(*args, **kwargs)

def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)
if isinstance(self.model, DistributedDataParallel):
return self.model(*args, **kwargs)
else:
return self.model.predict_step(*args, **kwargs)

def post_training_step(self):
if not self.lightning_module.automatic_optimization:
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ class DDPShardedPlugin(DDPPlugin):
_REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M

def configure_ddp(self):
# skip warpping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
rank_zero_debug(f"In {trainer_fn} stage: Skipping wrapping the model with ShardedDataParallel")
return
self._wrap_optimizers()
self._wrap_model()
setattr(self._model, "require_backward_grad_sync", False)
Expand Down Expand Up @@ -67,12 +72,6 @@ def _wrap_optimizers(self):
self._reinit_optimizers_with_oss()

def _wrap_model(self) -> None:
# skip warpping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
self._model = self.model
rank_zero_debug(f"In {trainer_fn} stage: Skipping wrapping the model with ShardedDataParallel")
return
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model),
sharded_optimizer=self.lightning_module.trainer.optimizers,
Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
"""Optimizer sharded training provided by FairScale."""

def configure_ddp(self):
# skip warpping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
rank_zero_debug(f"In {trainer_fn} stage: Skipping wrapping the model with ShardedDataParallel")
return
self._wrap_optimizers()
self._wrap_model()
setattr(self._model, "require_backward_grad_sync", False)
Expand All @@ -56,12 +61,6 @@ def _wrap_optimizers(self):
self._reinit_optimizers_with_oss()

def _wrap_model(self):
# skip warpping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn != TrainerFn.FITTING:
self._model = self.model
rank_zero_debug(f"In {trainer_fn} stage: Skipping wrapping the model with ShardedDataParallel")
return
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model), sharded_optimizer=self.lightning_module.trainer.optimizers
)
Expand Down

0 comments on commit 5f081bc

Please sign in to comment.