Skip to content

Commit

Permalink
FSDP with full state dict (#7487)
Browse files Browse the repository at this point in the history
* Fix some test errors
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* checkpoint consolidation

* Update ddp_spawn.py

* Update test_metric_result_integration.py

* Update test_results.py

* Update utils.py

* Update utils.py

* Update test_all_gather_grad.py

* Update test_all_gather_grad.py

* Update test_results.py

* Revert "Update test_results.py"

This reverts commit 9d4a2b8.

* Revert "Merge pull request #1 from shuyingsunshine21/shuyingsunshine21-checkpoint_consolidate"

This reverts commit c5053da, reversing
changes made to 0d23d75.

* Revert "Update test_all_gather_grad.py"

This reverts commit 0d23d75.

* Revert "Update utils.py"

This reverts commit 70fe5da.

* Revert "Update utils.py"

This reverts commit a9aae99.

* Revert "Update test_results.py"

This reverts commit ea74906.

* Revert "Update test_metric_result_integration.py"

This reverts commit bf70e43.

* Revert "Update ddp_spawn.py"

This reverts commit f172101.

* Revert "checkpoint consolidation"

This reverts commit 536c132.

* Revert "Revert "checkpoint consolidation""

This reverts commit 3a9fde9.

* Revert "Revert "Revert "checkpoint consolidation"""

This reverts commit 7a369f4.

* Revert "Revert "Update ddp_spawn.py""

This reverts commit 8222dc9.

* Revert "Revert "Update test_metric_result_integration.py""

This reverts commit 6c095b2.

* Revert "Revert "Update test_results.py""

This reverts commit 250d0aa.

* Revert "Revert "Update utils.py""

This reverts commit 8651d54.

* Revert "Revert "Update test_all_gather_grad.py""

This reverts commit dcdcd29.

* modify distributed environment to make test pass

* fix version for ddp plugin test

* fix

* fix

* changelog

* Update CHANGELOG.md

* fsdp with full state dict

* fix missing import

* modify unitest

* fix

* fix

* fix typo

* modify test and add changelog

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* limit max_epoch to 1 for testing

* test

* fix

* update

* testing remove special for multi gpu

* assert gpu

* add assertion for gpu

* fix

* Re-enable special test, use ModelCheckpoint

* Fix paths

* Fix path passing

* test

* test

* fix test

* fix

* pre-commit format

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: SeanNaren <sean@grid.ai>
  • Loading branch information
4 people authored May 24, 2021
1 parent 01109cd commit 299f2c4
Show file tree
Hide file tree
Showing 12 changed files with 522 additions and 13 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added correct `dataloader_idx` to batch transfer hooks ([#6241](https://github.com/PyTorchLightning/pytorch-lightning/pull/6241))


- Added `ddp_fully_sharded` support ([#7487](https://github.com/PyTorchLightning/pytorch-lightning/pull/7487))


### Changed

- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)
Expand Down
20 changes: 13 additions & 7 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
FullyShardedNativeMixedPrecisionPlugin,
)
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
Expand All @@ -15,6 +18,7 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
Expand All @@ -32,24 +36,26 @@
"DDP2Plugin",
"DDPPlugin",
"DDPSpawnPlugin",
"DDPFullyShardedPlugin",
"DeepSpeedPlugin",
"DeepSpeedPrecisionPlugin",
"DoublePrecisionPlugin",
"HorovodPlugin",
"NativeMixedPrecisionPlugin",
"PrecisionPlugin",
"ShardedNativeMixedPrecisionPlugin",
"FullyShardedNativeMixedPrecisionPlugin"
"SingleDevicePlugin",
"SingleTPUPlugin",
"TPUHalfPrecisionPlugin",
"TPUSpawnPlugin",
'RPCPlugin',
'RPCSequentialPlugin',
'TrainingTypePlugin',
'ParallelPlugin',
'Plugin',
'DDPShardedPlugin',
'DDPSpawnShardedPlugin',
"RPCPlugin",
"RPCSequentialPlugin",
"TrainingTypePlugin",
"ParallelPlugin",
"Plugin",
"DDPShardedPlugin",
"DDPSpawnShardedPlugin",
]

from pathlib import Path
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
FullyShardedNativeMixedPrecisionPlugin,
)
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
Expand Down
46 changes: 46 additions & 0 deletions pytorch_lightning/plugins/precision/fully_sharded_native_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union

from torch.nn import Module
from torch.optim import Optimizer

from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType


class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
"""Mixed Precision for Full Sharded Training"""

precision = "mixed"

def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.VALUE,
model: Optional[Module] = None,
) -> None:
clip_val = float(clip_val)
if clip_val <= 0:
return
# see https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect
# for FSDP module. To overcome this, needs to call sharded_module.clip_grad_norm(clip_val)
# however we rely on LightningModule's configure_sharded_model to wrap FSDP, it would be hard to
# trace back the root FSDP. Now we only support clip by value.
assert (
gradient_clip_algorithm == GradClipAlgorithmType.VALUE
), "`gradient_clip_algorithm`: `norm` is currently not supported for `FullyShardedNativeMixedPrecisionPlugin`"
self.clip_grad_by_value(optimizer, clip_val)
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
Expand Down
208 changes: 208 additions & 0 deletions pytorch_lightning/plugins/training_type/fully_sharded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from typing import Any, Dict, Generator, List, Optional, Union

import torch
from torch import Tensor
from torch.nn import Module

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
from fairscale.nn import default_auto_wrap_policy, enable_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel


class DDPFullyShardedPlugin(DDPPlugin):

def __init__(
self,
cpu_offload: bool = False,
flatten_parameters: bool = True,
reshard_after_forward: bool = True,
move_grads_to_cpu: Optional[bool] = None,
fp32_reduce_scatter: Optional[bool] = None,
compute_dtype: Optional[torch.dtype] = None,
bucket_cap_mb: int = 25,
min_num_params: int = 1e8,
state_dict_to_cpu: bool = True,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: ClusterEnvironment = None,
):
"""
Plugin for Fully Sharded Data Parallel provided by FairScale.
Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain
at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar
to ZeRO-Stage 3 but has been built for upstreaming to PyTorch.
`For more information: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`.
.. warning:: ``FullyShardedPlugin`` is in beta and subject to change.
Defaults have been set and options have been exposed, but may require configuration
based on your level of memory/speed efficiency. We suggest having a look at this PR for more information.
`https://github.com/facebookresearch/fairscale/pull/413`
Many of the helpful doc strings below came from the original FairScale documentation:
`https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`
Arguments:
cpu_offload: Offload FP32 params to CPU. Only usable in precision=16 mode.
(Default: False).
move_grads_to_cpu: Moves gradient shards to CPU after reduction.
Only disable if using CPU based optimizers
(Default to ``cpu_offload``).
flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency
(Default: True).
reshard_after_forward: Reshard parameters after the forward pass, which saves memory but slows
down training. This is only relevant when resharding individual layers.
(Default: True).
fp32_reduce_scatter: Reduce-Scatter gradients in FP32. Only relevant in mixed precision
(Default: None).
compute_dtype: dtype for full parameters for computation. Default to torch.float32,
unless using mixed precision, in which case defaults to torch.float16.
(Default: None).
bucket_cap_mb: bucket parameters so that gradient reduction
can potentially overlap with backward computation.
bucket_cap_mb controls the bucket size in MegaBytes (MB).
Buckets are sub-divided based on world_size,
so the max shard size is roughly bucket_cap_mb / world_size.
Values <= 0 disable bucketing.
(Default: 25).
min_num_params: Number of parameters to wrap when using FairScale ``auto_wrap``.
(Default: 1e8)
state_dict_to_cpu: Whether to return parameters (returned by :func:`state_dict`) on CPU device.
If ``False``, this will default to ``compute_device``.
(Defautl: True).
"""

super().__init__(
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
)
self.cpu_offload = cpu_offload
self.move_grads_to_cpu = move_grads_to_cpu
self.flatten_parameters = flatten_parameters
self.reshard_after_forward = reshard_after_forward
self.fp32_reduce_scatter = fp32_reduce_scatter
self.compute_dtype = compute_dtype
self.bucket_cap_mb = bucket_cap_mb
self.min_num_params = min_num_params
self.state_dict_device = torch.device("cpu") if state_dict_to_cpu else None
self._process_group = None

@property
def process_group(self):
if self._process_group is None:
self._process_group = torch.distributed.new_group()
return self._process_group

def setup_distributed(self) -> None:
if not self.on_gpu:
raise MisconfigurationException(
"You selected accelerator to be `ddp_fully_sharded`, but GPU is not available."
)
super().setup_distributed()
torch.cuda.set_device(self.root_device)

@contextlib.contextmanager
def model_sharded_context(self) -> Generator:
precision = self.lightning_module.trainer.precision

def wrap_policy(*args, **kwargs):
return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params)

with enable_wrap(
wrapper_cls=FullyShardedDataParallel,
auto_wrap_policy=wrap_policy,
process_group=self.process_group,
cpu_offload=self.cpu_offload,
move_grads_to_cpu=self.move_grads_to_cpu,
flatten_parameters=self.flatten_parameters,
mixed_precision=precision == "mixed",
reshard_after_forward=self.reshard_after_forward,
fp32_reduce_scatter=self.fp32_reduce_scatter,
compute_dtype=self.compute_dtype,
bucket_cap_mb=self.bucket_cap_mb,
state_dict_device=self.state_dict_device,
):
yield

def connect(self, model: Module) -> None:
super().connect(model)
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
if not model_call_configure_sharded_model_hook:
# if model has not called configure sharded model, we reset
# the training type plugin's call_configure_sharded_model_hook
# to give trainer a chance to configure.
self.call_configure_sharded_model_hook = True

def configure_ddp(self) -> None:
if not self.cpu_offload:
# When using CPU Offload, FSDP will manage the CUDA movement for us.
# Note: this would be problematic for large model (which could not fit in one GPU)
# as FSDP module.to(device) would first summon all parameters
# (TODO: need to figure out solution)
self.model_to_device()

# setup optimizers after fully sharded has wrapped the lightning module
self.lightning_module.trainer.accelerator.setup_optimizers(self.lightning_module.trainer)

def pre_dispatch(self) -> None:
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)
self.configure_ddp()
self.barrier()

def model_to_device(self) -> None:
# ensure we update the device type in the lightning module
self.lightning_module.to(self.root_device)

def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]:
# Currently it is same as default TrainingTypePlugin, i.e. return
# the full state dict for FSDP, in the future, we will provide sharded
# state dict.
return super().lightning_module_state_dict()

@property
def setup_optimizers_in_pre_dispatch(self) -> bool:
# Setup optimizers after the Fully Sharded Model has been made
return True

def training_step(self, *args, **kwargs):
return self.model.training_step(*args, **kwargs)

def validation_step(self, *args, **kwargs):
return self.model.validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs):
return self.model.test_step(*args, **kwargs)

def predict_step(self, *args, **kwargs):
return self.model.predict_step(*args, **kwargs)

def post_training_step(self):
pass

@classmethod
def register_plugins(cls, plugin_registry: Dict):
plugin_registry.register(
"fsdp",
cls,
description="Fully sharded training with checkpointing the full state dict.",
)
Loading

0 comments on commit 299f2c4

Please sign in to comment.