Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing custom headers to HfApi #2098

Merged
merged 7 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions src/huggingface_hub/_commit_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from .lfs import UploadInfo, lfs_upload, post_lfs_batch_info
from .utils import (
EntryNotFoundError,
build_hf_headers,
chunk_iterable,
hf_raise_for_status,
logging,
Expand Down Expand Up @@ -319,7 +318,7 @@ def _upload_lfs_files(
additions: List[CommitOperationAdd],
repo_type: str,
repo_id: str,
token: Optional[str],
headers: Dict[str, str],
endpoint: Optional[str] = None,
num_threads: int = 5,
revision: Optional[str] = None,
Expand All @@ -338,8 +337,8 @@ def _upload_lfs_files(
repo_id (`str`):
A namespace (user or an organization) and a repo name separated
by a `/`.
token (`str`, *optional*):
An authentication token ( See https://huggingface.co/settings/tokens )
headers (`Dict[str, str]`):
Headers to use for the request, including authorization headers and user agent.
num_threads (`int`, *optional*):
The number of concurrent threads to use when uploading. Defaults to 5.
revision (`str`, *optional*):
Expand All @@ -360,11 +359,12 @@ def _upload_lfs_files(
for chunk in chunk_iterable(additions, chunk_size=256):
batch_actions_chunk, batch_errors_chunk = post_lfs_batch_info(
upload_infos=[op.upload_info for op in chunk],
token=token,
repo_id=repo_id,
repo_type=repo_type,
revision=revision,
endpoint=endpoint,
headers=headers,
token=None, # already passed in 'headers'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the token be generally expected in the headers in that case, rather than put as a kwarg? Or is it a public-facing function and we can't change the API maybe

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I removed token whenever possible but this one is not a private method so just in case I did not remove token from the params.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(though it's not a documented / promoted method)

)

# If at least 1 error, we do not retrieve information for other chunks
Expand Down Expand Up @@ -399,7 +399,7 @@ def _upload_lfs_files(
def _wrapped_lfs_upload(batch_action) -> None:
try:
operation = oid2addop[batch_action["oid"]]
lfs_upload(operation=operation, lfs_batch_action=batch_action, token=token)
lfs_upload(operation=operation, lfs_batch_action=batch_action, headers=headers)
except Exception as exc:
raise RuntimeError(f"Error while uploading '{operation.path_in_repo}' to the Hub.") from exc

Expand Down Expand Up @@ -443,7 +443,7 @@ def _fetch_upload_modes(
additions: Iterable[CommitOperationAdd],
repo_type: str,
repo_id: str,
token: Optional[str],
headers: Dict[str, str],
revision: str,
endpoint: Optional[str] = None,
create_pr: bool = False,
Expand All @@ -462,8 +462,8 @@ def _fetch_upload_modes(
repo_id (`str`):
A namespace (user or an organization) and a repo name separated
by a `/`.
token (`str`, *optional*):
An authentication token ( See https://huggingface.co/settings/tokens )
headers (`Dict[str, str]`):
Headers to use for the request, including authorization headers and user agent.
revision (`str`):
The git revision to upload the files to. Can be any valid git revision.
gitignore_content (`str`, *optional*):
Expand All @@ -478,7 +478,6 @@ def _fetch_upload_modes(
If the Hub API response is improperly formatted.
"""
endpoint = endpoint if endpoint is not None else ENDPOINT
headers = build_hf_headers(token=token)

# Fetch upload mode (LFS or regular) chunk by chunk.
upload_modes: Dict[str, UploadMode] = {}
Expand Down Expand Up @@ -527,7 +526,7 @@ def _fetch_files_to_copy(
copies: Iterable[CommitOperationCopy],
repo_type: str,
repo_id: str,
token: Optional[str],
headers: Dict[str, str],
revision: str,
endpoint: Optional[str] = None,
) -> Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]]:
Expand All @@ -546,8 +545,8 @@ def _fetch_files_to_copy(
repo_id (`str`):
A namespace (user or an organization) and a repo name separated
by a `/`.
token (`str`, *optional*):
An authentication token ( See https://huggingface.co/settings/tokens )
headers (`Dict[str, str]`):
Headers to use for the request, including authorization headers and user agent.
revision (`str`):
The git revision to upload the files to. Can be any valid git revision.

Expand All @@ -563,7 +562,7 @@ def _fetch_files_to_copy(
"""
from .hf_api import HfApi, RepoFolder

hf_api = HfApi(endpoint=endpoint, token=token)
hf_api = HfApi(endpoint=endpoint, headers=headers)
files_to_copy: Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]] = {}
for src_revision, operations in groupby(copies, key=lambda op: op.src_revision):
operations = list(operations) # type: ignore
Expand All @@ -582,7 +581,6 @@ def _fetch_files_to_copy(
files_to_copy[(src_repo_file.path, src_revision)] = src_repo_file
else:
# TODO: (optimization) download regular files to copy concurrently
headers = build_hf_headers(token=token)
url = hf_hub_url(
endpoint=endpoint,
repo_type=repo_type,
Expand Down
10 changes: 9 additions & 1 deletion src/huggingface_hub/_snapshot_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def snapshot_download(
ignore_patterns: Optional[Union[List[str], str]] = None,
max_workers: int = 8,
tqdm_class: Optional[base_tqdm] = None,
headers: Optional[Dict[str, str]] = None,
endpoint: Optional[str] = None,
) -> str:
"""Download repo files.
Expand Down Expand Up @@ -120,6 +121,8 @@ def snapshot_download(
- If `True`, the token is read from the HuggingFace config
folder.
- If a string, it's used as the authentication token.
headers (`dict`, *optional*):
Additional headers to include in the request. Those headers take precedence over the others.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, avoid downloading the file and return the path to the
local cached file if it exists.
Expand Down Expand Up @@ -174,7 +177,11 @@ def snapshot_download(
try:
# if we have internet connection we want to list files to download
api = HfApi(
library_name=library_name, library_version=library_version, user_agent=user_agent, endpoint=endpoint
library_name=library_name,
library_version=library_version,
user_agent=user_agent,
endpoint=endpoint,
headers=headers,
)
repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token)
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
Expand Down Expand Up @@ -297,6 +304,7 @@ def _inner_hf_hub_download(repo_file: str):
resume_download=resume_download,
force_download=force_download,
token=token,
headers=headers,
)

if HF_HUB_ENABLE_HF_TRANSFER:
Expand Down
13 changes: 12 additions & 1 deletion src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,7 @@ def hf_hub_download(
resume_download: bool = False,
token: Union[bool, str, None] = None,
local_files_only: bool = False,
headers: Optional[Dict[str, str]] = None,
legacy_cache_layout: bool = False,
endpoint: Optional[str] = None,
) -> str:
Expand Down Expand Up @@ -1120,6 +1121,8 @@ def hf_hub_download(
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, avoid downloading the file and return the path to the
local cached file if it exists.
headers (`dict`, *optional*):
Additional headers to be sent with the request.
legacy_cache_layout (`bool`, *optional*, defaults to `False`):
If `True`, uses the legacy file cache layout i.e. just call [`hf_hub_url`]
then `cached_download`. This is deprecated as the new cache layout is
Expand Down Expand Up @@ -1237,6 +1240,7 @@ def hf_hub_download(
library_name=library_name,
library_version=library_version,
user_agent=user_agent,
headers=headers,
)

url_to_download = url
Expand Down Expand Up @@ -1619,6 +1623,7 @@ def get_hf_file_metadata(
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Union[Dict, str, None] = None,
headers: Optional[Dict[str, str]] = None,
) -> HfFileMetadata:
"""Fetch metadata of a file versioned on the Hub for a given url.

Expand All @@ -1642,13 +1647,19 @@ def get_hf_file_metadata(
The version of the library.
user_agent (`dict`, `str`, *optional*):
The user-agent info in the form of a dictionary or a string.
headers (`dict`, *optional*):
Additional headers to be sent with the request.

Returns:
A [`HfFileMetadata`] object containing metadata such as location, etag, size and
commit_hash.
"""
headers = build_hf_headers(
token=token, library_name=library_name, library_version=library_version, user_agent=user_agent
token=token,
library_name=library_name,
library_version=library_version,
user_agent=user_agent,
headers=headers,
)
headers["Accept-Encoding"] = "identity" # prevent any compression => we want to know the real size of the file

Expand Down
Loading
Loading