Skip to content

Commit

Permalink
Extend FSDP guide with checkpointing (#18374)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Aug 23, 2023
1 parent fc6f43f commit f4825e5
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source-fabric/advanced/model_parallel/fsdp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ You can easily load checkpoints saved by Fabric to resume training:
# model.load_state_dict(torch.load("path/to/checkpoint/file"))
Fabric will automatically recognize whether the provided path contains a checkpoint saved with ``state_dict_type="full"`` or ``state_dict_type="sharded"``.
Checkpoints saved with ``state_dict_type="full"`` can be loaded by all strategies, but sharded checkpoints can only be loaded by FSDP.
Read :doc:`the checkpoints guide <../../guide/checkpoint>` to explore more features.


----
Expand Down
75 changes: 75 additions & 0 deletions docs/source-pytorch/advanced/model_parallel/fsdp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,81 @@ In our example, we see a 3.5x memory saving, but a significant increase in itera
----


*****************
Save a checkpoint
*****************

Since training large models can be very expensive, it is best practice to checkpoint the training state periodically in case it gets interrupted unexpectedly.
Lightning saves a checkpoint every epoch by default, and there are :ref:`several settings to configure the checkpointing behavior in detail <checkpointing>`.

.. code-block:: python
# Default: Saves a checkpoint every epoch
trainer = L.Trainer()
trainer.fit(model)
# You can also manually trigger a checkpoint at any time
trainer.save_checkpoint("path/to/checkpoint/file")
# DON'T do this (inefficient):
# torch.save("path/to/checkpoint/file", model.state_dict())
For single-machine training this typically works fine, but for larger models saving a checkpoint can become slow (minutes not seconds) or overflow CPU memory (OOM) depending on the system.
To reduce memory peaks and speed up the saving to disk, set ``state_dict_type="sharded"``:

.. code-block:: python
# Default: Save a single, consolidated checkpoint file
strategy = FSDPStrategy(state_dict_type="full")
# Save individual files with state from each process
strategy = FSDPStrategy(state_dict_type="sharded")
With this, each process/GPU will save its own file into a folder at the given path by default.
The resulting checkpoint folder will have this structure:

.. code-block:: text
path/to/checkpoint/file
├── .metadata
├── __0_0.distcp
├── __1_0.distcp
└── meta.pt
The “sharded” checkpoint format is the most efficient to save and load in Lightning.

**Which checkpoint format should I use?**

- ``state_dict_type="sharded"``: Use for pre-training very large models. It is fast and uses less memory, but it is less portable - you can’t easily load the checkpoint in raw PyTorch (in the future, Lightning will provide utilities to convert the checkpoint though).
- ``state_dict_type="full"``: Use when pre-training small to moderately large models (less than 10B parameters), when fine-tuning, and when portability is required.


----


*****************
Load a checkpoint
*****************

You can easily :ref:`load checkpoints <checkpointing>` saved by Lightning to resume training:

.. code-block:: python
trainer = L.Trainer(...)
# Restore the training progress, weights, and optimizer state
trainer.fit(model, ckpt_path="path/to/checkpoint/file")
The Trainer will automatically recognize whether the provided path contains a checkpoint saved with ``state_dict_type="full"`` or ``state_dict_type="sharded"``.
Checkpoints saved with ``state_dict_type="full"`` can be loaded by all strategies, but sharded checkpoints can only be loaded by FSDP.
Read :ref:`the checkpoints guide <checkpointing>` to explore more features.


----


**********************************
Advanced performance optimizations
**********************************
Expand Down

0 comments on commit f4825e5

Please sign in to comment.