-
Notifications
You must be signed in to change notification settings - Fork 570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve Mixin #34
Improve Mixin #34
Changes from all commits
e412bdb
f2656f4
86d11f0
a3d8b1c
250195e
926a112
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,24 @@ | ||
import os | ||
import shutil | ||
import time | ||
import unittest | ||
|
||
from huggingface_hub import HfApi | ||
from huggingface_hub.file_download import is_torch_available | ||
from huggingface_hub.hub_mixin import ModelHubMixin | ||
|
||
from .testing_constants import ENDPOINT_STAGING, PASS, USER | ||
|
||
if is_torch_available(): | ||
import torch.nn as nn | ||
|
||
REPO_NAME = "mixin-repo-{}".format(int(time.time() * 10e3)) | ||
|
||
HUGGINGFACE_ID = "vasudevgupta" | ||
DUMMY_REPO_NAME = "dummy" | ||
WORKING_REPO_SUBDIR = "fixtures/working_repo_2" | ||
WORKING_REPO_DIR = os.path.join( | ||
os.path.dirname(os.path.abspath(__file__)), WORKING_REPO_SUBDIR | ||
) | ||
|
||
if is_torch_available(): | ||
import torch.nn as nn | ||
|
||
|
||
def require_torch(test_case): | ||
|
@@ -25,42 +34,101 @@ def require_torch(test_case): | |
return test_case | ||
|
||
|
||
@require_torch | ||
class DummyModel(ModelHubMixin): | ||
def __init__(self, **kwargs): | ||
super().__init__() | ||
self.config = kwargs.pop("config", None) | ||
self.l1 = nn.Linear(2, 2) | ||
if is_torch_available(): | ||
|
||
def forward(self, x): | ||
return self.l1(x) | ||
class DummyModel(nn.Module, ModelHubMixin): | ||
def __init__(self, **kwargs): | ||
super().__init__() | ||
self.config = kwargs.pop("config", None) | ||
self.l1 = nn.Linear(2, 2) | ||
|
||
def forward(self, x): | ||
return self.l1(x) | ||
|
||
|
||
else: | ||
DummyModel = None | ||
|
||
|
||
@require_torch | ||
class DummyModelTest(unittest.TestCase): | ||
class HubMixingCommonTest(unittest.TestCase): | ||
_api = HfApi(endpoint=ENDPOINT_STAGING) | ||
|
||
|
||
@require_torch | ||
class HubMixingTest(HubMixingCommonTest): | ||
def tearDown(self) -> None: | ||
try: | ||
shutil.rmtree(WORKING_REPO_DIR) | ||
except FileNotFoundError: | ||
pass | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
""" | ||
Share this valid token in all tests below. | ||
""" | ||
cls._token = cls._api.login(username=USER, password=PASS) | ||
|
||
def test_save_pretrained(self): | ||
model = DummyModel() | ||
model.save_pretrained(DUMMY_REPO_NAME) | ||
|
||
model.save_pretrained(f"{WORKING_REPO_DIR}/{REPO_NAME}") | ||
files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") | ||
self.assertTrue("pytorch_model.bin" in files) | ||
self.assertEqual(len(files), 1) | ||
|
||
model.save_pretrained( | ||
DUMMY_REPO_NAME, config={"num": 12, "act": "gelu"}, push_to_hub=True | ||
f"{WORKING_REPO_DIR}/{REPO_NAME}", config={"num": 12, "act": "gelu"} | ||
) | ||
files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") | ||
self.assertTrue("config.json" in files) | ||
self.assertTrue("pytorch_model.bin" in files) | ||
self.assertEqual(len(files), 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
|
||
def test_rel_path_from_pretrained(self): | ||
model = DummyModel() | ||
model.save_pretrained( | ||
DUMMY_REPO_NAME, config={"num": 24, "act": "relu"}, push_to_hub=True | ||
f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED", | ||
config={"num": 10, "act": "gelu_fast"}, | ||
) | ||
model.save_pretrained( | ||
"dummy-wts", config=None, push_to_hub=True, model_id=DUMMY_REPO_NAME | ||
|
||
model = DummyModel.from_pretrained( | ||
f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED" | ||
) | ||
self.assertTrue(model.config == {"num": 10, "act": "gelu_fast"}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed! |
||
|
||
def test_from_pretrained(self): | ||
def test_abs_path_from_pretrained(self): | ||
model = DummyModel() | ||
model.save_pretrained( | ||
DUMMY_REPO_NAME, config={"num": 7, "act": "gelu_fast"}, push_to_hub=True | ||
f"{WORKING_REPO_DIR}/{REPO_NAME}-FROM_PRETRAINED", | ||
config={"num": 10, "act": "gelu_fast"}, | ||
) | ||
|
||
model = DummyModel.from_pretrained(f"{HUGGINGFACE_ID}/{DUMMY_REPO_NAME}@main") | ||
self.assertTrue(model.config == {"num": 7, "act": "gelu_fast"}) | ||
model = DummyModel.from_pretrained( | ||
f"{WORKING_REPO_DIR}/{REPO_NAME}-FROM_PRETRAINED" | ||
) | ||
self.assertDictEqual(model.config, {"num": 10, "act": "gelu_fast"}) | ||
|
||
def test_push_to_hub(self): | ||
model = DummyModel() | ||
model.save_pretrained("dummy-wts", push_to_hub=False) | ||
model.push_to_hub("dummy-wts", model_id=DUMMY_REPO_NAME) | ||
model.save_pretrained( | ||
f"{WORKING_REPO_DIR}/{REPO_NAME}-PUSH_TO_HUB", | ||
config={"num": 7, "act": "gelu_fast"}, | ||
) | ||
|
||
model.push_to_hub( | ||
f"{WORKING_REPO_DIR}/{REPO_NAME}-PUSH_TO_HUB", | ||
f"{REPO_NAME}-PUSH_TO_HUB", | ||
api_endpoint=ENDPOINT_STAGING, | ||
use_auth_token=self._token, | ||
git_user="ci", | ||
git_email="ci@dummy.com", | ||
) | ||
|
||
model_info = self._api.model_info( | ||
f"{USER}/{REPO_NAME}-PUSH_TO_HUB", | ||
) | ||
self.assertEqual(model_info.modelId, f"{USER}/{REPO_NAME}-PUSH_TO_HUB") | ||
|
||
self._api.delete_repo(token=self._token, name=f"{REPO_NAME}-PUSH_TO_HUB") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use the same logic as elsewhere no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue with using the same logic as elsewhere is that the API becomes weird to use, in order to cover a rare use-case. Namely, the following will never work:
it needs to be the following:
I fail to see when a user would want to push to the hub without having an
auth_token
, as it's necessary to create a repo. If the repo already exists and one wants to push to it, then the user already has to specify therepo_url
to push to.I think the API is cleaner by having
model.push_to_hub("xxx")
work if you already have an authentication token in your HF folder.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok