Skip to content

Commit

Permalink
Fix a bug to respect the HF_HUB_ETAG_TIMEOUT. (#1728)
Browse files Browse the repository at this point in the history
* Fix a bug to respect the HF_HUB_ETAG_TIMEOUT.

* style

* add test

---------

Co-authored-by: Lucain <lucainp@gmail.com>
  • Loading branch information
Shahafgo and Wauplin authored Oct 13, 2023
1 parent 30f5ed4 commit c36eb68
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
HF_HUB_DISABLE_SYMLINKS_WARNING,
HF_HUB_DOWNLOAD_TIMEOUT,
HF_HUB_ENABLE_HF_TRANSFER,
HF_HUB_ETAG_TIMEOUT,
HUGGINGFACE_CO_URL_TEMPLATE,
HUGGINGFACE_HEADER_X_LINKED_ETAG,
HUGGINGFACE_HEADER_X_LINKED_SIZE,
Expand Down Expand Up @@ -664,9 +665,9 @@ def cached_download(
</Tip>
"""
if HF_HUB_DOWNLOAD_TIMEOUT != DEFAULT_DOWNLOAD_TIMEOUT:
if HF_HUB_ETAG_TIMEOUT != DEFAULT_ETAG_TIMEOUT:
# Respect environment variable above user value
etag_timeout = HF_HUB_DOWNLOAD_TIMEOUT
etag_timeout = HF_HUB_ETAG_TIMEOUT

if not legacy_cache_layout:
warnings.warn(
Expand Down Expand Up @@ -1150,9 +1151,9 @@ def hf_hub_download(
</Tip>
"""
if HF_HUB_DOWNLOAD_TIMEOUT != DEFAULT_DOWNLOAD_TIMEOUT:
if HF_HUB_ETAG_TIMEOUT != DEFAULT_ETAG_TIMEOUT:
# Respect environment variable above user value
etag_timeout = HF_HUB_DOWNLOAD_TIMEOUT
etag_timeout = HF_HUB_ETAG_TIMEOUT

if force_filename is not None:
warnings.warn(
Expand Down
59 changes: 59 additions & 0 deletions tests/test_file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,65 @@ def _get_etag_and_normalize(response: Response) -> str:
return _normalize_etag(response.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or response.headers.get("ETag"))


@with_production_testing
class TestEtagTimeoutConfig(unittest.TestCase):
@patch("huggingface_hub.file_download.DEFAULT_ETAG_TIMEOUT", 10)
@patch("huggingface_hub.file_download.HF_HUB_ETAG_TIMEOUT", 10)
def test_etag_timeout_default_value(self):
with SoftTemporaryDirectory() as cache_dir:
with patch.object(
huggingface_hub.file_download,
"get_hf_file_metadata",
wraps=huggingface_hub.file_download.get_hf_file_metadata,
) as mock_etag_call:
hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=cache_dir)
kwargs = mock_etag_call.call_args.kwargs
self.assertIn("timeout", kwargs)
self.assertEqual(kwargs["timeout"], 10)

@patch("huggingface_hub.file_download.DEFAULT_ETAG_TIMEOUT", 10)
@patch("huggingface_hub.file_download.HF_HUB_ETAG_TIMEOUT", 10)
def test_etag_timeout_parameter_value(self):
with SoftTemporaryDirectory() as cache_dir:
with patch.object(
huggingface_hub.file_download,
"get_hf_file_metadata",
wraps=huggingface_hub.file_download.get_hf_file_metadata,
) as mock_etag_call:
hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=cache_dir, etag_timeout=12)
kwargs = mock_etag_call.call_args.kwargs
self.assertIn("timeout", kwargs)
self.assertEqual(kwargs["timeout"], 12) # passed as parameter, takes priority

@patch("huggingface_hub.file_download.DEFAULT_ETAG_TIMEOUT", 10)
@patch("huggingface_hub.file_download.HF_HUB_ETAG_TIMEOUT", 15) # takes priority
def test_etag_timeout_set_as_env_variable(self):
with SoftTemporaryDirectory() as cache_dir:
with patch.object(
huggingface_hub.file_download,
"get_hf_file_metadata",
wraps=huggingface_hub.file_download.get_hf_file_metadata,
) as mock_etag_call:
hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=cache_dir)
kwargs = mock_etag_call.call_args.kwargs
self.assertIn("timeout", kwargs)
self.assertEqual(kwargs["timeout"], 15)

@patch("huggingface_hub.file_download.DEFAULT_ETAG_TIMEOUT", 10)
@patch("huggingface_hub.file_download.HF_HUB_ETAG_TIMEOUT", 15) # takes priority
def test_etag_timeout_set_as_env_variable_parameter_ignored(self):
with SoftTemporaryDirectory() as cache_dir:
with patch.object(
huggingface_hub.file_download,
"get_hf_file_metadata",
wraps=huggingface_hub.file_download.get_hf_file_metadata,
) as mock_etag_call:
hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=cache_dir, etag_timeout=12)
kwargs = mock_etag_call.call_args.kwargs
self.assertIn("timeout", kwargs)
self.assertEqual(kwargs["timeout"], 15) # passed value ignored, HF_HUB_ETAG_TIMEOUT takes priority


def _recursive_chmod(path: str, mode: int) -> None:
# Taken from https://stackoverflow.com/a/2853934
for root, dirs, files in os.walk(path):
Expand Down

0 comments on commit c36eb68

Please sign in to comment.