Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add changes for push_to_hub_fastai to use the new http-based approach. #1040

Merged
103 changes: 60 additions & 43 deletions src/huggingface_hub/fastai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import os
from pathlib import Path
from pickle import DEFAULT_PROTOCOL, PicklingError
from typing import Any, Dict, Optional
import tempfile
from typing import Any, Dict, Optional, Union, List
nandwalritik marked this conversation as resolved.
Show resolved Hide resolved

from packaging import version

from huggingface_hub import snapshot_download
from huggingface_hub import hf_api
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.file_download import (
_PY_VERSION,
Expand All @@ -16,7 +18,7 @@
from huggingface_hub.hf_api import HfApi, HfFolder
from huggingface_hub.repository import Repository
from huggingface_hub.utils import logging

from .utils._deprecation import _deprecate_arguments, _deprecate_positional_args

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -349,14 +351,42 @@ def from_pretrained_fastai(
return load_learner(os.path.join(storage_folder, "model.pkl"))


@_deprecate_positional_args(version="0.12")
@_deprecate_arguments(
version="0.12",
deprecated_args={
"git_user",
"git_email",
},
)
def push_to_hub_fastai(
# NOTE: New arguments since 0.9
learner,
repo_id: str,
commit_message: Optional[str] = "Add model",
private: Optional[bool] = None,
token: Optional[str] = None,
config: Optional[dict] = None,
**kwargs,
branch: Optional[str] = None,
create_pr: Optional[bool] = None,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
api_endpoint: Optional[str] = None,
# NOTE: deprecated signature that will change in 0.12
git_user: Optional[str] = None,
git_email: Optional[str] = None,
# TODO (release 0.12): signature must be the following
# learner,
# repo_id: Optional[str] = None, # optional only until 0.12
# commit_message: Optional[str] = "Add model",
# private: Optional[bool] = None,
# token: Optional[str] = None,
# config: Optional[dict] = None,
# branch: Optional[str] = None,
# create_pr: Optional[bool] = None,
# allow_patterns: Optional[Union[List[str], str]] = None,
# ignore_patterns: Optional[Union[List[str], str]] = None,
# api_endpoint: Optional[str] = None,
):
"""
Upload learner checkpoint files to the Hub while synchronizing a local clone of the repo in
Expand All @@ -375,15 +405,15 @@ def push_to_hub_fastai(
The Hugging Face account token to use as HTTP bearer authorization for remote files. If :obj:`None`, the token will be asked by a prompt.
config (`dict`, *optional*):
Configuration object to be saved alongside the model weights.

Keyword Args:
branch (`str`, *optional*):
The git branch on which to push the model. This defaults to
the default branch as specified in your repository, which
defaults to `"main"`.
create_pr (`boolean`, *optional*):
Whether or not to create a Pull Request from `branch` with that commit.
Defaults to `False`.
api_endpoint (`str`, *optional*):
The API endpoint to use when pushing the model to the hub.
git_user (`str`, *optional*):
Will override the ``git config user.name`` for committing and pushing files to the hub.
git_email (`str`, *optional*):
Will override the ``git config user.email`` for committing and pushing files to the hub.

Returns:
The url of the commit of your model in the given repository.

Expand All @@ -398,41 +428,28 @@ def push_to_hub_fastai(
"""

_check_fastai_fastcore_versions()

api_endpoint: str = kwargs.get("api_endpoint", None)
git_user: str = kwargs.get("git_user", None)
git_email: str = kwargs.get("git_email", None)

if token is None:
token = HfFolder.get_token()

if token is None:
raise ValueError(
"You must login to the Hugging Face Hub. There are two options: "
"(1) Type `huggingface-cli login` in your terminal and enter your token. "
"(2) Enter your token in the `token` argument. "
"Your token is available in the Settings of your Hugging Face account. "
)

# Create repo using `HfApi()`.
repo_url = HfApi(endpoint=api_endpoint).create_repo(
repo_id,
token, _ = hf_api._validate_or_retrieve_token(token)
api = HfApi(endpoint=api_endpoint)
api.create_repo(
repo_id=repo_id,
repo_type="model",
token=token,
private=private,
repo_type=None,
exist_ok=True,
)

# If repository exists in the Hugging Face Hub then clone it locally in `repo_id`.
repo = Repository(
repo_id,
clone_from=repo_url,
use_auth_token=token,
git_user=git_user,
git_email=git_email,
)
repo.git_pull(rebase=True)

_save_pretrained_fastai(learner, repo_id, config=config)

return repo.push_to_hub(commit_message=commit_message)
# Push the files to the repo in a single commit
with tempfile.TemporaryDirectory() as tmp:
saved_path = Path(tmp) / repo_id
_save_pretrained_fastai(learner, saved_path, config=config)
return api.upload_folder(
repo_id=repo_id,
repo_type="model",
token=token,
folder_path=saved_path,
commit_message=commit_message,
revision=branch,
create_pr=create_pr,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)