Skip to content

Commit

Permalink
[docs] Update FSDP instructions and add DeepSpeed evaluate/predict ex…
Browse files Browse the repository at this point in the history
…ample (#8713)
  • Loading branch information
Sean Naren authored Aug 4, 2021
1 parent 052aefc commit 49df107
Showing 1 changed file with 32 additions and 11 deletions.
43 changes: 32 additions & 11 deletions docs/source/advanced/advanced_gpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,13 @@ To reach larger parameter sizes and be memory efficient, we have to shard parame
This means you must create a single model that is treated as a ``torch.nn.Module`` within the ``LightningModule``.
This is a limitation of Fully Sharded Training that will be resolved in the future.

Wrap the Model
""""""""""""""
Enabling Module Sharding for Maximum Memory Effeciency
""""""""""""""""""""""""""""""""""""""""""""""""""""""

To activate parameter sharding, you must wrap your model using provided ``wrap`` or ``auto_wrap`` functions as described below. Internally in Lightning, we enable a context manager around the ``configure_sharded_model`` function to make sure the ``wrap`` and ``auto_wrap`` parameters are passed correctly.

When not using Fully Sharded these wrap functions are a no-op. This means once the changes have been made, there is no need to remove the changes for other plugins.

This is a requirement for really large models and also saves on instantiation time as modules are sharded instantly, rather than after the entire model is created in memory.

``auto_wrap`` will recursively wrap `torch.nn.Modules` within the ``LightningModule`` with nested Fully Sharded Wrappers,
signalling that we'd like to partition these modules across data parallel devices, discarding the full weights when not required (information `here <https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html>`__).

Expand All @@ -129,22 +127,28 @@ Below is an example of using both ``wrap`` and ``auto_wrap`` to create your mode
class MyModel(pl.LightningModule):
...
def __init__(self):
super().__init__()
self.linear_layer = nn.Linear(32, 32)
self.block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
self.final_block = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
def configure_sharded_model(self):
# Created within sharded model context, modules are instantly sharded across processes
# as soon as they are wrapped with ``wrap`` or ``auto_wrap``
# modules are sharded across processes
# as soon as they are wrapped with ``wrap`` or ``auto_wrap``.
# During the forward/backward passes, weights get synced across processes
# and de-allocated once computation is complete, saving memory.
# Wraps the layer in a Fully Sharded Wrapper automatically
linear_layer = wrap(nn.Linear(32, 32))
linear_layer = wrap(self.linear_layer)
# Wraps the module recursively
# based on a minimum number of parameters (default 100M parameters)
block = auto_wrap(nn.Sequential(nn.Linear(32, 32), nn.ReLU()))
block = auto_wrap(self.block)
# For best memory efficiency,
# add fairscale activation checkpointing
final_block = auto_wrap(checkpoint_wrapper(nn.Sequential(nn.Linear(32, 32), nn.ReLU())))
# add FairScale activation checkpointing
final_block = auto_wrap(checkpoint_wrapper(self.final_block))
self.model = nn.Sequential(linear_layer, nn.ReLU(), block, final_block)
def configure_optimizers(self):
Expand Down Expand Up @@ -359,6 +363,23 @@ Also please have a look at our :ref:`deepspeed-zero-stage-3-tips` which contains
trainer.predict()
You can also use the Lightning Trainer to run predict or evaluate with DeepSpeed once the model has been trained.

.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin
class MyModel(pl.LightningModule):
...
model = MyModel()
trainer = Trainer(gpus=4, plugins="deepspeed_stage_3", precision=16)
trainer.test(ckpt_path="my_saved_deepspeed_checkpoint.ckpt")
Shard Model Instantly to Reduce Initialization Time/Memory
""""""""""""""""""""""""""""""""""""""""""""""""""""""""""

Expand Down

0 comments on commit 49df107

Please sign in to comment.