FSDP state_dict_type 'sharded' not working with ModelCheckpoint(save_weights_only=True) #19492
Labels
bug
Something isn't working
checkpointing
Related to checkpointing
strategy: fsdp
Fully Sharded Data Parallel
Milestone
Bug description
When training with
FSDPStrategy(state_dict_type='sharded')
with a checkpoint callback that saves only the weights,checkpoint_callback = ModelCheckpoint(dirpath='.', save_weights_only=True)
, the run fails withThe fix could be very simple. In
lightning/pytorch/strategies/fsdp.py
add a check for"optimizer_states"
Currently we have (line 574 on):
With correction
Happy to submit a PR if there are no problems with this proposed solution. Can create unit tests etc.
What version are you seeing the problem on?
v2.2
How to reproduce the bug
Error messages and logs
Environment
Current environment
More info
No response
cc @awaelchli @carmocca
The text was updated successfully, but these errors were encountered: