Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharded state dicts save correctly when save_weights_only=True #19524

Merged
merged 15 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed infinite recursion error in precision plugin graveyard ([#19542](https://github.com/Lightning-AI/pytorch-lightning/pull/19542))


- Fixed a KeyError when saving a FSDP sharded checkpoint and setting `save_weights_only=True` ([#19524](https://github.com/Lightning-AI/pytorch-lightning/pull/19524))


## [2.2.0] - 2024-02-08

Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,8 @@ def save_checkpoint(

converted_state = {"model": checkpoint.pop("state_dict")}
converted_state.update({
f"optimizer_{idx}": optim_state for idx, optim_state in enumerate(checkpoint.pop("optimizer_states"))
f"optimizer_{idx}": optim_state
for idx, optim_state in enumerate(checkpoint.pop("optimizer_states", []))
})

_distributed_checkpoint_save(converted_state, path)
Expand Down
10 changes: 6 additions & 4 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
trainer.save_checkpoint(model_path.with_name("after-test"))
trainer.save_checkpoint(model_path, weights_only=True)

_assert_save_equality(trainer, model_path, cls=model.__class__)
if not model_path.is_dir(): # TODO (@awaelchli): Add support for asserting equality of sharded checkpoints
_assert_save_equality(trainer, model_path, cls=model.__class__)

with torch.inference_mode():
# Test entry point
Expand Down Expand Up @@ -277,13 +278,14 @@ def training_step(self, batch, batch_idx):
trainer.fit(model)


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
def test_fsdp_strategy_checkpoint(tmpdir, precision):
@pytest.mark.parametrize("state_dict_type", ("sharded", "full"))
def test_fsdp_strategy_checkpoint(state_dict_type, precision, tmpdir):
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
model = TestFSDPModel()
strategy = FSDPStrategy(state_dict_type=state_dict_type)
trainer = Trainer(
default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy="fsdp", precision=precision, max_epochs=1
default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy=strategy, precision=precision, max_epochs=1
)
_run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))

Expand Down