Skip to content

Commit

Permalink
Sharded state dicts save correctly when save_weights_only=True (#19524
Browse files Browse the repository at this point in the history
)

Co-authored-by: Dimitri <dvoytan@sparkcognition.com>
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
4 people authored Mar 13, 2024
1 parent 8549a93 commit b3275e0
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed a KeyError when saving a FSDP sharded checkpoint and setting `save_weights_only=True` ([#19524](https://github.com/Lightning-AI/pytorch-lightning/pull/19524))


-

Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,8 @@ def save_checkpoint(

converted_state = {"model": checkpoint.pop("state_dict")}
converted_state.update({
f"optimizer_{idx}": optim_state for idx, optim_state in enumerate(checkpoint.pop("optimizer_states"))
f"optimizer_{idx}": optim_state
for idx, optim_state in enumerate(checkpoint.pop("optimizer_states", []))
})

_distributed_checkpoint_save(converted_state, path)
Expand Down
9 changes: 6 additions & 3 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
trainer.save_checkpoint(model_path.with_name("after-test"))
trainer.save_checkpoint(model_path, weights_only=True)

_assert_save_equality(trainer, model_path, cls=model.__class__)
if not model_path.is_dir(): # TODO (@awaelchli): Add support for asserting equality of sharded checkpoints
_assert_save_equality(trainer, model_path, cls=model.__class__)

with torch.inference_mode():
# Test entry point
Expand Down Expand Up @@ -279,11 +280,13 @@ def training_step(self, batch, batch_idx):

@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
def test_fsdp_strategy_checkpoint(tmpdir, precision):
@pytest.mark.parametrize("state_dict_type", ["sharded", "full"])
def test_fsdp_strategy_checkpoint(state_dict_type, precision, tmpdir):
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
model = TestFSDPModel()
strategy = FSDPStrategy(state_dict_type=state_dict_type)
trainer = Trainer(
default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy="fsdp", precision=precision, max_epochs=1
default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy=strategy, precision=precision, max_epochs=1
)
_run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))

Expand Down

0 comments on commit b3275e0

Please sign in to comment.