Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed Integration #5954

Merged
merged 68 commits into from
Feb 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
2163721
Add initial deepspeed changes
Feb 13, 2021
14c7b61
Address code review
Feb 14, 2021
11a8ab9
Merge branch 'master' into feature/deepspeed2
SeanNaren Feb 14, 2021
e8ab7fd
Move static method outside of function
Feb 15, 2021
5b1e091
Fixes
Feb 15, 2021
d6d90be
Add missing annotation
Feb 15, 2021
ab4efdf
Remove seed setting
Feb 15, 2021
91abbc4
Merge branch 'master' into feature/deepspeed2
SeanNaren Feb 15, 2021
bffb916
Doc changes
Feb 15, 2021
5c4444d
Doc changes, add address reviews
Feb 15, 2021
978470c
Fix docs
Feb 15, 2021
41389b9
Try fixing issue by moving to torch adam
Feb 15, 2021
b1cf9c0
Clean up check
Feb 16, 2021
beea306
Changes, better APIs!
Feb 16, 2021
64da158
Merge branch 'master' into feature/deepspeed2
Feb 16, 2021
2c659fe
Add wrapper, swap to git install revision
Feb 16, 2021
4b295cd
Add special test
Feb 16, 2021
bb39215
Merge branch 'master' into feature/deepspeed2
SeanNaren Feb 16, 2021
caeac52
Add warning
Feb 16, 2021
fcc2f99
Address review
Feb 16, 2021
91cc1e0
Add better disclaimer
Feb 16, 2021
05c4c51
Merge branch 'master' into feature/deepspeed2
SeanNaren Feb 16, 2021
37542d6
Turn off ZeRO for testing due to compilation
Feb 16, 2021
a11695d
Add description on modifying parameters via the plugin
Feb 16, 2021
f4585f0
Doc strings clear
Feb 16, 2021
887112d
Merge branch 'master' into feature/deepspeed2
Feb 16, 2021
52b654d
Small doc fixes
Feb 16, 2021
17e2252
Merge branch 'master' into feature/deepspeed2
Feb 16, 2021
b06bd2c
Fix hash, reduce test
Feb 16, 2021
83535fb
Added CI change
Feb 16, 2021
a1e487d
Move to azure pipeline
Feb 16, 2021
535800c
Fix test name
Feb 16, 2021
f326068
Merge branch 'master' into feature/deepspeed2
Feb 16, 2021
471ccdf
Add missing flag
Feb 16, 2021
e458c19
Remove sudo...
Feb 16, 2021
9826ca8
Try conda instead
Feb 16, 2021
45ea290
Swap to conda base
Feb 16, 2021
9272a95
Try suggested install
Feb 16, 2021
e06ec29
Apply suggestions from code review
Borda Feb 16, 2021
41cca05
Apply suggestions from code review
Borda Feb 16, 2021
37f2d9d
Revert "Apply suggestions from code review"
Feb 17, 2021
c015530
Revert "Apply suggestions from code review"
Feb 17, 2021
054b320
Remove setter
Feb 17, 2021
87e0a92
Merge branch 'master' into feature/deepspeed2
tchaton Feb 17, 2021
301c32d
Address most review
Feb 17, 2021
3fda074
Move out function, remove DeepSpeed from requirements
Feb 17, 2021
d969d28
Install deepspeed/mpi4py within container
Feb 17, 2021
5d993ec
Use special tests, move to master commit for deepspeed
Feb 17, 2021
62f3048
Export path
Feb 17, 2021
ec79096
Force compile to happen first
Feb 17, 2021
894a6dd
Remove!
Feb 17, 2021
e735358
Debugging ninja
Feb 17, 2021
60063f2
Fix error in optimizer step logic
Feb 17, 2021
5aa9acc
Attempt to fix symbolic link
Feb 17, 2021
b68a539
Reverse to aid debugging
Feb 17, 2021
0878927
Export path again
Feb 17, 2021
7dd17d3
Clean up mess
Feb 17, 2021
3450eac
var
Borda Feb 17, 2021
0b9e7d5
Revert "var"
Feb 17, 2021
e0a2d6b
Address review, add todo
Feb 17, 2021
c565bb1
Add note about unsupported functionality
Feb 17, 2021
a9ba173
Merge branch 'master' into feature/deepspeed2
Feb 17, 2021
fbaf86f
Update docs/source/advanced/multi_gpu.rst
SeanNaren Feb 17, 2021
1e7dcd6
Address review
Feb 17, 2021
ea1d78c
Remove import
Feb 17, 2021
10da87e
Add tmpdir
Feb 17, 2021
f789b77
Add note
Feb 17, 2021
31d6267
Add note
Feb 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ jobs:
pip list
displayName: 'Install dependencies'

- bash: |
# Temporary fix till DeepSpeed release, move this into CUDA image
pip install deepspeed@git+https://github.com/microsoft/DeepSpeed@ec8b1cb
displayName: 'Install DeepSpeed'

- script: |
python tests/collect_env_details.py
displayName: 'Env details'
Expand All @@ -76,7 +81,9 @@ jobs:
python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=50
displayName: 'Testing: standard'

- script: |
- bash: |
# Required for Ninja binary for building extensions, which is installed at this location
export PATH=$PATH:/home/AzDevOps_azpcontainer/.local/bin
sh tests/special_tests.sh
displayName: 'Testing: special'

Expand Down
145 changes: 145 additions & 0 deletions docs/source/advanced/multi_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,8 @@ Lightning currently offers the following methods to leverage model parallelism:
- Sharded Training (partitioning your gradients and optimizer state across multiple GPUs, for reduced memory overhead with **no performance loss**)
- Sequential Model Parallelism with Checkpointing (partition your :class:`nn.Sequential <torch.nn.Sequential>` module across multiple GPUs, leverage checkpointing and microbatching for further memory improvements and device utilization)

.. _sharded:

Sharded Training
^^^^^^^^^^^^^^^^
Lightning integration of optimizer sharded training provided by `FairScale <https://github.com/facebookresearch/fairscale>`_.
Expand Down Expand Up @@ -680,6 +682,149 @@ Sharded Training can work across all DDP variants by adding the additional ``--p

Internally we re-initialize your optimizers and shard them across your machines and processes. We handle all communication using PyTorch distributed, so no code changes are required.

----------

.. _deep_speed:

DeepSpeed
^^^^^^^^^

.. note::
The DeepSpeed plugin is in beta and the API is subject to change. Please create an `issue <https://github.com/PyTorchLightning/pytorch-lightning/issues>`_ if you run into any issues.

`DeepSpeed <https://github.com/microsoft/DeepSpeed>`_ offers additional CUDA deep learning training optimizations, similar to `FairScale <https://github.com/facebookresearch/fairscale>`_. DeepSpeed offers lower level training optimizations, and useful efficient optimizers such as `1-bit Adam <https://www.deepspeed.ai/tutorials/onebit-adam/>`_.
Using the plugin, we were able to **train model sizes of 10 Billion parameters and above**, with a lot of useful information in this `benchmark <https://github.com/huggingface/transformers/issues/9996>`_ and the DeepSpeed `docs <https://www.deepspeed.ai/tutorials/megatron/>`_.
We recommend using DeepSpeed in environments where speed and memory optimizations are important (such as training large billion parameter models). In addition, we recommend trying :ref:`sharded` first before trying DeepSpeed's further optimizations, primarily due to FairScale Sharded ease of use in scenarios such as multiple optimizers/schedulers.

To use DeepSpeed, you first need to install DeepSpeed using the commands below.

.. code-block:: bash

pip install deepspeed mpi4py

If you run into an issue with the install or later in training, ensure that the CUDA version of the pytorch you've installed matches your locally installed CUDA (you can see which one has been recognized by running ``nvcc --version``).
Additionally if you run into any issues installing m4py, ensure you have openmpi installed using ``sudo apt install libopenmpi-dev`` or ``brew install mpich`` before running ``pip install mpi4py``.

.. note::
Currently ``resume_from_checkpoint`` and manual optimization are not supported.

DeepSpeed only supports single optimizer, single scheduler.

ZeRO-Offload
""""""""""""

Below we show an example of running `ZeRO-Offload <https://www.deepspeed.ai/tutorials/zero-offload/>`_. ZeRO-Offload leverages the host CPU to offload optimizer memory/computation, reducing the overall memory consumption.
For even more speed benefit, they offer an optimized CPU version of ADAM to run the offloaded computation, which is faster than the standard PyTorch implementation. By default we enable ZeRO-Offload.

.. note::
To use ZeRO-Offload, you must use ``precision=16`` or set precision via `the DeepSpeed config. <https://www.deepspeed.ai/docs/config-json/#fp16-training-options>`_.

.. code-block:: python

from pytorch_lightning import Trainer

model = MyModel()
trainer = Trainer(gpus=4, plugins='deepspeed', precision=16)
trainer.fit(model)


This can also be done via the command line using a Pytorch Lightning script:

.. code-block:: bash

python train.py --plugins deepspeed --precision 16 --gpus 4


You can also modify the ZeRO-Offload parameters via the plugin as below.

.. code-block:: python

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin

model = MyModel()
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(allgather_bucket_size=5e8, reduce_bucket_size=5e8), precision=16)
trainer.fit(model)


.. note::
We suggest tuning the ``allgather_bucket_size`` parameter and ``reduce_bucket_size`` parameter to find optimum parameters based on your model size.
These control how large a buffer we limit the model to using when reducing gradients/gathering updated parameters. Smaller values will result in less memory, but tradeoff with speed.

DeepSpeed allocates a reduce buffer size `multiplied by 4.5x <https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage2.py#L1594-L1607>`_ so take that into consideration when tweaking the parameters.

The plugin sets a reasonable default of ``2e8``, which should work for most low VRAM GPUs (less than ``7GB``), allocating roughly ``3.6GB`` of VRAM as buffer. Higher VRAM GPUs should aim for values around ``5e8``.


Custom DeepSpeed Config
"""""""""""""""""""""""

DeepSpeed allows use of custom DeepSpeed optimizers and schedulers defined within a config file. This allows you to enable optimizers such as `1-bit Adam <https://www.deepspeed.ai/tutorials/onebit-adam/>`_.

.. note::
All plugin default parameters will be ignored when a config object is passed.
All compatible arguments can be seen in the `DeepSpeed docs <https://www.deepspeed.ai/docs/config-json/>`_.

.. code-block:: python

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin

deepspeed_config = {
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
"zero_allow_untested_optimizer": True,
"optimizer": {
"type": "OneBitAdam",
"params": {
"lr": 3e-5,
"betas": [0.998, 0.999],
"eps": 1e-5,
"weight_decay": 1e-9,
"cuda_aware": True,
},
},
'scheduler': {
"type": "WarmupLR",
"params": {
"last_batch_iteration": -1,
"warmup_min_lr": 0,
"warmup_max_lr": 3e-5,
"warmup_num_steps": 100,
}
},
"zero_optimization": {
"stage": 2, # Enable Stage 2 ZeRO (Optimizer/Gradient state partitioning)
"cpu_offload": True, # Enable Offloading optimizer state/calculation to the host CPU
"contiguous_gradients": True, # Reduce gradient fragmentation.
"overlap_comm": True, # Overlap reduce/backward operation of gradients for speed.
"allgather_bucket_size": 2e8, # Number of elements to all gather at once.
"reduce_bucket_size": 2e8, # Number of elements we reduce/allreduce at once.
}
}

model = MyModel()
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin(deepspeed_config), precision=16)
trainer.fit(model)


We support taking the config as a json formatted file:

.. code-block:: python

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin

model = MyModel()
trainer = Trainer(gpus=4, plugins=DeepSpeedPlugin("/path/to/deepspeed_config.json"), precision=16)
trainer.fit(model)


You can use also use an environment variable via your PyTorch Lightning script:

.. code-block:: bash

PL_DEEPSPEED_CONFIG_PATH=/path/to/deepspeed_config.json python train.py --plugins deepspeed


----------

.. _sequential-parallelism:
Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Cal
self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)

def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
optimizer.step(closure=lambda_closure, **kwargs)
self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)

def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
"""Zeros all model parameter's gradients"""
Expand Down Expand Up @@ -315,9 +315,11 @@ def setup_optimizers(self, trainer: "Trainer"):
trainer: the Trainer, these optimizers should be connected to
model: the model to be optimized by the created optimizers
"""
if trainer.testing is True:
if trainer.testing:
return
optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(self.lightning_module)
optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers(
tchaton marked this conversation as resolved.
Show resolved Hide resolved
trainer=trainer, model=self.lightning_module
)
self.optimizers = optimizers
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
Expand All @@ -25,6 +27,8 @@
"DDP2Plugin",
"DDPPlugin",
"DDPSpawnPlugin",
"DeepSpeedPlugin",
"DeepSpeedPrecisionPlugin",
"HorovodPlugin",
"NativeMixedPrecisionPlugin",
"PrecisionPlugin",
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
Expand Down
61 changes: 61 additions & 0 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Callable, Union

import torch
from torch.optim import Optimizer

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()


class DeepSpeedPrecisionPlugin(PrecisionPlugin):
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, precision):
super().__init__()
self.precision = precision
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def pre_optimizer_step(
self, pl_module: LightningModule, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs
) -> bool:
deepspeed_engine = pl_module.trainer.model
# DeepSpeed not support closures.
lambda_closure()
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

if not pl_module.automatic_optimization:
pl_module.trainer.call_hook("on_after_backward")

deepspeed_engine.step()

return False

def backward(
self,
lightning_module: LightningModule,
closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
opt_idx: int,
should_accumulate: bool,
*args,
**kwargs,
):
if is_overridden('backward', lightning_module):
warning_cache.warn(
"Overridden backward hook in the LightningModule will be ignored since DeepSpeed handles"
"backward logic outside of the LightningModule"
)
# todo: hack around for deepspeed engine to call backward
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
deepspeed_engine = lightning_module.trainer.model
deepspeed_engine.backward(closure_loss, **kwargs)
# once backward has been applied, release graph
closure_loss = closure_loss.detach()

return closure_loss

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)):
"""
DeepSpeed handles clipping gradients via the training type plugin.
"""
pass
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/training_type/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
Expand Down
Loading