Skip to content

Commit

Permalink
refactor push_to_hub helper
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin authored and rwightman committed Nov 23, 2022
1 parent c037db0 commit ce4d348
Showing 1 changed file with 43 additions and 49 deletions.
92 changes: 43 additions & 49 deletions timm/models/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import os
from functools import partial
from pathlib import Path
from typing import Union
from tempfile import TemporaryDirectory
from typing import Optional, Union

import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse

try:
from torch.hub import get_dir
except ImportError:
Expand All @@ -15,7 +17,10 @@
from timm import __version__

try:
from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url
from huggingface_hub import (create_repo, get_hf_file_metadata,
hf_hub_download, hf_hub_url,
repo_type_and_id_from_hf_id, upload_folder)
from huggingface_hub.utils import EntryNotFoundError
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
_has_hf_hub = True
except ImportError:
Expand Down Expand Up @@ -121,56 +126,45 @@ def save_for_hf(model, save_directory, model_config=None):

def push_to_hf_hub(
model,
local_dir,
repo_namespace_or_url=None,
commit_message='Add model',
use_auth_token=True,
git_email=None,
git_user=None,
revision=None,
model_config=None,
repo_id: str,
commit_message: str ='Add model',
token: Optional[str] = None,
revision: Optional[str] = None,
private: bool = False,
create_pr: bool = False,
model_config: Optional[dict] = None,
):
if isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token()
if token is None:
raise ValueError(
"You must login to the Hugging Face hub on this computer by typing `huggingface-cli login` and "
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
"token as the `use_auth_token` argument."
)

if repo_namespace_or_url:
repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:]
else:
repo_owner = HfApi().whoami(token)['name']
repo_name = Path(local_dir).name

repo_id = f'{repo_owner}/{repo_name}'
repo_url = f'https://huggingface.co/{repo_id}'

# Create repo if doesn't exist yet
HfApi().create_repo(repo_id, token=use_auth_token, exist_ok=True)

repo = Repository(
local_dir,
clone_from=repo_url,
use_auth_token=use_auth_token,
git_user=git_user,
git_email=git_email,
revision=revision,
)

# Prepare a default model card that includes the necessary tags to enable inference.
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}'
with repo.commit(commit_message):
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)

# Infer complete repo_id from repo_url
# Can be different from the input `repo_id` if repo_owner was implicit
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
repo_id = f"{repo_owner}/{repo_name}"

# Check if README file already exist in repo
try:
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
has_readme = True
except EntryNotFoundError:
has_readme = False

# Dump model and push to Hub
with TemporaryDirectory() as tmpdir:
# Save model weights and config.
save_for_hf(model, repo.local_dir, model_config=model_config)
save_for_hf(model, tmpdir, model_config=model_config)

This comment has been minimized.

Copy link
@Abolish0098

Abolish0098 Dec 16, 2022

2.md


# Save a model card if it doesn't exist.
readme_path = Path(repo.local_dir) / 'README.md'
if not readme_path.exists():
# Add readme if does not exist
if not has_readme:
readme_path = Path(tmpdir) / "README.md"
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_id}'
readme_path.write_text(readme_text)

return repo.git_remote_url()
# Upload model and return
return upload_folder(
repo_id=repo_id,
folder_path=tmpdir,
revision=revision,
create_pr=create_pr,
commit_message=commit_message,
)

0 comments on commit ce4d348

Please sign in to comment.