-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use mmap option to load_state_dict #28331
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch bin files are now deprecated in favor of safetensors but no harm in improving this!
Coul you add a test as well in test_modelling_common? 🤗
|
||
return torch.load(checkpoint_file, map_location=map_location, weights_only=True) | ||
extra_args = {} | ||
# mmap can only be used with files serialized with zipfile-based format. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how often does this come up? (zipfile-based format)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
zip file is used often as it's the default for torch.save (https://pytorch.org/docs/stable/generated/torch.save.html).
Sure. Please let me know if my test case is correct. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! My last query would be to make sure this works for the deepspeed / sdpa case (device_map = "meta")!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Do you mean add a test case for device_map == "meta' ? |
That would be a good way of making sure this will be fine with DeepSpeedZero (the if branch) 🤗 |
Currently none of existing tests covers the branch: Anyway, I added a guard for "meta" device. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good, thanks for iterating
Use mmap option to load_state_dict (huggingface#28331)
Use mmap option to load_state_dict (huggingface#28331)
Use mmap option to load_state_dict (huggingface#28331)
Use mmap option to load_state_dict (huggingface#28331)
Use torch.load(mmap=True) to accelerate checkpoint loading
#28332
cc @SunMarc @sgugger