Skip to content

Commit

Permalink
Add download alias for hf_hub_download to HfApi (#1580)
Browse files Browse the repository at this point in the history
* accept endpoint in hf_hub_download

* Add download aliases + test
  • Loading branch information
Wauplin authored Aug 8, 2023
1 parent 37a5eaa commit f4532c9
Show file tree
Hide file tree
Showing 5 changed files with 408 additions and 13 deletions.
15 changes: 10 additions & 5 deletions src/huggingface_hub/_snapshot_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
def snapshot_download(
repo_id: str,
*,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
endpoint: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
local_dir: Union[str, Path, None] = None,
local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto",
Expand Down Expand Up @@ -71,12 +72,15 @@ def snapshot_download(
Args:
repo_id (`str`):
A user or an organization name and a repo name separated by a `/`.
revision (`str`, *optional*):
An optional Git revision id which can be a branch name, a tag, or a
commit hash.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if downloading from a dataset or space,
`None` or `"model"` if downloading from a model. Default is `None`.
revision (`str`, *optional*):
An optional Git revision id which can be a branch name, a tag, or a
commit hash.
endpoint (`str`, *optional*):
Hugging Face Hub base url. Will default to https://huggingface.co/. Otherwise, one can set the `HF_ENDPOINT`
environment variable.
cache_dir (`str`, `Path`, *optional*):
Path to the folder where cached files are stored.
local_dir (`str` or `Path`, *optional*:
Expand Down Expand Up @@ -181,7 +185,7 @@ def snapshot_download(
)

# if we have internet connection we retrieve the correct folder name from the huggingface api
api = HfApi(library_name=library_name, library_version=library_version, user_agent=user_agent)
api = HfApi(library_name=library_name, library_version=library_version, user_agent=user_agent, endpoint=endpoint)
repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token)
assert repo_info.sha is not None, "Repo info returned from server must have a revision sha."

Expand Down Expand Up @@ -212,6 +216,7 @@ def _inner_hf_hub_download(repo_file: str):
filename=repo_file,
repo_type=repo_type,
revision=commit_hash,
endpoint=endpoint,
cache_dir=cache_dir,
local_dir=local_dir,
local_dir_use_symlinks=local_dir_use_symlinks,
Expand Down
22 changes: 17 additions & 5 deletions src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from . import __version__ # noqa: F401 # for backward compatibility
from .constants import (
DEFAULT_REVISION,
ENDPOINT,
HF_HUB_DISABLE_SYMLINKS_WARNING,
HF_HUB_ENABLE_HF_TRANSFER,
HUGGINGFACE_CO_URL_TEMPLATE,
Expand Down Expand Up @@ -176,6 +177,7 @@ def hf_hub_url(
subfolder: Optional[str] = None,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
endpoint: Optional[str] = None,
) -> str:
"""Construct the URL of a file from the given information.
Expand All @@ -197,6 +199,9 @@ def hf_hub_url(
revision (`str`, *optional*):
An optional Git revision id which can be a branch name, a tag, or a
commit hash.
endpoint (`str`, *optional*):
Hugging Face Hub base url. Will default to https://huggingface.co/. Otherwise, one can set the `HF_ENDPOINT`
environment variable.
Example:
Expand Down Expand Up @@ -247,11 +252,13 @@ def hf_hub_url(

if revision is None:
revision = DEFAULT_REVISION
return HUGGINGFACE_CO_URL_TEMPLATE.format(
repo_id=repo_id,
revision=quote(revision, safe=""),
filename=quote(filename),
url = HUGGINGFACE_CO_URL_TEMPLATE.format(
repo_id=repo_id, revision=quote(revision, safe=""), filename=quote(filename)
)
# Update endpoint if provided
if endpoint is not None and url.startswith(ENDPOINT):
url = endpoint + url[len(ENDPOINT) :]
return url


def url_to_filename(url: str, etag: Optional[str] = None) -> str:
Expand Down Expand Up @@ -962,6 +969,7 @@ def hf_hub_download(
subfolder: Optional[str] = None,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
endpoint: Optional[str] = None,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
Expand Down Expand Up @@ -1035,6 +1043,9 @@ def hf_hub_download(
revision (`str`, *optional*):
An optional Git revision id which can be a branch name, a tag, or a
commit hash.
endpoint (`str`, *optional*):
Hugging Face Hub base url. Will default to https://huggingface.co/. Otherwise, one can set the `HF_ENDPOINT`
environment variable.
library_name (`str`, *optional*):
The name of the library to which the object corresponds.
library_version (`str`, *optional*):
Expand Down Expand Up @@ -1116,6 +1127,7 @@ def hf_hub_download(
subfolder=subfolder,
repo_type=repo_type,
revision=revision,
endpoint=endpoint,
)

return cached_download(
Expand Down Expand Up @@ -1175,7 +1187,7 @@ def hf_hub_download(
return _to_local_dir(pointer_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks)
return pointer_path

url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision)
url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision, endpoint=endpoint)

headers = build_hf_headers(
token=token,
Expand Down
Loading

0 comments on commit f4532c9

Please sign in to comment.