From ba9ba24c43be5624e5e37894ac18ad6cac3c9984 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Wed, 7 Sep 2022 13:34:11 -0400 Subject: [PATCH] try_to_load_from_cache returns cached non-existence --- src/huggingface_hub/file_download.py | 24 +++++++++++++++++++++++- tests/test_file_download.py | 17 +++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index f3fd28eaa7..92b1fa831f 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -169,6 +169,8 @@ def get_jinja_version(): return _jinja_version +# Return value when trying to load a file from cache but the file does not exist in the distant repo. +_CACHED_NO_EXIST = object() REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$") @@ -1330,8 +1332,25 @@ def try_to_load_from_cache( This function will not raise any exception if the file in not cached. + Args: + cache_dir (`str` or `os.PathLike`): + The folder where the cached files lie. + repo_id (`str`): + The ID of the repo on huggingface.co. + filename (`str`): + The filename to look for inside `repo_id`. + revision (`str`, *optional*): + The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is + provided either. + repo_type (`str`, *optional*): + The type of the repository. Will default to `"model"`. + Returns: - Local path (string) of file or `None` if no cached file is found. + `Optional[str]` or `_CACHED_NO_EXIST`: + Will return `None` if the file was not cached. Otherwise: + - The exact path to the cached file if it's found in the cache + - A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was + cached. """ if revision is None: revision = "main" @@ -1360,6 +1379,9 @@ def try_to_load_from_cache( with open(os.path.join(repo_cache, "refs", revision)) as f: revision = f.read() + if os.path.isfile(os.path.join(repo_cache, ".no_exist", revision, filename)): + return _CACHED_NO_EXIST + cached_shas = os.listdir(os.path.join(repo_cache, "snapshots")) if revision not in cached_shas: # No cache for this revision and we won't try to return a random revision diff --git a/tests/test_file_download.py b/tests/test_file_download.py index 50ce54d506..13ae0e9df8 100644 --- a/tests/test_file_download.py +++ b/tests/test_file_download.py @@ -25,6 +25,7 @@ REPO_TYPE_DATASET, ) from huggingface_hub.file_download import ( + _CACHED_NO_EXIST, cached_download, filename_to_url, hf_hub_download, @@ -320,3 +321,19 @@ def test_try_to_load_from_cache(self): ) # Same for uncached models self.assertIsNone(try_to_load_from_cache("bert-base", filename=CONFIG_NAME)) + + def test_try_to_load_from_cache_no_exist(self): + # Make sure the file is cached + with self.assertRaises(EntryNotFoundError): + _ = hf_hub_download(DUMMY_MODEL_ID, filename="dummy") + + new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename="dummy") + self.assertEqual(new_file_path, _CACHED_NO_EXIST) + + new_file_path = try_to_load_from_cache( + DUMMY_MODEL_ID, filename="dummy", revision="main" + ) + self.assertEqual(new_file_path, _CACHED_NO_EXIST) + + # If file non-existence is not cached, returns None + self.assertIsNone(try_to_load_from_cache(DUMMY_MODEL_ID, filename="dummy2"))