-
Notifications
You must be signed in to change notification settings - Fork 970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce shard-merging util for FSDP #2772
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @muellerzr, for this nice utility function to merge the sharded FSDP state dicts! It would be great to have an accompanying test for this.
@pacman100 hope the tests are to your liking :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @muellerzr for iterating! Very useful feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this and creating a great test for that ! Left a few minor comments
def test_merge_weights_safetensors(model, path): | ||
# Should now be saved at `path/merged.safetensors` | ||
merge_fsdp_weights(path / "pytorch_model_fsdp_0", path, use_safetensors=True) | ||
|
||
safe_state_dict = load_file(path / "merged.safetensors") | ||
safe_loaded_model = TinyModel() | ||
check_weights("diff", model.state_dict(), safe_loaded_model.state_dict()) | ||
safe_loaded_model.load_state_dict(safe_state_dict) | ||
check_weights("same", model.state_dict(), safe_loaded_model.state_dict()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice
What does this PR do?
This PR brings a shard-merging util from PyTorch into accelerate. The reason for doing so is we can specifically save these weights as
safetensors
rather than.bin
.Applicable to users who use
SHARDED_STATE_DICT
during FSDPNote that this is a cpu-bound process. So you just need enough RAM to load the model into memory you want to merge.
New API
Command Line
checkpoint_dir
via--remove_checkpoint_dir
--use_pytorch
to save as.bin
Python API
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@pacman100 @SunMarc
TODO: