Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Mixin #34

Merged
merged 6 commits into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,24 @@ jobs:
matrix:
python-version: ["3.6", "3.9"]

steps:
- uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
pip install --upgrade pip
pip install .[testing]

- run: pytest -sv ./tests/

build_pytorch:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2

Expand All @@ -25,7 +43,7 @@ jobs:
- name: Install dependencies
run: |
pip install --upgrade pip
pip install .[testing]
pip install .[testing,torch]

- run: pytest -sv ./tests/

Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def get_version() -> str:

extras = {}

extras["torch"] = [
"torch",
]

extras["testing"] = [
"pytest",
]
Expand Down
43 changes: 35 additions & 8 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import logging
import os
from typing import Dict, Optional
from typing import Dict, Optional, Union

import requests

Expand Down Expand Up @@ -150,7 +150,7 @@ def from_pretrained(
if len(model_id.split("@")) == 2:
model_id, revision = model_id.split("@")

if model_id in os.listdir() and CONFIG_NAME in os.listdir(model_id):
if os.path.isdir(model_id) and CONFIG_NAME in os.listdir(model_id):
config_file = os.path.join(model_id, CONFIG_NAME)
else:
try:
Expand All @@ -170,7 +170,7 @@ def from_pretrained(
logger.warning("config.json NOT FOUND in HuggingFace Hub")
config_file = None

if model_id in os.listdir():
if os.path.isdir(model_id):
print("LOADING weights from local directory")
model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
else:
Expand Down Expand Up @@ -208,6 +208,10 @@ def push_to_hub(
commit_message: Optional[str] = "add model",
organization: Optional[str] = None,
private: bool = None,
api_endpoint=None,
use_auth_token: Union[bool, str, None] = None,
git_user: Optional[str] = None,
git_email: Optional[str] = None,
) -> str:
"""
Parameters:
Expand All @@ -223,24 +227,47 @@ def push_to_hub(
private: Whether the model repo should be private (requires a paid huggingface.co account)
commit_message (:obj:`str`, `optional`, defaults to :obj:`add model`):
Message to commit while pushing
api_endpoint (:obj:`str`, `optional`):
The API endpoint to use when pushing the model to the hub.
use_auth_token (``str`` or ``bool``, `optional`, defaults ``None``):
huggingface_token can be extract from ``HfApi().login(username, password)`` and is used to authenticate
against the hub (useful from Google Colab for instance).
git_user (``str``, `optional`, defaults ``None``):
will override the ``git config user.name`` for committing and pushing files to the hub.
git_email (``str``, `optional`, defaults ``None``):
will override the ``git config user.email`` for committing and pushing files to the hub.

Returns:
url to commit on remote repo.
"""
if model_id is None:
model_id = save_directory
model_id = save_directory.split("/")[-1]

# The auth token is necessary to create a repo
if isinstance(use_auth_token, str):
huggingface_token = use_auth_token
elif use_auth_token is None and repo_url is not None:
# If the repo url exists, then no need for a token
huggingface_token = None
else:
huggingface_token = HfFolder.get_token()
Comment on lines +247 to +253
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(use_auth_token, str):
huggingface_token = use_auth_token
elif use_auth_token is None and repo_url is not None:
# If the repo url exists, then no need for a token
huggingface_token = None
else:
huggingface_token = HfFolder.get_token()
if isinstance(use_auth_token, str):
huggingface_token = use_auth_token
elif use_auth_token:
huggingface_token = HfFolder.get_token()
else:
huggingface_token = None

let's use the same logic as elsewhere no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue with using the same logic as elsewhere is that the API becomes weird to use, in order to cover a rare use-case. Namely, the following will never work:

model.push_to_hub("xxx")

it needs to be the following:

model.push_to_hub("xxx", use_auth_token=True)

I fail to see when a user would want to push to the hub without having an auth_token, as it's necessary to create a repo. If the repo already exists and one wants to push to it, then the user already has to specify the repo_url to push to.

I think the API is cleaner by having model.push_to_hub("xxx") work if you already have an authentication token in your HF folder.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok


token = HfFolder.get_token()
if repo_url is None:
repo_url = HfApi().create_repo(
token,
repo_url = HfApi(endpoint=api_endpoint).create_repo(
huggingface_token,
model_id,
organization=organization,
private=private,
repo_type=None,
exist_ok=True,
)

repo = Repository(save_directory, clone_from=repo_url, use_auth_token=token)
repo = Repository(
save_directory,
clone_from=repo_url,
use_auth_token=use_auth_token,
git_user=git_user,
git_email=git_email,
)

return repo.push_to_hub(commit_message=commit_message)
116 changes: 92 additions & 24 deletions tests/test_hubmixin.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
import os
import shutil
import time
import unittest

from huggingface_hub import HfApi
from huggingface_hub.file_download import is_torch_available
from huggingface_hub.hub_mixin import ModelHubMixin

from .testing_constants import ENDPOINT_STAGING, PASS, USER

if is_torch_available():
import torch.nn as nn

REPO_NAME = "mixin-repo-{}".format(int(time.time() * 10e3))

HUGGINGFACE_ID = "vasudevgupta"
DUMMY_REPO_NAME = "dummy"
WORKING_REPO_SUBDIR = "fixtures/working_repo_2"
WORKING_REPO_DIR = os.path.join(
os.path.dirname(os.path.abspath(__file__)), WORKING_REPO_SUBDIR
)

if is_torch_available():
import torch.nn as nn


def require_torch(test_case):
Expand All @@ -25,42 +34,101 @@ def require_torch(test_case):
return test_case


@require_torch
class DummyModel(ModelHubMixin):
def __init__(self, **kwargs):
super().__init__()
self.config = kwargs.pop("config", None)
self.l1 = nn.Linear(2, 2)
if is_torch_available():

def forward(self, x):
return self.l1(x)
class DummyModel(nn.Module, ModelHubMixin):
def __init__(self, **kwargs):
super().__init__()
self.config = kwargs.pop("config", None)
self.l1 = nn.Linear(2, 2)

def forward(self, x):
return self.l1(x)


else:
DummyModel = None


@require_torch
class DummyModelTest(unittest.TestCase):
class HubMixingCommonTest(unittest.TestCase):
_api = HfApi(endpoint=ENDPOINT_STAGING)


@require_torch
class HubMixingTest(HubMixingCommonTest):
def tearDown(self) -> None:
try:
shutil.rmtree(WORKING_REPO_DIR)
except FileNotFoundError:
pass

@classmethod
def setUpClass(cls):
"""
Share this valid token in all tests below.
"""
cls._token = cls._api.login(username=USER, password=PASS)

def test_save_pretrained(self):
model = DummyModel()
model.save_pretrained(DUMMY_REPO_NAME)

model.save_pretrained(f"{WORKING_REPO_DIR}/{REPO_NAME}")
files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}")
self.assertTrue("pytorch_model.bin" in files)
self.assertEqual(len(files), 1)

model.save_pretrained(
DUMMY_REPO_NAME, config={"num": 12, "act": "gelu"}, push_to_hub=True
f"{WORKING_REPO_DIR}/{REPO_NAME}", config={"num": 12, "act": "gelu"}
)
files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}")
self.assertTrue("config.json" in files)
self.assertTrue("pytorch_model.bin" in files)
self.assertEqual(len(files), 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice


def test_rel_path_from_pretrained(self):
model = DummyModel()
model.save_pretrained(
DUMMY_REPO_NAME, config={"num": 24, "act": "relu"}, push_to_hub=True
f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED",
config={"num": 10, "act": "gelu_fast"},
)
model.save_pretrained(
"dummy-wts", config=None, push_to_hub=True, model_id=DUMMY_REPO_NAME

model = DummyModel.from_pretrained(
f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED"
)
self.assertTrue(model.config == {"num": 10, "act": "gelu_fast"})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.assertEqual should work, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed!


def test_from_pretrained(self):
def test_abs_path_from_pretrained(self):
model = DummyModel()
model.save_pretrained(
DUMMY_REPO_NAME, config={"num": 7, "act": "gelu_fast"}, push_to_hub=True
f"{WORKING_REPO_DIR}/{REPO_NAME}-FROM_PRETRAINED",
config={"num": 10, "act": "gelu_fast"},
)

model = DummyModel.from_pretrained(f"{HUGGINGFACE_ID}/{DUMMY_REPO_NAME}@main")
self.assertTrue(model.config == {"num": 7, "act": "gelu_fast"})
model = DummyModel.from_pretrained(
f"{WORKING_REPO_DIR}/{REPO_NAME}-FROM_PRETRAINED"
)
self.assertDictEqual(model.config, {"num": 10, "act": "gelu_fast"})

def test_push_to_hub(self):
model = DummyModel()
model.save_pretrained("dummy-wts", push_to_hub=False)
model.push_to_hub("dummy-wts", model_id=DUMMY_REPO_NAME)
model.save_pretrained(
f"{WORKING_REPO_DIR}/{REPO_NAME}-PUSH_TO_HUB",
config={"num": 7, "act": "gelu_fast"},
)

model.push_to_hub(
f"{WORKING_REPO_DIR}/{REPO_NAME}-PUSH_TO_HUB",
f"{REPO_NAME}-PUSH_TO_HUB",
api_endpoint=ENDPOINT_STAGING,
use_auth_token=self._token,
git_user="ci",
git_email="ci@dummy.com",
)

model_info = self._api.model_info(
f"{USER}/{REPO_NAME}-PUSH_TO_HUB",
)
self.assertEqual(model_info.modelId, f"{USER}/{REPO_NAME}-PUSH_TO_HUB")

self._api.delete_repo(token=self._token, name=f"{REPO_NAME}-PUSH_TO_HUB")