From c36eb68d93fb880715e97b2ecfdcbbea3381497d Mon Sep 17 00:00:00 2001 From: Shahaf Golan <31520128+Shahafgo@users.noreply.github.com> Date: Fri, 13 Oct 2023 10:36:03 +0300 Subject: [PATCH] Fix a bug to respect the HF_HUB_ETAG_TIMEOUT. (#1728) * Fix a bug to respect the HF_HUB_ETAG_TIMEOUT. * style * add test --------- Co-authored-by: Lucain --- src/huggingface_hub/file_download.py | 9 +++-- tests/test_file_download.py | 59 ++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index d7f2bf0084..7f540eea17 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -33,6 +33,7 @@ HF_HUB_DISABLE_SYMLINKS_WARNING, HF_HUB_DOWNLOAD_TIMEOUT, HF_HUB_ENABLE_HF_TRANSFER, + HF_HUB_ETAG_TIMEOUT, HUGGINGFACE_CO_URL_TEMPLATE, HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_LINKED_SIZE, @@ -664,9 +665,9 @@ def cached_download( """ - if HF_HUB_DOWNLOAD_TIMEOUT != DEFAULT_DOWNLOAD_TIMEOUT: + if HF_HUB_ETAG_TIMEOUT != DEFAULT_ETAG_TIMEOUT: # Respect environment variable above user value - etag_timeout = HF_HUB_DOWNLOAD_TIMEOUT + etag_timeout = HF_HUB_ETAG_TIMEOUT if not legacy_cache_layout: warnings.warn( @@ -1150,9 +1151,9 @@ def hf_hub_download( """ - if HF_HUB_DOWNLOAD_TIMEOUT != DEFAULT_DOWNLOAD_TIMEOUT: + if HF_HUB_ETAG_TIMEOUT != DEFAULT_ETAG_TIMEOUT: # Respect environment variable above user value - etag_timeout = HF_HUB_DOWNLOAD_TIMEOUT + etag_timeout = HF_HUB_ETAG_TIMEOUT if force_filename is not None: warnings.warn( diff --git a/tests/test_file_download.py b/tests/test_file_download.py index 8c90e07e87..63de900c89 100644 --- a/tests/test_file_download.py +++ b/tests/test_file_download.py @@ -1025,6 +1025,65 @@ def _get_etag_and_normalize(response: Response) -> str: return _normalize_etag(response.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or response.headers.get("ETag")) +@with_production_testing +class TestEtagTimeoutConfig(unittest.TestCase): + @patch("huggingface_hub.file_download.DEFAULT_ETAG_TIMEOUT", 10) + @patch("huggingface_hub.file_download.HF_HUB_ETAG_TIMEOUT", 10) + def test_etag_timeout_default_value(self): + with SoftTemporaryDirectory() as cache_dir: + with patch.object( + huggingface_hub.file_download, + "get_hf_file_metadata", + wraps=huggingface_hub.file_download.get_hf_file_metadata, + ) as mock_etag_call: + hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=cache_dir) + kwargs = mock_etag_call.call_args.kwargs + self.assertIn("timeout", kwargs) + self.assertEqual(kwargs["timeout"], 10) + + @patch("huggingface_hub.file_download.DEFAULT_ETAG_TIMEOUT", 10) + @patch("huggingface_hub.file_download.HF_HUB_ETAG_TIMEOUT", 10) + def test_etag_timeout_parameter_value(self): + with SoftTemporaryDirectory() as cache_dir: + with patch.object( + huggingface_hub.file_download, + "get_hf_file_metadata", + wraps=huggingface_hub.file_download.get_hf_file_metadata, + ) as mock_etag_call: + hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=cache_dir, etag_timeout=12) + kwargs = mock_etag_call.call_args.kwargs + self.assertIn("timeout", kwargs) + self.assertEqual(kwargs["timeout"], 12) # passed as parameter, takes priority + + @patch("huggingface_hub.file_download.DEFAULT_ETAG_TIMEOUT", 10) + @patch("huggingface_hub.file_download.HF_HUB_ETAG_TIMEOUT", 15) # takes priority + def test_etag_timeout_set_as_env_variable(self): + with SoftTemporaryDirectory() as cache_dir: + with patch.object( + huggingface_hub.file_download, + "get_hf_file_metadata", + wraps=huggingface_hub.file_download.get_hf_file_metadata, + ) as mock_etag_call: + hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=cache_dir) + kwargs = mock_etag_call.call_args.kwargs + self.assertIn("timeout", kwargs) + self.assertEqual(kwargs["timeout"], 15) + + @patch("huggingface_hub.file_download.DEFAULT_ETAG_TIMEOUT", 10) + @patch("huggingface_hub.file_download.HF_HUB_ETAG_TIMEOUT", 15) # takes priority + def test_etag_timeout_set_as_env_variable_parameter_ignored(self): + with SoftTemporaryDirectory() as cache_dir: + with patch.object( + huggingface_hub.file_download, + "get_hf_file_metadata", + wraps=huggingface_hub.file_download.get_hf_file_metadata, + ) as mock_etag_call: + hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, cache_dir=cache_dir, etag_timeout=12) + kwargs = mock_etag_call.call_args.kwargs + self.assertIn("timeout", kwargs) + self.assertEqual(kwargs["timeout"], 15) # passed value ignored, HF_HUB_ETAG_TIMEOUT takes priority + + def _recursive_chmod(path: str, mode: int) -> None: # Taken from https://stackoverflow.com/a/2853934 for root, dirs, files in os.walk(path):