Skip to content

Commit

Permalink
Add changes for push_to_hub_fastai to use the new http-based approach. (
Browse files Browse the repository at this point in the history
#1040)

* Add changes for push_to_hub_fastai to use the new http-based approach.

* Add updates for http based approach

* Update docstrings

* Update src/huggingface_hub/fastai_utils.py

* Update src/huggingface_hub/fastai_utils.py

* Update src/huggingface_hub/fastai_utils.py

* Fix isort

* Remove unused imports

Co-authored-by: Lucain <lucainp@gmail.com>
  • Loading branch information
nandwalritik and Wauplin authored Sep 12, 2022
1 parent 0c7b7ef commit b9d8617
Showing 1 changed file with 70 additions and 47 deletions.
117 changes: 70 additions & 47 deletions src/huggingface_hub/fastai_utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
import json
import os
import tempfile
from pathlib import Path
from pickle import DEFAULT_PROTOCOL, PicklingError
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional, Union

from packaging import version

from huggingface_hub import snapshot_download
from huggingface_hub import hf_api, snapshot_download
from huggingface_hub.constants import CONFIG_NAME
from huggingface_hub.file_download import (
_PY_VERSION,
get_fastai_version,
get_fastcore_version,
)
from huggingface_hub.hf_api import HfApi, HfFolder
from huggingface_hub.repository import Repository
from huggingface_hub.hf_api import HfApi

from .utils import logging, validate_hf_hub_args
from .utils._deprecation import _deprecate_arguments, _deprecate_positional_args


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -351,19 +352,50 @@ def from_pretrained_fastai(
return load_learner(os.path.join(storage_folder, "model.pkl"))


@_deprecate_positional_args(version="0.12")
@_deprecate_arguments(
version="0.12",
deprecated_args={
"git_user",
"git_email",
},
)
@validate_hf_hub_args
def push_to_hub_fastai(
# NOTE: New arguments since 0.9
learner,
*,
repo_id: str,
commit_message: Optional[str] = "Add model",
private: bool = False,
token: Optional[str] = None,
config: Optional[dict] = None,
**kwargs,
branch: Optional[str] = None,
create_pr: Optional[bool] = None,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
api_endpoint: Optional[str] = None,
# NOTE: deprecated signature that will change in 0.12
git_user: Optional[str] = None,
git_email: Optional[str] = None,
# TODO (release 0.12): signature must be the following
# learner,
# repo_id: Optional[str] = None, # optional only until 0.12
# commit_message: Optional[str] = "Add model",
# private: Optional[bool] = None,
# token: Optional[str] = None,
# config: Optional[dict] = None,
# branch: Optional[str] = None,
# create_pr: Optional[bool] = None,
# allow_patterns: Optional[Union[List[str], str]] = None,
# ignore_patterns: Optional[Union[List[str], str]] = None,
# api_endpoint: Optional[str] = None,
):
"""
Upload learner checkpoint files to the Hub while synchronizing a local clone of the repo in
:obj:`repo_id`.
Upload learner checkpoint files to the Hub.
Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be
pushed to the hub. See [`upload_folder`] reference for more details.
Args:
learner (`Learner`):
Expand All @@ -378,15 +410,19 @@ def push_to_hub_fastai(
The Hugging Face account token to use as HTTP bearer authorization for remote files. If :obj:`None`, the token will be asked by a prompt.
config (`dict`, *optional*):
Configuration object to be saved alongside the model weights.
Keyword Args:
branch (`str`, *optional*):
The git branch on which to push the model. This defaults to
the default branch as specified in your repository, which
defaults to `"main"`.
create_pr (`boolean`, *optional*):
Whether or not to create a Pull Request from `branch` with that commit.
Defaults to `False`.
api_endpoint (`str`, *optional*):
The API endpoint to use when pushing the model to the hub.
git_user (`str`, *optional*):
Will override the ``git config user.name`` for committing and pushing files to the hub.
git_email (`str`, *optional*):
Will override the ``git config user.email`` for committing and pushing files to the hub.
allow_patterns (`List[str]` or `str`, *optional*):
If provided, only files matching at least one pattern are pushed.
ignore_patterns (`List[str]` or `str`, *optional*):
If provided, files matching any of the patterns are not pushed.
Returns:
The url of the commit of your model in the given repository.
Expand All @@ -401,41 +437,28 @@ def push_to_hub_fastai(
"""

_check_fastai_fastcore_versions()

api_endpoint: str = kwargs.get("api_endpoint", None)
git_user: str = kwargs.get("git_user", None)
git_email: str = kwargs.get("git_email", None)

if token is None:
token = HfFolder.get_token()

if token is None:
raise ValueError(
"You must login to the Hugging Face Hub. There are two options: "
"(1) Type `huggingface-cli login` in your terminal and enter your token. "
"(2) Enter your token in the `token` argument. "
"Your token is available in the Settings of your Hugging Face account. "
)

# Create repo using `HfApi()`.
repo_url = HfApi(endpoint=api_endpoint).create_repo(
repo_id,
token, _ = hf_api._validate_or_retrieve_token(token)
api = HfApi(endpoint=api_endpoint)
api.create_repo(
repo_id=repo_id,
repo_type="model",
token=token,
private=private,
repo_type=None,
exist_ok=True,
)

# If repository exists in the Hugging Face Hub then clone it locally in `repo_id`.
repo = Repository(
repo_id,
clone_from=repo_url,
use_auth_token=token,
git_user=git_user,
git_email=git_email,
)
repo.git_pull(rebase=True)

_save_pretrained_fastai(learner, repo_id, config=config)

return repo.push_to_hub(commit_message=commit_message)
# Push the files to the repo in a single commit
with tempfile.TemporaryDirectory() as tmp:
saved_path = Path(tmp) / repo_id
_save_pretrained_fastai(learner, saved_path, config=config)
return api.upload_folder(
repo_id=repo_id,
repo_type="model",
token=token,
folder_path=saved_path,
commit_message=commit_message,
revision=branch,
create_pr=create_pr,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)

0 comments on commit b9d8617

Please sign in to comment.