diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 0236a37292..eee8532705 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -14,6 +14,24 @@ jobs: matrix: python-version: ["3.6", "3.9"] + steps: + - uses: actions/checkout@v2 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install .[testing] + + - run: pytest -sv ./tests/ + + build_pytorch: + runs-on: ubuntu-latest + steps: - uses: actions/checkout@v2 @@ -25,7 +43,7 @@ jobs: - name: Install dependencies run: | pip install --upgrade pip - pip install .[testing] + pip install .[testing,torch] - run: pytest -sv ./tests/ diff --git a/setup.py b/setup.py index 17963ed097..e91e00fc5c 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,10 @@ def get_version() -> str: extras = {} +extras["torch"] = [ + "torch", +] + extras["testing"] = [ "pytest", ] diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index 081be72bfc..a38473b182 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -1,7 +1,7 @@ import json import logging import os -from typing import Dict, Optional +from typing import Dict, Optional, Union import requests @@ -150,7 +150,7 @@ def from_pretrained( if len(model_id.split("@")) == 2: model_id, revision = model_id.split("@") - if model_id in os.listdir() and CONFIG_NAME in os.listdir(model_id): + if os.path.isdir(model_id) and CONFIG_NAME in os.listdir(model_id): config_file = os.path.join(model_id, CONFIG_NAME) else: try: @@ -170,7 +170,7 @@ def from_pretrained( logger.warning("config.json NOT FOUND in HuggingFace Hub") config_file = None - if model_id in os.listdir(): + if os.path.isdir(model_id): print("LOADING weights from local directory") model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) else: @@ -208,6 +208,10 @@ def push_to_hub( commit_message: Optional[str] = "add model", organization: Optional[str] = None, private: bool = None, + api_endpoint=None, + use_auth_token: Union[bool, str, None] = None, + git_user: Optional[str] = None, + git_email: Optional[str] = None, ) -> str: """ Parameters: @@ -223,17 +227,34 @@ def push_to_hub( private: Whether the model repo should be private (requires a paid huggingface.co account) commit_message (:obj:`str`, `optional`, defaults to :obj:`add model`): Message to commit while pushing + api_endpoint (:obj:`str`, `optional`): + The API endpoint to use when pushing the model to the hub. + use_auth_token (``str`` or ``bool``, `optional`, defaults ``None``): + huggingface_token can be extract from ``HfApi().login(username, password)`` and is used to authenticate + against the hub (useful from Google Colab for instance). + git_user (``str``, `optional`, defaults ``None``): + will override the ``git config user.name`` for committing and pushing files to the hub. + git_email (``str``, `optional`, defaults ``None``): + will override the ``git config user.email`` for committing and pushing files to the hub. Returns: url to commit on remote repo. """ if model_id is None: - model_id = save_directory + model_id = save_directory.split("/")[-1] + + # The auth token is necessary to create a repo + if isinstance(use_auth_token, str): + huggingface_token = use_auth_token + elif use_auth_token is None and repo_url is not None: + # If the repo url exists, then no need for a token + huggingface_token = None + else: + huggingface_token = HfFolder.get_token() - token = HfFolder.get_token() if repo_url is None: - repo_url = HfApi().create_repo( - token, + repo_url = HfApi(endpoint=api_endpoint).create_repo( + huggingface_token, model_id, organization=organization, private=private, @@ -241,6 +262,12 @@ def push_to_hub( exist_ok=True, ) - repo = Repository(save_directory, clone_from=repo_url, use_auth_token=token) + repo = Repository( + save_directory, + clone_from=repo_url, + use_auth_token=use_auth_token, + git_user=git_user, + git_email=git_email, + ) return repo.push_to_hub(commit_message=commit_message) diff --git a/tests/test_hubmixin.py b/tests/test_hubmixin.py index 23a3e5a249..09aa7e75d8 100644 --- a/tests/test_hubmixin.py +++ b/tests/test_hubmixin.py @@ -1,15 +1,24 @@ +import os +import shutil +import time import unittest +from huggingface_hub import HfApi from huggingface_hub.file_download import is_torch_available from huggingface_hub.hub_mixin import ModelHubMixin +from .testing_constants import ENDPOINT_STAGING, PASS, USER -if is_torch_available(): - import torch.nn as nn +REPO_NAME = "mixin-repo-{}".format(int(time.time() * 10e3)) -HUGGINGFACE_ID = "vasudevgupta" -DUMMY_REPO_NAME = "dummy" +WORKING_REPO_SUBDIR = "fixtures/working_repo_2" +WORKING_REPO_DIR = os.path.join( + os.path.dirname(os.path.abspath(__file__)), WORKING_REPO_SUBDIR +) + +if is_torch_available(): + import torch.nn as nn def require_torch(test_case): @@ -25,42 +34,101 @@ def require_torch(test_case): return test_case -@require_torch -class DummyModel(ModelHubMixin): - def __init__(self, **kwargs): - super().__init__() - self.config = kwargs.pop("config", None) - self.l1 = nn.Linear(2, 2) +if is_torch_available(): - def forward(self, x): - return self.l1(x) + class DummyModel(nn.Module, ModelHubMixin): + def __init__(self, **kwargs): + super().__init__() + self.config = kwargs.pop("config", None) + self.l1 = nn.Linear(2, 2) + + def forward(self, x): + return self.l1(x) + + +else: + DummyModel = None @require_torch -class DummyModelTest(unittest.TestCase): +class HubMixingCommonTest(unittest.TestCase): + _api = HfApi(endpoint=ENDPOINT_STAGING) + + +@require_torch +class HubMixingTest(HubMixingCommonTest): + def tearDown(self) -> None: + try: + shutil.rmtree(WORKING_REPO_DIR) + except FileNotFoundError: + pass + + @classmethod + def setUpClass(cls): + """ + Share this valid token in all tests below. + """ + cls._token = cls._api.login(username=USER, password=PASS) + def test_save_pretrained(self): model = DummyModel() - model.save_pretrained(DUMMY_REPO_NAME) + + model.save_pretrained(f"{WORKING_REPO_DIR}/{REPO_NAME}") + files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") + self.assertTrue("pytorch_model.bin" in files) + self.assertEqual(len(files), 1) + model.save_pretrained( - DUMMY_REPO_NAME, config={"num": 12, "act": "gelu"}, push_to_hub=True + f"{WORKING_REPO_DIR}/{REPO_NAME}", config={"num": 12, "act": "gelu"} ) + files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") + self.assertTrue("config.json" in files) + self.assertTrue("pytorch_model.bin" in files) + self.assertEqual(len(files), 2) + + def test_rel_path_from_pretrained(self): + model = DummyModel() model.save_pretrained( - DUMMY_REPO_NAME, config={"num": 24, "act": "relu"}, push_to_hub=True + f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED", + config={"num": 10, "act": "gelu_fast"}, ) - model.save_pretrained( - "dummy-wts", config=None, push_to_hub=True, model_id=DUMMY_REPO_NAME + + model = DummyModel.from_pretrained( + f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED" ) + self.assertTrue(model.config == {"num": 10, "act": "gelu_fast"}) - def test_from_pretrained(self): + def test_abs_path_from_pretrained(self): model = DummyModel() model.save_pretrained( - DUMMY_REPO_NAME, config={"num": 7, "act": "gelu_fast"}, push_to_hub=True + f"{WORKING_REPO_DIR}/{REPO_NAME}-FROM_PRETRAINED", + config={"num": 10, "act": "gelu_fast"}, ) - model = DummyModel.from_pretrained(f"{HUGGINGFACE_ID}/{DUMMY_REPO_NAME}@main") - self.assertTrue(model.config == {"num": 7, "act": "gelu_fast"}) + model = DummyModel.from_pretrained( + f"{WORKING_REPO_DIR}/{REPO_NAME}-FROM_PRETRAINED" + ) + self.assertDictEqual(model.config, {"num": 10, "act": "gelu_fast"}) def test_push_to_hub(self): model = DummyModel() - model.save_pretrained("dummy-wts", push_to_hub=False) - model.push_to_hub("dummy-wts", model_id=DUMMY_REPO_NAME) + model.save_pretrained( + f"{WORKING_REPO_DIR}/{REPO_NAME}-PUSH_TO_HUB", + config={"num": 7, "act": "gelu_fast"}, + ) + + model.push_to_hub( + f"{WORKING_REPO_DIR}/{REPO_NAME}-PUSH_TO_HUB", + f"{REPO_NAME}-PUSH_TO_HUB", + api_endpoint=ENDPOINT_STAGING, + use_auth_token=self._token, + git_user="ci", + git_email="ci@dummy.com", + ) + + model_info = self._api.model_info( + f"{USER}/{REPO_NAME}-PUSH_TO_HUB", + ) + self.assertEqual(model_info.modelId, f"{USER}/{REPO_NAME}-PUSH_TO_HUB") + + self._api.delete_repo(token=self._token, name=f"{REPO_NAME}-PUSH_TO_HUB")