From e412bdbc3ec0b6a2148fb53709d526376e49c2a0 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 20 Apr 2021 19:58:17 -0400 Subject: [PATCH 1/6] Improve Mixin --- setup.py | 6 +- src/huggingface_hub/hub_mixin.py | 41 +++++++++++--- tests/test_hubmixin.py | 94 ++++++++++++++++++++++++++------ 3 files changed, 116 insertions(+), 25 deletions(-) diff --git a/setup.py b/setup.py index 17963ed097..6a20431dc7 100644 --- a/setup.py +++ b/setup.py @@ -20,9 +20,13 @@ def get_version() -> str: extras = {} +extras["torch"] = [ + "torch", +] + extras["testing"] = [ "pytest", -] +] + extras["torch"] extras["quality"] = [ "black>=20.8b1", diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index 081be72bfc..051a2591c7 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -1,7 +1,7 @@ import json import logging import os -from typing import Dict, Optional +from typing import Dict, Optional, Union import requests @@ -150,7 +150,7 @@ def from_pretrained( if len(model_id.split("@")) == 2: model_id, revision = model_id.split("@") - if model_id in os.listdir() and CONFIG_NAME in os.listdir(model_id): + if os.path.isdir(model_id) and CONFIG_NAME in os.listdir(model_id): config_file = os.path.join(model_id, CONFIG_NAME) else: try: @@ -170,7 +170,7 @@ def from_pretrained( logger.warning("config.json NOT FOUND in HuggingFace Hub") config_file = None - if model_id in os.listdir(): + if os.path.isdir(model_id): print("LOADING weights from local directory") model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) else: @@ -208,6 +208,10 @@ def push_to_hub( commit_message: Optional[str] = "add model", organization: Optional[str] = None, private: bool = None, + api_endpoint=None, + use_auth_token: Union[bool, str, None] = None, + git_user: Optional[str] = None, + git_email: Optional[str] = None, ) -> str: """ Parameters: @@ -223,6 +227,15 @@ def push_to_hub( private: Whether the model repo should be private (requires a paid huggingface.co account) commit_message (:obj:`str`, `optional`, defaults to :obj:`add model`): Message to commit while pushing + api_endpoint (:obj:`str`, `optional`): + The API endpoint to use when pushing the model to the hub. + use_auth_token (``str`` or ``bool``, `optional`, defaults ``None``): + huggingface_token can be extract from ``HfApi().login(username, password)`` and is used to authenticate + against the hub (useful from Google Colab for instance). + git_user (``str``, `optional`, defaults ``None``): + will override the ``git config user.name`` for committing and pushing files to the hub. + git_email (``str``, `optional`, defaults ``None``): + will override the ``git config user.email`` for committing and pushing files to the hub. Returns: url to commit on remote repo. @@ -230,10 +243,18 @@ def push_to_hub( if model_id is None: model_id = save_directory - token = HfFolder.get_token() + # The auth token is necessary to create a repo + if isinstance(use_auth_token, str): + huggingface_token = use_auth_token + elif use_auth_token is None and repo_url is not None: + # If the repo url exists, then no need for a token + huggingface_token = None + else: + huggingface_token = HfFolder.get_token() + if repo_url is None: - repo_url = HfApi().create_repo( - token, + repo_url = HfApi(endpoint=api_endpoint).create_repo( + huggingface_token, model_id, organization=organization, private=private, @@ -241,6 +262,12 @@ def push_to_hub( exist_ok=True, ) - repo = Repository(save_directory, clone_from=repo_url, use_auth_token=token) + repo = Repository( + save_directory, + clone_from=repo_url, + use_auth_token=use_auth_token, + git_user=git_user, + git_email=git_email, + ) return repo.push_to_hub(commit_message=commit_message) diff --git a/tests/test_hubmixin.py b/tests/test_hubmixin.py index 23a3e5a249..478dc7f5f3 100644 --- a/tests/test_hubmixin.py +++ b/tests/test_hubmixin.py @@ -1,15 +1,24 @@ +import os +import shutil +import time import unittest +from huggingface_hub import HfApi from huggingface_hub.file_download import is_torch_available from huggingface_hub.hub_mixin import ModelHubMixin +from .testing_constants import ENDPOINT_STAGING, PASS, USER -if is_torch_available(): - import torch.nn as nn +REPO_NAME = "mixin-repo-{}".format(int(time.time() * 10e3)) + +WORKING_REPO_SUBDIR = "fixtures/working_repo_2" +WORKING_REPO_DIR = os.path.join( + os.path.dirname(os.path.abspath(__file__)), WORKING_REPO_SUBDIR +) -HUGGINGFACE_ID = "vasudevgupta" -DUMMY_REPO_NAME = "dummy" +if is_torch_available(): + import torch.nn as nn def require_torch(test_case): @@ -26,7 +35,7 @@ def require_torch(test_case): @require_torch -class DummyModel(ModelHubMixin): +class DummyModel(nn.Module, ModelHubMixin): def __init__(self, **kwargs): super().__init__() self.config = kwargs.pop("config", None) @@ -37,30 +46,81 @@ def forward(self, x): @require_torch -class DummyModelTest(unittest.TestCase): +class HubMixingCommonTest(unittest.TestCase): + _api = HfApi(endpoint=ENDPOINT_STAGING) + + +class HubMixingTest(HubMixingCommonTest): + def tearDown(self) -> None: + try: + shutil.rmtree(WORKING_REPO_DIR) + except FileNotFoundError: + pass + + @classmethod + def setUpClass(cls): + """ + Share this valid token in all tests below. + """ + cls._token = cls._api.login(username=USER, password=PASS) + def test_save_pretrained(self): model = DummyModel() - model.save_pretrained(DUMMY_REPO_NAME) + + model.save_pretrained(f"{WORKING_REPO_DIR}/{REPO_NAME}") + files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") + self.assertTrue("pytorch_model.bin" in files) + self.assertEqual(len(files), 1) + model.save_pretrained( - DUMMY_REPO_NAME, config={"num": 12, "act": "gelu"}, push_to_hub=True + f"{WORKING_REPO_DIR}/{REPO_NAME}", config={"num": 12, "act": "gelu"} ) + files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") + self.assertTrue("config.json" in files) + self.assertTrue("pytorch_model.bin" in files) + self.assertEqual(len(files), 2) + + def test_rel_path_from_pretrained(self): + model = DummyModel() model.save_pretrained( - DUMMY_REPO_NAME, config={"num": 24, "act": "relu"}, push_to_hub=True + f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED", + config={"num": 10, "act": "gelu_fast"}, ) - model.save_pretrained( - "dummy-wts", config=None, push_to_hub=True, model_id=DUMMY_REPO_NAME + + model = DummyModel.from_pretrained( + f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED" ) + self.assertTrue(model.config == {"num": 10, "act": "gelu_fast"}) - def test_from_pretrained(self): + def test_abs_path_from_pretrained(self): model = DummyModel() model.save_pretrained( - DUMMY_REPO_NAME, config={"num": 7, "act": "gelu_fast"}, push_to_hub=True + f"{WORKING_REPO_DIR}/{REPO_NAME}-FROM_PRETRAINED", + config={"num": 10, "act": "gelu_fast"}, ) - model = DummyModel.from_pretrained(f"{HUGGINGFACE_ID}/{DUMMY_REPO_NAME}@main") - self.assertTrue(model.config == {"num": 7, "act": "gelu_fast"}) + model = DummyModel.from_pretrained( + f"{WORKING_REPO_DIR}/{REPO_NAME}-FROM_PRETRAINED" + ) + self.assertTrue(model.config == {"num": 10, "act": "gelu_fast"}) def test_push_to_hub(self): model = DummyModel() - model.save_pretrained("dummy-wts", push_to_hub=False) - model.push_to_hub("dummy-wts", model_id=DUMMY_REPO_NAME) + model.save_pretrained( + f"{WORKING_REPO_DIR}/{REPO_NAME}-PUSH_TO_HUB", + config={"num": 7, "act": "gelu_fast"}, + ) + + model.push_to_hub( + f"{WORKING_REPO_DIR}/{REPO_NAME}-PUSH_TO_HUB", + f"{REPO_NAME}-PUSH_TO_HUB", + api_endpoint=ENDPOINT_STAGING, + use_auth_token=self._token, + ) + + model_info = self._api.model_info( + f"{USER}/{REPO_NAME}-PUSH_TO_HUB", + ) + self.assertEqual(model_info.modelId, f"{USER}/{REPO_NAME}-PUSH_TO_HUB") + + self._api.delete_repo(token=self._token, name=f"{REPO_NAME}-PUSH_TO_HUB") From f2656f4f204e27dbc635e525ac4b125462e37145 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 20 Apr 2021 20:09:43 -0400 Subject: [PATCH 2/6] Specify user and email --- tests/test_hubmixin.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_hubmixin.py b/tests/test_hubmixin.py index 478dc7f5f3..b82a84072d 100644 --- a/tests/test_hubmixin.py +++ b/tests/test_hubmixin.py @@ -116,6 +116,8 @@ def test_push_to_hub(self): f"{REPO_NAME}-PUSH_TO_HUB", api_endpoint=ENDPOINT_STAGING, use_auth_token=self._token, + git_user="ci", + git_email="ci@dummy.com", ) model_info = self._api.model_info( From 86d11f0b7f6d48c4a0d87dfae0f25d0e3d185b2d Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 22 Apr 2021 10:36:23 -0400 Subject: [PATCH 3/6] Address comments --- .github/workflows/python-tests.yml | 20 +++++++++++++++++++- setup.py | 2 +- src/huggingface_hub/hub_mixin.py | 3 ++- tests/test_hubmixin.py | 2 +- 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 0236a37292..eee8532705 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -14,6 +14,24 @@ jobs: matrix: python-version: ["3.6", "3.9"] + steps: + - uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install .[testing] + + - run: pytest -sv ./tests/ + + build_pytorch: + runs-on: ubuntu-latest + steps: - uses: actions/checkout@v2 @@ -25,7 +43,7 @@ jobs: - name: Install dependencies run: | pip install --upgrade pip - pip install .[testing] + pip install .[testing,torch] - run: pytest -sv ./tests/ diff --git a/setup.py b/setup.py index 6a20431dc7..e91e00fc5c 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ def get_version() -> str: extras["testing"] = [ "pytest", -] + extras["torch"] +] extras["quality"] = [ "black>=20.8b1", diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index 051a2591c7..697341c6f0 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -1,6 +1,7 @@ import json import logging import os +from pathlib import Path from typing import Dict, Optional, Union import requests @@ -241,7 +242,7 @@ def push_to_hub( url to commit on remote repo. """ if model_id is None: - model_id = save_directory + model_id = save_directory.split("/")[-1] # The auth token is necessary to create a repo if isinstance(use_auth_token, str): diff --git a/tests/test_hubmixin.py b/tests/test_hubmixin.py index b82a84072d..2810a6265e 100644 --- a/tests/test_hubmixin.py +++ b/tests/test_hubmixin.py @@ -102,7 +102,7 @@ def test_abs_path_from_pretrained(self): model = DummyModel.from_pretrained( f"{WORKING_REPO_DIR}/{REPO_NAME}-FROM_PRETRAINED" ) - self.assertTrue(model.config == {"num": 10, "act": "gelu_fast"}) + self.assertDictEqual(model.config, {"num": 10, "act": "gelu_fast"}) def test_push_to_hub(self): model = DummyModel() From a3d8b1cd7cb68c1980233f0a0584777b6c92bd25 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 22 Apr 2021 10:40:17 -0400 Subject: [PATCH 4/6] Require torch :) --- tests/test_hubmixin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_hubmixin.py b/tests/test_hubmixin.py index 2810a6265e..a89cb72962 100644 --- a/tests/test_hubmixin.py +++ b/tests/test_hubmixin.py @@ -50,6 +50,7 @@ class HubMixingCommonTest(unittest.TestCase): _api = HfApi(endpoint=ENDPOINT_STAGING) +@require_torch class HubMixingTest(HubMixingCommonTest): def tearDown(self) -> None: try: From 250195eb1468912af3a167408c0f85b3234ff331 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 22 Apr 2021 10:44:27 -0400 Subject: [PATCH 5/6] Only create model if torch is available --- tests/test_hubmixin.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/test_hubmixin.py b/tests/test_hubmixin.py index a89cb72962..09aa7e75d8 100644 --- a/tests/test_hubmixin.py +++ b/tests/test_hubmixin.py @@ -34,15 +34,20 @@ def require_torch(test_case): return test_case -@require_torch -class DummyModel(nn.Module, ModelHubMixin): - def __init__(self, **kwargs): - super().__init__() - self.config = kwargs.pop("config", None) - self.l1 = nn.Linear(2, 2) - - def forward(self, x): - return self.l1(x) +if is_torch_available(): + + class DummyModel(nn.Module, ModelHubMixin): + def __init__(self, **kwargs): + super().__init__() + self.config = kwargs.pop("config", None) + self.l1 = nn.Linear(2, 2) + + def forward(self, x): + return self.l1(x) + + +else: + DummyModel = None @require_torch From 926a11270c225fee038e078c444394b69fe03a8f Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 22 Apr 2021 10:50:44 -0400 Subject: [PATCH 6/6] Quality --- src/huggingface_hub/hub_mixin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index 697341c6f0..a38473b182 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -1,7 +1,6 @@ import json import logging import os -from pathlib import Path from typing import Dict, Optional, Union import requests