Skip to content

Commit

Permalink
Send user_agent in HEAD calls (#1854)
Browse files Browse the repository at this point in the history
* Send user_agent in HEAD calls + add get_hf_file_metadata to HfApi + test

* docstring
  • Loading branch information
Wauplin authored Nov 22, 2023
1 parent 1b1049a commit 5d69111
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 5 deletions.
18 changes: 17 additions & 1 deletion src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
tqdm,
validate_hf_hub_args,
)
from .utils._deprecation import _deprecate_method
from .utils._headers import _http_user_agent
from .utils._runtime import _PY_VERSION # noqa: F401 # for backward compatibility
from .utils._typing import HTTP_METHOD_T
Expand Down Expand Up @@ -345,6 +346,7 @@ def filename_to_url(
return url, etag


@_deprecate_method(version="0.22.0", message="Use `huggingface_hub.utils.build_hf_headers` instead.")
def http_user_agent(
*,
library_name: Optional[str] = None,
Expand Down Expand Up @@ -1249,6 +1251,9 @@ def hf_hub_download(
token=token,
proxies=proxies,
timeout=etag_timeout,
library_name=library_name,
library_version=library_version,
user_agent=user_agent,
)
except EntryNotFoundError as http_error:
# Cache the non-existence of the file and raise
Expand Down Expand Up @@ -1595,6 +1600,9 @@ def get_hf_file_metadata(
token: Union[bool, str, None] = None,
proxies: Optional[Dict] = None,
timeout: Optional[float] = DEFAULT_REQUEST_TIMEOUT,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Union[Dict, str, None] = None,
) -> HfFileMetadata:
"""Fetch metadata of a file versioned on the Hub for a given url.
Expand All @@ -1612,12 +1620,20 @@ def get_hf_file_metadata(
`requests.request`.
timeout (`float`, *optional*, defaults to 10):
How many seconds to wait for the server to send metadata before giving up.
library_name (`str`, *optional*):
The name of the library to which the object corresponds.
library_version (`str`, *optional*):
The version of the library.
user_agent (`dict`, `str`, *optional*):
The user-agent info in the form of a dictionary or a string.
Returns:
A [`HfFileMetadata`] object containing metadata such as location, etag, size and
commit_hash.
"""
headers = build_hf_headers(token=token)
headers = build_hf_headers(
token=token, library_name=library_name, library_version=library_version, user_agent=user_agent
)
headers["Accept-Encoding"] = "identity" # prevent any compression => we want to know the real size of the file

# Retrieve metadata
Expand Down
48 changes: 44 additions & 4 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
)
from .constants import (
DEFAULT_ETAG_TIMEOUT,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_REVISION,
DISCUSSION_STATUS,
DISCUSSION_TYPES,
Expand All @@ -106,10 +107,7 @@
DiscussionStatusFilter,
DiscussionTypeFilter,
)
from .file_download import (
get_hf_file_metadata,
hf_hub_url,
)
from .file_download import HfFileMetadata, get_hf_file_metadata, hf_hub_url
from .repocard_data import DatasetCardData, ModelCardData, SpaceCardData
from .utils import ( # noqa: F401 # imported for backward compatibility
BadRequestError,
Expand Down Expand Up @@ -4614,6 +4612,48 @@ def delete_folder(
parent_commit=parent_commit,
)

@validate_hf_hub_args
def get_hf_file_metadata(
self,
*,
url: str,
token: Union[bool, str, None] = None,
proxies: Optional[Dict] = None,
timeout: Optional[float] = DEFAULT_REQUEST_TIMEOUT,
) -> HfFileMetadata:
"""Fetch metadata of a file versioned on the Hub for a given url.
Args:
url (`str`):
File url, for example returned by [`hf_hub_url`].
token (`str` or `bool`, *optional*):
A token to be used for the download.
- If `True`, the token is read from the HuggingFace config
folder.
- If `False` or `None`, no token is provided.
- If a string, it's used as the authentication token.
proxies (`dict`, *optional*):
Dictionary mapping protocol to the URL of the proxy passed to `requests.request`.
timeout (`float`, *optional*, defaults to 10):
How many seconds to wait for the server to send metadata before giving up.
Returns:
A [`HfFileMetadata`] object containing metadata such as location, etag, size and commit_hash.
"""
if token is None:
# Cannot do `token = token or self.token` as token can be `False`.
token = self.token

return get_hf_file_metadata(
url=url,
token=token,
proxies=proxies,
timeout=timeout,
library_name=self.library_name,
library_version=self.library_version,
user_agent=self.user_agent,
)

@validate_hf_hub_args
def hf_hub_download(
self,
Expand Down
45 changes: 45 additions & 0 deletions tests/test_file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_create_symlink,
_get_pointer_path,
_normalize_etag,
_request_wrapper,
_to_local_dir,
cached_download,
filename_to_url,
Expand Down Expand Up @@ -388,6 +389,50 @@ def test_hf_hub_download_offline_no_refs(self):
cache_dir=cache_dir,
)

def test_hf_hub_download_with_user_agent(self):
"""
Check that user agent is correctly sent to the HEAD call when downloading a file.
Regression test for #1854.
See https://github.com/huggingface/huggingface_hub/pull/1854.
"""

def _check_user_agent(headers: dict):
assert "user-agent" in headers
assert "test/1.0.0" in headers["user-agent"]
assert "foo/bar" in headers["user-agent"]

with SoftTemporaryDirectory() as cache_dir:
with patch("huggingface_hub.file_download._request_wrapper", wraps=_request_wrapper) as mock_request:
# First download
hf_hub_download(
DUMMY_MODEL_ID,
filename=CONFIG_NAME,
cache_dir=cache_dir,
library_name="test",
library_version="1.0.0",
user_agent="foo/bar",
)
calls = mock_request.call_args_list
assert len(calls) == 3 # HEAD, HEAD, GET
for call in calls:
_check_user_agent(call.kwargs["headers"])

with patch("huggingface_hub.file_download._request_wrapper", wraps=_request_wrapper) as mock_request:
# Second download: no GET call
hf_hub_download(
DUMMY_MODEL_ID,
filename=CONFIG_NAME,
cache_dir=cache_dir,
library_name="test",
library_version="1.0.0",
user_agent="foo/bar",
)
calls = mock_request.call_args_list
assert len(calls) == 2 # HEAD, HEAD
for call in calls:
_check_user_agent(call.kwargs["headers"])

def test_hf_hub_url_with_empty_subfolder(self):
"""
Check subfolder arg is processed correctly when empty string is passed to
Expand Down

0 comments on commit 5d69111

Please sign in to comment.