-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Fix some test errors Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * checkpoint consolidation * Update ddp_spawn.py * Update test_metric_result_integration.py * Update test_results.py * Update utils.py * Update utils.py * Update test_all_gather_grad.py * Update test_all_gather_grad.py * Update test_results.py * Revert "Update test_results.py" This reverts commit 9d4a2b8. * Revert "Merge pull request #1 from shuyingsunshine21/shuyingsunshine21-checkpoint_consolidate" This reverts commit c5053da, reversing changes made to 0d23d75. * Revert "Update test_all_gather_grad.py" This reverts commit 0d23d75. * Revert "Update utils.py" This reverts commit 70fe5da. * Revert "Update utils.py" This reverts commit a9aae99. * Revert "Update test_results.py" This reverts commit ea74906. * Revert "Update test_metric_result_integration.py" This reverts commit bf70e43. * Revert "Update ddp_spawn.py" This reverts commit f172101. * Revert "checkpoint consolidation" This reverts commit 536c132. * Revert "Revert "checkpoint consolidation"" This reverts commit 3a9fde9. * Revert "Revert "Revert "checkpoint consolidation""" This reverts commit 7a369f4. * Revert "Revert "Update ddp_spawn.py"" This reverts commit 8222dc9. * Revert "Revert "Update test_metric_result_integration.py"" This reverts commit 6c095b2. * Revert "Revert "Update test_results.py"" This reverts commit 250d0aa. * Revert "Revert "Update utils.py"" This reverts commit 8651d54. * Revert "Revert "Update test_all_gather_grad.py"" This reverts commit dcdcd29. * modify distributed environment to make test pass * fix version for ddp plugin test * fix * fix * changelog * Update CHANGELOG.md * fsdp with full state dict * fix missing import * modify unitest * fix * fix * fix typo * modify test and add changelog * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * limit max_epoch to 1 for testing * test * fix * update * testing remove special for multi gpu * assert gpu * add assertion for gpu * fix * Re-enable special test, use ModelCheckpoint * Fix paths * Fix path passing * test * test * fix test * fix * pre-commit format * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: SeanNaren <sean@grid.ai>
- Loading branch information
1 parent
01109cd
commit 299f2c4
Showing
12 changed files
with
522 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
46 changes: 46 additions & 0 deletions
46
pytorch_lightning/plugins/precision/fully_sharded_native_amp.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Optional, Union | ||
|
||
from torch.nn import Module | ||
from torch.optim import Optimizer | ||
|
||
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin | ||
from pytorch_lightning.utilities import GradClipAlgorithmType | ||
|
||
|
||
class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): | ||
"""Mixed Precision for Full Sharded Training""" | ||
|
||
precision = "mixed" | ||
|
||
def clip_gradients( | ||
self, | ||
optimizer: Optimizer, | ||
clip_val: Union[int, float], | ||
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.VALUE, | ||
model: Optional[Module] = None, | ||
) -> None: | ||
clip_val = float(clip_val) | ||
if clip_val <= 0: | ||
return | ||
# see https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html | ||
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect | ||
# for FSDP module. To overcome this, needs to call sharded_module.clip_grad_norm(clip_val) | ||
# however we rely on LightningModule's configure_sharded_model to wrap FSDP, it would be hard to | ||
# trace back the root FSDP. Now we only support clip by value. | ||
assert ( | ||
gradient_clip_algorithm == GradClipAlgorithmType.VALUE | ||
), "`gradient_clip_algorithm`: `norm` is currently not supported for `FullyShardedNativeMixedPrecisionPlugin`" | ||
self.clip_grad_by_value(optimizer, clip_val) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
208 changes: 208 additions & 0 deletions
208
pytorch_lightning/plugins/training_type/fully_sharded.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import contextlib | ||
from typing import Any, Dict, Generator, List, Optional, Union | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch.nn import Module | ||
|
||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment | ||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin | ||
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
||
if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: | ||
from fairscale.nn import default_auto_wrap_policy, enable_wrap | ||
from fairscale.nn.data_parallel import FullyShardedDataParallel | ||
|
||
|
||
class DDPFullyShardedPlugin(DDPPlugin): | ||
|
||
def __init__( | ||
self, | ||
cpu_offload: bool = False, | ||
flatten_parameters: bool = True, | ||
reshard_after_forward: bool = True, | ||
move_grads_to_cpu: Optional[bool] = None, | ||
fp32_reduce_scatter: Optional[bool] = None, | ||
compute_dtype: Optional[torch.dtype] = None, | ||
bucket_cap_mb: int = 25, | ||
min_num_params: int = 1e8, | ||
state_dict_to_cpu: bool = True, | ||
parallel_devices: Optional[List[torch.device]] = None, | ||
cluster_environment: ClusterEnvironment = None, | ||
): | ||
""" | ||
Plugin for Fully Sharded Data Parallel provided by FairScale. | ||
Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model | ||
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain | ||
at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar | ||
to ZeRO-Stage 3 but has been built for upstreaming to PyTorch. | ||
`For more information: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`. | ||
.. warning:: ``FullyShardedPlugin`` is in beta and subject to change. | ||
Defaults have been set and options have been exposed, but may require configuration | ||
based on your level of memory/speed efficiency. We suggest having a look at this PR for more information. | ||
`https://github.com/facebookresearch/fairscale/pull/413` | ||
Many of the helpful doc strings below came from the original FairScale documentation: | ||
`https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html` | ||
Arguments: | ||
cpu_offload: Offload FP32 params to CPU. Only usable in precision=16 mode. | ||
(Default: False). | ||
move_grads_to_cpu: Moves gradient shards to CPU after reduction. | ||
Only disable if using CPU based optimizers | ||
(Default to ``cpu_offload``). | ||
flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency | ||
(Default: True). | ||
reshard_after_forward: Reshard parameters after the forward pass, which saves memory but slows | ||
down training. This is only relevant when resharding individual layers. | ||
(Default: True). | ||
fp32_reduce_scatter: Reduce-Scatter gradients in FP32. Only relevant in mixed precision | ||
(Default: None). | ||
compute_dtype: dtype for full parameters for computation. Default to torch.float32, | ||
unless using mixed precision, in which case defaults to torch.float16. | ||
(Default: None). | ||
bucket_cap_mb: bucket parameters so that gradient reduction | ||
can potentially overlap with backward computation. | ||
bucket_cap_mb controls the bucket size in MegaBytes (MB). | ||
Buckets are sub-divided based on world_size, | ||
so the max shard size is roughly bucket_cap_mb / world_size. | ||
Values <= 0 disable bucketing. | ||
(Default: 25). | ||
min_num_params: Number of parameters to wrap when using FairScale ``auto_wrap``. | ||
(Default: 1e8) | ||
state_dict_to_cpu: Whether to return parameters (returned by :func:`state_dict`) on CPU device. | ||
If ``False``, this will default to ``compute_device``. | ||
(Defautl: True). | ||
""" | ||
|
||
super().__init__( | ||
parallel_devices=parallel_devices, | ||
cluster_environment=cluster_environment, | ||
) | ||
self.cpu_offload = cpu_offload | ||
self.move_grads_to_cpu = move_grads_to_cpu | ||
self.flatten_parameters = flatten_parameters | ||
self.reshard_after_forward = reshard_after_forward | ||
self.fp32_reduce_scatter = fp32_reduce_scatter | ||
self.compute_dtype = compute_dtype | ||
self.bucket_cap_mb = bucket_cap_mb | ||
self.min_num_params = min_num_params | ||
self.state_dict_device = torch.device("cpu") if state_dict_to_cpu else None | ||
self._process_group = None | ||
|
||
@property | ||
def process_group(self): | ||
if self._process_group is None: | ||
self._process_group = torch.distributed.new_group() | ||
return self._process_group | ||
|
||
def setup_distributed(self) -> None: | ||
if not self.on_gpu: | ||
raise MisconfigurationException( | ||
"You selected accelerator to be `ddp_fully_sharded`, but GPU is not available." | ||
) | ||
super().setup_distributed() | ||
torch.cuda.set_device(self.root_device) | ||
|
||
@contextlib.contextmanager | ||
def model_sharded_context(self) -> Generator: | ||
precision = self.lightning_module.trainer.precision | ||
|
||
def wrap_policy(*args, **kwargs): | ||
return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params) | ||
|
||
with enable_wrap( | ||
wrapper_cls=FullyShardedDataParallel, | ||
auto_wrap_policy=wrap_policy, | ||
process_group=self.process_group, | ||
cpu_offload=self.cpu_offload, | ||
move_grads_to_cpu=self.move_grads_to_cpu, | ||
flatten_parameters=self.flatten_parameters, | ||
mixed_precision=precision == "mixed", | ||
reshard_after_forward=self.reshard_after_forward, | ||
fp32_reduce_scatter=self.fp32_reduce_scatter, | ||
compute_dtype=self.compute_dtype, | ||
bucket_cap_mb=self.bucket_cap_mb, | ||
state_dict_device=self.state_dict_device, | ||
): | ||
yield | ||
|
||
def connect(self, model: Module) -> None: | ||
super().connect(model) | ||
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) | ||
if not model_call_configure_sharded_model_hook: | ||
# if model has not called configure sharded model, we reset | ||
# the training type plugin's call_configure_sharded_model_hook | ||
# to give trainer a chance to configure. | ||
self.call_configure_sharded_model_hook = True | ||
|
||
def configure_ddp(self) -> None: | ||
if not self.cpu_offload: | ||
# When using CPU Offload, FSDP will manage the CUDA movement for us. | ||
# Note: this would be problematic for large model (which could not fit in one GPU) | ||
# as FSDP module.to(device) would first summon all parameters | ||
# (TODO: need to figure out solution) | ||
self.model_to_device() | ||
|
||
# setup optimizers after fully sharded has wrapped the lightning module | ||
self.lightning_module.trainer.accelerator.setup_optimizers(self.lightning_module.trainer) | ||
|
||
def pre_dispatch(self) -> None: | ||
if self.sync_batchnorm: | ||
self.model = self.configure_sync_batchnorm(self.model) | ||
self.configure_ddp() | ||
self.barrier() | ||
|
||
def model_to_device(self) -> None: | ||
# ensure we update the device type in the lightning module | ||
self.lightning_module.to(self.root_device) | ||
|
||
def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: | ||
# Currently it is same as default TrainingTypePlugin, i.e. return | ||
# the full state dict for FSDP, in the future, we will provide sharded | ||
# state dict. | ||
return super().lightning_module_state_dict() | ||
|
||
@property | ||
def setup_optimizers_in_pre_dispatch(self) -> bool: | ||
# Setup optimizers after the Fully Sharded Model has been made | ||
return True | ||
|
||
def training_step(self, *args, **kwargs): | ||
return self.model.training_step(*args, **kwargs) | ||
|
||
def validation_step(self, *args, **kwargs): | ||
return self.model.validation_step(*args, **kwargs) | ||
|
||
def test_step(self, *args, **kwargs): | ||
return self.model.test_step(*args, **kwargs) | ||
|
||
def predict_step(self, *args, **kwargs): | ||
return self.model.predict_step(*args, **kwargs) | ||
|
||
def post_training_step(self): | ||
pass | ||
|
||
@classmethod | ||
def register_plugins(cls, plugin_registry: Dict): | ||
plugin_registry.register( | ||
"fsdp", | ||
cls, | ||
description="Fully sharded training with checkpointing the full state dict.", | ||
) |
Oops, something went wrong.