Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update setup logic in training type plugins (sharded) [4 / 4] #10028

Merged
merged 8 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994))
* Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010))
* Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009))
* Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_models_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028))
* Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023))


Expand Down
63 changes: 44 additions & 19 deletions pytorch_lightning/plugins/training_type/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Dict, Generator, Optional
from typing import Dict, Generator, List, Optional, Tuple, Union

import torch
from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
Expand All @@ -33,47 +35,70 @@
class DDPShardedPlugin(DDPPlugin):
"""Optimizer and gradient sharded training provided by FairScale."""

_REDUCE_BUFFER_SIZE_DEFAULT = 2 ** 23 # 8M
_REDUCE_BUFFER_SIZE_DEFAULT: int = 2 ** 23 # 8M

def configure_ddp(self) -> None:
self._wrap_optimizers()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._precision = None

def configure_ddp(self) -> None:
trainer = self.lightning_module.trainer
if "reduce_buffer_size" not in self._ddp_kwargs:
# For multi-node training, enabling bucketing will improve performance.
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0

self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model),
sharded_optimizer=self.lightning_module.trainer.optimizers,
**self._ddp_kwargs
[self._model], optimizers = self._setup_models_and_optimizers(
models=[LightningShardedDataParallel(self.model)],
optimizers=trainer.optimizers,
)
setattr(self._model, "require_backward_grad_sync", False)
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()

def _setup_models_and_optimizers(
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.

def _reinit_optimizers_with_oss(self):
optimizers = self.lightning_module.trainer.optimizers
Currently only one model can be setup at once.

Return:
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
"""
if len(models) > 1:
raise ValueError(
"DDPSharded only supports setting up a single model with one or several optimizers."
f" Got {len(models)} models."
)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

optimizers = self._wrap_optimizers(optimizers)
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
setattr(model, "require_backward_grad_sync", False) # TODO: needed?
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
return [model], optimizers

def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]:
for x, optimizer in enumerate(optimizers):
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
precision = self.lightning_module.trainer.precision
precision = self._precision or self.lightning_module.trainer.precision
is_fp16 = precision in ("mixed", 16)
# For multi-node training, compressing the model shards in fp16 before broadcasting
# improves performance. When using PyTorch AMP, it will not degrade
# the model performance.
zero_optimizer.broadcast_fp16 = is_fp16 and self.num_nodes > 1
optimizers[x] = zero_optimizer
del optimizer
trainer = self.lightning_module.trainer
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()
return optimizers

def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
return optimizers

def _wrap_optimizers(self):
if self.model.trainer.state.fn != TrainerFn.FITTING:
return
self._reinit_optimizers_with_oss()
return self._reinit_optimizers_with_oss(optimizers)

def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
if isinstance(optimizer, LightningOptimizer):
Expand Down
52 changes: 37 additions & 15 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.
from contextlib import contextmanager
from multiprocessing.queues import SimpleQueue
from typing import Dict, Generator, Optional
from typing import Dict, Generator, List, Optional, Tuple

import torch
from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
Expand All @@ -36,29 +38,49 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
"""Optimizer sharded training provided by FairScale."""

def configure_ddp(self) -> None:
self._wrap_optimizers()
self._model = ShardedDataParallel(
LightningShardedDataParallel(self.model),
sharded_optimizer=self.lightning_module.trainer.optimizers,
**self._ddp_kwargs
trainer = self.lightning_module.trainer
[self._model], optimizers = self._setup_models_and_optimizers(
models=[LightningShardedDataParallel(self.model)],
optimizers=trainer.optimizers,
)
setattr(self._model, "require_backward_grad_sync", False)
trainer.optimizers = optimizers

def _setup_models_and_optimizers(
self, models: List[Module], optimizers: List[Optimizer]
) -> Tuple[List[Module], List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.

def _reinit_optimizers_with_oss(self):
optimizers = self.lightning_module.trainer.optimizers
Currently only one model can be setup at once.

Return:
A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module
and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`.
"""
if len(models) > 1:
raise ValueError(
f"DDPShardedSpawn only supports setting up a single model with one or several optimizers."
f" Got {len(models)} models."
)

optimizers = self._wrap_optimizers(optimizers)
model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs)
setattr(model, "require_backward_grad_sync", False) # TODO: needed?
return [model], optimizers
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved

def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]:
for x, optimizer in enumerate(optimizers):
if not isinstance(optimizer, OSS):
optim_class = type(optimizer)
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
optimizers[x] = zero_optimizer
del optimizer
trainer = self.lightning_module.trainer
trainer.optimizers = optimizers
return optimizers

def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
if self.model is not None and self.model.trainer.state.fn != TrainerFn.FITTING:
return optimizers

def _wrap_optimizers(self):
if self.model.trainer.state.fn != TrainerFn.FITTING:
return
self._reinit_optimizers_with_oss()
return self._reinit_optimizers_with_oss(optimizers)

def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
if isinstance(optimizer, OSS):
Expand Down