Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add download alias for hf_hub_download to HfApi #1580

Merged
merged 2 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/huggingface_hub/_snapshot_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
def snapshot_download(
repo_id: str,
*,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
endpoint: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
local_dir: Union[str, Path, None] = None,
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
Expand Down Expand Up @@ -71,12 +72,15 @@ def snapshot_download(
Args:
repo_id (`str`):
A user or an organization name and a repo name separated by a `/`.
revision (`str`, *optional*):
An optional Git revision id which can be a branch name, a tag, or a
commit hash.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if downloading from a dataset or space,
`None` or `"model"` if downloading from a model. Default is `None`.
revision (`str`, *optional*):
An optional Git revision id which can be a branch name, a tag, or a
commit hash.
endpoint (`str`, *optional*):
Hugging Face Hub base url. Will default to https://huggingface.co/. Otherwise, one can set the `HF_ENDPOINT`
environment variable.
cache_dir (`str`, `Path`, *optional*):
Path to the folder where cached files are stored.
local_dir (`str` or `Path`, *optional*:
Expand Down Expand Up @@ -181,7 +185,7 @@ def snapshot_download(
)

# if we have internet connection we retrieve the correct folder name from the huggingface api
api = HfApi(library_name=library_name, library_version=library_version, user_agent=user_agent)
api = HfApi(library_name=library_name, library_version=library_version, user_agent=user_agent, endpoint=endpoint)
repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token)
assert repo_info.sha is not None, "Repo info returned from server must have a revision sha."

Expand Down Expand Up @@ -212,6 +216,7 @@ def _inner_hf_hub_download(repo_file: str):
filename=repo_file,
repo_type=repo_type,
revision=commit_hash,
endpoint=endpoint,
cache_dir=cache_dir,
local_dir=local_dir,
local_dir_use_symlinks=local_dir_use_symlinks,
Expand Down
22 changes: 17 additions & 5 deletions src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from . import __version__ # noqa: F401 # for backward compatibility
from .constants import (
DEFAULT_REVISION,
ENDPOINT,
HF_HUB_DISABLE_SYMLINKS_WARNING,
HF_HUB_ENABLE_HF_TRANSFER,
HUGGINGFACE_CO_URL_TEMPLATE,
Expand Down Expand Up @@ -176,6 +177,7 @@ def hf_hub_url(
subfolder: Optional[str] = None,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
endpoint: Optional[str] = None,
) -> str:
"""Construct the URL of a file from the given information.

Expand All @@ -197,6 +199,9 @@ def hf_hub_url(
revision (`str`, *optional*):
An optional Git revision id which can be a branch name, a tag, or a
commit hash.
endpoint (`str`, *optional*):
Hugging Face Hub base url. Will default to https://huggingface.co/. Otherwise, one can set the `HF_ENDPOINT`
environment variable.

Example:

Expand Down Expand Up @@ -247,11 +252,13 @@ def hf_hub_url(

if revision is None:
revision = DEFAULT_REVISION
return HUGGINGFACE_CO_URL_TEMPLATE.format(
repo_id=repo_id,
revision=quote(revision, safe=""),
filename=quote(filename),
url = HUGGINGFACE_CO_URL_TEMPLATE.format(
repo_id=repo_id, revision=quote(revision, safe=""), filename=quote(filename)
)
# Update endpoint if provided
if endpoint is not None and url.startswith(ENDPOINT):
url = endpoint + url[len(ENDPOINT) :]
return url


def url_to_filename(url: str, etag: Optional[str] = None) -> str:
Expand Down Expand Up @@ -962,6 +969,7 @@ def hf_hub_download(
subfolder: Optional[str] = None,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
endpoint: Optional[str] = None,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
Expand Down Expand Up @@ -1035,6 +1043,9 @@ def hf_hub_download(
revision (`str`, *optional*):
An optional Git revision id which can be a branch name, a tag, or a
commit hash.
endpoint (`str`, *optional*):
Hugging Face Hub base url. Will default to https://huggingface.co/. Otherwise, one can set the `HF_ENDPOINT`
environment variable.
library_name (`str`, *optional*):
The name of the library to which the object corresponds.
library_version (`str`, *optional*):
Expand Down Expand Up @@ -1116,6 +1127,7 @@ def hf_hub_download(
subfolder=subfolder,
repo_type=repo_type,
revision=revision,
endpoint=endpoint,
)

return cached_download(
Expand Down Expand Up @@ -1175,7 +1187,7 @@ def hf_hub_download(
return _to_local_dir(pointer_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks)
return pointer_path

url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision)
url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision, endpoint=endpoint)

headers = build_hf_headers(
token=token,
Expand Down
Loading