Skip to content

Commit

Permalink
try_to_load_from_cache returns cached non-existence (#1039)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger authored Sep 9, 2022
1 parent 59873e7 commit 0c7b7ef
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
24 changes: 23 additions & 1 deletion src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def get_jinja_version():
return _jinja_version


# Return value when trying to load a file from cache but the file does not exist in the distant repo.
_CACHED_NO_EXIST = object()
REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$")


Expand Down Expand Up @@ -1318,8 +1320,25 @@ def try_to_load_from_cache(
This function will not raise any exception if the file in not cached.
Args:
cache_dir (`str` or `os.PathLike`):
The folder where the cached files lie.
repo_id (`str`):
The ID of the repo on huggingface.co.
filename (`str`):
The filename to look for inside `repo_id`.
revision (`str`, *optional*):
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
provided either.
repo_type (`str`, *optional*):
The type of the repository. Will default to `"model"`.
Returns:
Local path (string) of file or `None` if no cached file is found.
`Optional[str]` or `_CACHED_NO_EXIST`:
Will return `None` if the file was not cached. Otherwise:
- The exact path to the cached file if it's found in the cache
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
cached.
"""
if revision is None:
revision = "main"
Expand Down Expand Up @@ -1348,6 +1367,9 @@ def try_to_load_from_cache(
with open(os.path.join(repo_cache, "refs", revision)) as f:
revision = f.read()

if os.path.isfile(os.path.join(repo_cache, ".no_exist", revision, filename)):
return _CACHED_NO_EXIST

cached_shas = os.listdir(os.path.join(repo_cache, "snapshots"))
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
Expand Down
17 changes: 17 additions & 0 deletions tests/test_file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
REPO_TYPE_DATASET,
)
from huggingface_hub.file_download import (
_CACHED_NO_EXIST,
cached_download,
filename_to_url,
hf_hub_download,
Expand Down Expand Up @@ -320,3 +321,19 @@ def test_try_to_load_from_cache(self):
)
# Same for uncached models
self.assertIsNone(try_to_load_from_cache("bert-base", filename=CONFIG_NAME))

def test_try_to_load_from_cache_no_exist(self):
# Make sure the file is cached
with self.assertRaises(EntryNotFoundError):
_ = hf_hub_download(DUMMY_MODEL_ID, filename="dummy")

new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename="dummy")
self.assertEqual(new_file_path, _CACHED_NO_EXIST)

new_file_path = try_to_load_from_cache(
DUMMY_MODEL_ID, filename="dummy", revision="main"
)
self.assertEqual(new_file_path, _CACHED_NO_EXIST)

# If file non-existence is not cached, returns None
self.assertIsNone(try_to_load_from_cache(DUMMY_MODEL_ID, filename="dummy2"))

0 comments on commit 0c7b7ef

Please sign in to comment.