-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatic safetensors conversion when lacking these files #29390
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -20,6 +20,7 @@ | |||||
import os.path | ||||||
import sys | ||||||
import tempfile | ||||||
import threading | ||||||
import unittest | ||||||
import unittest.mock as mock | ||||||
import uuid | ||||||
|
@@ -1428,7 +1429,7 @@ def test_safetensors_on_the_fly_wrong_user_opened_pr(self): | |||||
bot_opened_pr_title = None | ||||||
|
||||||
for discussion in discussions: | ||||||
if discussion.author == "SFconvertBot": | ||||||
if discussion.author == "SFconvertbot": | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ➕ on @julien-c's comment, have had feedback that this is not explicit enough.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can't change the account name now but we will think of a way to make it clearer in the UI that it's a "official bot" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good 👍🏻 |
||||||
bot_opened_pr = True | ||||||
bot_opened_pr_title = discussion.title | ||||||
|
||||||
|
@@ -1451,6 +1452,51 @@ def test_safetensors_on_the_fly_specific_revision(self): | |||||
with self.assertRaises(EnvironmentError): | ||||||
BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token, revision="new-branch") | ||||||
|
||||||
def test_absence_of_safetensors_triggers_conversion(self): | ||||||
config = BertConfig( | ||||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 | ||||||
) | ||||||
initial_model = BertModel(config) | ||||||
|
||||||
# Push a model on `main` | ||||||
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False) | ||||||
|
||||||
# Download the model that doesn't have safetensors | ||||||
BertModel.from_pretrained(self.repo_name, token=self.token) | ||||||
|
||||||
for thread in threading.enumerate(): | ||||||
if thread.name == "Thread-autoconversion": | ||||||
thread.join(timeout=10) | ||||||
|
||||||
with self.subTest("PR was open with the safetensors account"): | ||||||
discussions = self.api.get_repo_discussions(self.repo_name) | ||||||
|
||||||
bot_opened_pr = None | ||||||
bot_opened_pr_title = None | ||||||
|
||||||
for discussion in discussions: | ||||||
if discussion.author == "SFconvertbot": | ||||||
bot_opened_pr = True | ||||||
bot_opened_pr_title = discussion.title | ||||||
|
||||||
self.assertTrue(bot_opened_pr) | ||||||
self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model") | ||||||
|
||||||
@mock.patch("transformers.safetensors_conversion.spawn_conversion") | ||||||
def test_absence_of_safetensors_triggers_conversion_failed(self, spawn_conversion_mock): | ||||||
spawn_conversion_mock.side_effect = HTTPError() | ||||||
|
||||||
config = BertConfig( | ||||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 | ||||||
) | ||||||
initial_model = BertModel(config) | ||||||
|
||||||
# Push a model on `main` | ||||||
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False) | ||||||
|
||||||
# The auto conversion is mocked to always raise; ensure that it doesn't raise in the main thread | ||||||
BertModel.from_pretrained(self.repo_name, token=self.token) | ||||||
|
||||||
|
||||||
@require_torch | ||||||
@is_staging_test | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one thing I would be wary is just that if we convert a big checkpoint from torch to safetensors and we want to load it in
Flax
, sharded safetensors are not supported yetThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Flax defaults to loading flax checkpoints, not safetensors, so it won't be affected by a repo where there is sharded safetensors