Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache non-existence of files or completeness of repo #986

Merged
merged 11 commits into from
Aug 16, 2022
44 changes: 37 additions & 7 deletions src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
)
from .hf_api import HfFolder
from .utils import logging, tqdm
from .utils._errors import LocalEntryNotFoundError, _raise_for_status
from .utils._errors import (
EntryNotFoundError,
LocalEntryNotFoundError,
_raise_for_status,
)


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -844,6 +848,20 @@ def _create_relative_symlink(src: str, dst: str) -> None:
raise


def _cache_commit_hash_for_specific_revision(
storage_folder: str, revision: str, commit_hash: str
) -> None:
"""Cache reference between a revision (tag, branch or truncated commit hash) and the corresponding commit hash.

Does nothing if `revision` is already a proper `commit_hash` or reference is already cached.
"""
if revision != commit_hash:
ref_path = os.path.join(storage_folder, "refs", revision)
os.makedirs(os.path.dirname(ref_path), exist_ok=True)
with open(ref_path, "w") as f:
f.write(commit_hash)


def repo_folder_name(*, repo_id: str, repo_type: str) -> str:
"""Return a serialized version of a hf.co repo name and type, safe for disk storage
as a single non-nested folder.
Expand Down Expand Up @@ -1077,7 +1095,23 @@ def hf_hub_download(
proxies=proxies,
timeout=etag_timeout,
)
_raise_for_status(r)
try:
_raise_for_status(r)
except EntryNotFoundError:
commit_hash = r.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT)
if commit_hash is not None and not legacy_cache_layout:
no_exist_file_path = (
Path(storage_folder)
/ ".no_exist"
/ commit_hash
/ relative_filename
)
no_exist_file_path.parent.mkdir(parents=True, exist_ok=True)
no_exist_file_path.touch()
_cache_commit_hash_for_specific_revision(
storage_folder, revision, commit_hash
)
raise
commit_hash = r.headers[HUGGINGFACE_HEADER_X_REPO_COMMIT]
if commit_hash is None:
raise OSError(
Expand Down Expand Up @@ -1173,11 +1207,7 @@ def hf_hub_download(
# if passed revision is not identical to commit_hash
# then revision has to be a branch name or tag name.
# In that case store a ref.
if revision != commit_hash:
ref_path = os.path.join(storage_folder, "refs", revision)
os.makedirs(os.path.dirname(ref_path), exist_ok=True)
with open(ref_path, "w") as f:
f.write(commit_hash)
_cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash)

if os.path.exists(pointer_path) and not force_download:
return pointer_path
Expand Down
148 changes: 83 additions & 65 deletions tests/test_cache_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
upload_file,
)
from huggingface_hub.utils import logging
from huggingface_hub.utils._errors import EntryNotFoundError

from .testing_constants import ENDPOINT_STAGING, TOKEN, USER
from .testing_utils import repo_name, with_production_testing
Expand All @@ -32,92 +33,109 @@ def get_file_contents(path):
@with_production_testing
class CacheFileLayoutHfHubDownload(unittest.TestCase):
def test_file_downloaded_in_cache(self):
with tempfile.TemporaryDirectory() as cache:
hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache)

expected_directory_name = f'models--{MODEL_IDENTIFIER.replace("/", "--")}'
expected_path = os.path.join(cache, expected_directory_name)

refs = os.listdir(os.path.join(expected_path, "refs"))
snapshots = os.listdir(os.path.join(expected_path, "snapshots"))

expected_reference = "main"

# Only reference should be `main`.
self.assertListEqual(refs, [expected_reference])

with open(os.path.join(expected_path, "refs", expected_reference)) as f:
snapshot_name = f.readline().strip()

# The `main` reference should point to the only snapshot we have downloaded
self.assertListEqual(snapshots, [snapshot_name])
for revision, expected_reference in (
(None, "main"),
("file-2", "file-2"),
):
with self.subTest(revision), tempfile.TemporaryDirectory() as cache:
with tempfile.TemporaryDirectory() as cache:
hf_hub_download(
MODEL_IDENTIFIER,
"file_0.txt",
cache_dir=cache,
revision=revision,
)

snapshot_path = os.path.join(expected_path, "snapshots", snapshot_name)
snapshot_content = os.listdir(snapshot_path)
expected_directory_name = (
f'models--{MODEL_IDENTIFIER.replace("/", "--")}'
)
expected_path = os.path.join(cache, expected_directory_name)

# Only a single file in the snapshot
self.assertEqual(len(snapshot_content), 1)
refs = os.listdir(os.path.join(expected_path, "refs"))
snapshots = os.listdir(os.path.join(expected_path, "snapshots"))

snapshot_content_path = os.path.join(snapshot_path, snapshot_content[0])
# Only reference should be the expected one.
self.assertListEqual(refs, [expected_reference])

# The snapshot content should link to a blob
self.assertTrue(os.path.islink(snapshot_content_path))
with open(
os.path.join(expected_path, "refs", expected_reference)
) as f:
snapshot_name = f.readline().strip()

resolved_blob_relative = os.readlink(snapshot_content_path)
resolved_blob_absolute = os.path.normpath(
os.path.join(snapshot_path, resolved_blob_relative)
)
# The `main` reference should point to the only snapshot we have downloaded
self.assertListEqual(snapshots, [snapshot_name])

with open(resolved_blob_absolute) as f:
blob_contents = f.read().strip()
snapshot_path = os.path.join(
expected_path, "snapshots", snapshot_name
)
snapshot_content = os.listdir(snapshot_path)

# The contents of the file should be 'File 0'.
self.assertEqual(blob_contents, "File 0")
# Only a single file in the snapshot
self.assertEqual(len(snapshot_content), 1)

def test_file_downloaded_in_cache_with_revision(self):
with tempfile.TemporaryDirectory() as cache:
hf_hub_download(
MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache, revision="file-2"
)
snapshot_content_path = os.path.join(
snapshot_path, snapshot_content[0]
)

expected_directory_name = f'models--{MODEL_IDENTIFIER.replace("/", "--")}'
expected_path = os.path.join(cache, expected_directory_name)
# The snapshot content should link to a blob
self.assertTrue(os.path.islink(snapshot_content_path))

refs = os.listdir(os.path.join(expected_path, "refs"))
snapshots = os.listdir(os.path.join(expected_path, "snapshots"))
resolved_blob_relative = os.readlink(snapshot_content_path)
resolved_blob_absolute = os.path.normpath(
os.path.join(snapshot_path, resolved_blob_relative)
)

expected_reference = "file-2"
with open(resolved_blob_absolute) as f:
blob_contents = f.readline().strip()

# The contents of the file should be 'File 0'.
self.assertEqual(blob_contents, "File 0")

def test_no_exist_file_is_cached(self):
revisions = [None, "file-2"]
expected_references = ["main", "file-2"]
for revision, expected_reference in zip(revisions, expected_references):
with self.subTest(revision), tempfile.TemporaryDirectory() as cache:
filename = "this_does_not_exist.txt"
with self.assertRaises(EntryNotFoundError):
# The file does not exist, so we get an exception.
hf_hub_download(
MODEL_IDENTIFIER, filename, cache_dir=cache, revision=revision
)

# Only reference should be `file-2`.
self.assertListEqual(refs, [expected_reference])
expected_directory_name = (
f'models--{MODEL_IDENTIFIER.replace("/", "--")}'
)
expected_path = os.path.join(cache, expected_directory_name)

with open(os.path.join(expected_path, "refs", expected_reference)) as f:
snapshot_name = f.read().strip()
refs = os.listdir(os.path.join(expected_path, "refs"))
no_exist_snapshots = os.listdir(
os.path.join(expected_path, ".no_exist")
)

# The `main` reference should point to the only snapshot we have downloaded
self.assertListEqual(snapshots, [snapshot_name])
# Only reference should be `main`.
self.assertListEqual(refs, [expected_reference])

snapshot_path = os.path.join(expected_path, "snapshots", snapshot_name)
snapshot_content = os.listdir(snapshot_path)
with open(os.path.join(expected_path, "refs", expected_reference)) as f:
snapshot_name = f.readline().strip()

# Only a single file in the snapshot
self.assertEqual(len(snapshot_content), 1)
# The `main` reference should point to the only snapshot we have downloaded
self.assertListEqual(no_exist_snapshots, [snapshot_name])

snapshot_content_path = os.path.join(snapshot_path, snapshot_content[0])
no_exist_path = os.path.join(expected_path, ".no_exist", snapshot_name)
no_exist_content = os.listdir(no_exist_path)

# The snapshot content should link to a blob
self.assertTrue(os.path.islink(snapshot_content_path))
# Only a single file in the no_exist snapshot
self.assertEqual(len(no_exist_content), 1)

resolved_blob_relative = os.readlink(snapshot_content_path)
resolved_blob_absolute = os.path.normpath(
os.path.join(snapshot_path, resolved_blob_relative)
)
# The no_exist content should be our file
self.assertEqual(no_exist_content[0], filename)

with open(resolved_blob_absolute) as f:
blob_contents = f.readline().strip()
with open(os.path.join(no_exist_path, filename)) as f:
content = f.read().strip()

# The contents of the file should be 'File 0'.
self.assertEqual(blob_contents, "File 0")
# The contents of the file should be empty.
self.assertEqual(content, "")

def test_file_download_happens_once(self):
# Tests that a file is only downloaded once if it's not updated.
Expand Down