-
Notifications
You must be signed in to change notification settings - Fork 571
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use safetensors by default for PyTorchModelHubMixin
#2033
Use safetensors by default for PyTorchModelHubMixin
#2033
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @bmuskalla thanks for your contribution! 🔥 That will help starting the discussion :) So in my opinion we should:
- when saving: start to save new files as .safetensors
- when loading:
- check if the local folder or remote repository contains a .safetensors file => if it's the case, load it
- check if the local folder or remote repository contains a pytorch_model.bin file => if it's the case, load it
- otherwise => raise exception
I would not update directly the PYTORCH_WEIGHTS_NAME
constant since other libraries/users might use it in their workflow. What you can do is create a new PYTORCH_SAFE_WEIGHTS_NAME
constant as done in transformers
.
When transformers
did the change, they introduced a new parameter safe_serialization: bool
, first set to False (with a warning?) and then a few releases after set to True by default. Goal being to make the transition as smooth as possible.
EDIT: given the lower usage of this class (compared to transformers), we can skip the safe_serialization: bool
parameter (e.g. no need to add it, let's make safetensors the default).
Pinging @LysandreJik who handled this process in transformers
if I remember correctly (or at least kept an eye on it 🤗).
src/huggingface_hub/hub_mixin.py
Outdated
from safetensors import safe_open | ||
from safetensors.torch import save_file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be in base imports as they should be optional for the users. What you should do is import them only in a if is_safetensors_available()
statement below, as done for torch
(see L14). Since huggingface_hub
is a collection of many helpers used in various situations, we want to limit the number of required dependencies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback, done in 859e230
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for working on this @bmuskalla!
or drop the |
@Wauplin Good call, I've implemented the fallback for now. We can still look into whether we should issue a warning later down the road.
@julien-c My pleasure. I've reused the constants for safetensors that were already present, no PYTORCH in the name anymore. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating on this PR @bmuskalla! Looks good to me logic-wise. Thanks for taking care of the tests as well. Left a few comments mainly for styling matters but other than that we should be close to merging it :)
setup.py
Outdated
extras["torch"] = [ | ||
"torch", | ||
] | ||
extras["torch"] = ["torch", "safetensors"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extras["torch"] = ["torch", "safetensors"] | |
extras["torch"] = [ | |
"safetensors", | |
"torch", | |
] |
(nit)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer that as well but make style
puts it on a single line ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to add a trailing comma ","
to the last line otherwise ruff
will fold it indeed. Made the change in 8ca8550.
tests/test_hub_mixin_pytorch.py
Outdated
DummyModel().save_pretrained(self.cache_dir, config=TOKEN) | ||
return self.cache_dir / "model.safetensors" | ||
|
||
@patch.object(DummyModel, "_hf_hub_download") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patch.object(DummyModel, "_hf_hub_download") | |
@patch("huggingface_hub.hf_hub_download") |
Mocking like this should work and avoid the alias
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ha, didn't get this to work with the existing imports. My lack of python mock experience is exposed ;) Switching to import huggingface_hub
makes it work. If you can enlighten me if there is a way to use @patch
with a fqn while using from .file_download import hf_hub_download
, more than happy to update the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Today I learned - thanks for taking care of that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Thanks for the PR @bmuskalla! Hope that's fine with you but I've pushed 2 commits to arrange the last comments (see above). We should now be good to merge the PR as soon as the CI is green. So that it'll be shipped in the coming release! 🚀
EDIT: looks like we have issues in the CI but unrelated with this PR 😞
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Failing tests are unrelated so I'm merging this PR. Very nice contribution here 🔥 |
Thanks a lot for your contribution! 🙌 FYI there was a comment by Meta authors as PyTorch was used previously: facebookresearch/hiera#26 (comment). Perhaps we can add an explicit flag to allow users to still do this, as in the Transformers library: https://github.com/huggingface/transformers/blob/15f8296a9b493eaa0770557fe2e931677fb62e2f/src/transformers/modeling_utils.py#L2182. It can then default to |
@NielsRogge I don't think it's worth adding back support for saving to |
Ok @Wauplin, sounds good to me! |
Switches
PyTorchModelHubMixin
to use safetensors.To do discussed:
PyTorchModelHubMixin
stay backward compatible and keep reading the pickle format via_from_pretrained
? Not being sure if the model we're loading is safe is certainly a concern. Do you think the old way should continue working but issue a warning?Fixes #1989