From eb05af72bf2773452933be84d9c12277238e23ed Mon Sep 17 00:00:00 2001 From: worldlight425 Date: Wed, 25 May 2022 16:52:04 +0200 Subject: [PATCH] New git-aware cache file layout (#801) * light typing (cherry picked from commit b2c8f9b970c505cdf2c685e645e9e36cc472b0d3) * remove this seminal comment (cherry picked from commit 12a841a605c94733154f3b22e812c0f5e69ef37b) * I don't understand why we don't early return here cc @patrickvonplaten care to take a look? cc @LysandreJik (cherry picked from commit 259ab36f03ab3eed6eeb4fc4984bc259619b442f) * following last commit, unnest this (cherry picked from commit 54957f3f049d887af21dd8f6950873a2823c4247) * [BIG] This should work for all repo_types not just models! (cherry picked from commit 9a3f96ccb2de6663cf4cf2d9a60dd7f415227c1b) * one more (cherry picked from commit b74871250616c44a2125b26d5de29b1189e82e12) * forgot a repo_type and reorder code (cherry picked from commit 3ef7d79a44087e971e10e35d3b9f5bea3474f297) * also rename this cache folder (cherry picked from commit 4c518b861723a6d28d59108403c37edf5208f2fe) * Use `hf_hub_download`, will be simpler later (cherry picked from commit c7478d58fe62da02625b8ca17796ad1419a048b1) * in this new version, `force_filename` does not make sense anymore (cherry picked from commit 9a674bc795d5c8a26aecf5429d391fff92e47e8d) * Just inline everything inside `hf_hub_download` for now (cherry picked from commit ee49f8f57ba4e7e66f237df8f64c804862fe3ee8) * Big prototype! it works! :tada: (cherry picked from commit 7fe19ec66a2c5a7386a956cb9b65616cb209608a) * wip wip * do not touch `cached_download` * Prompt user to upgrade to `hf_hub_download` * Add a `legacy_cache_layout=True` to preserve old behavior, just in case * Create `relative symlinks` + add some doc * Fix behavior when no network * This test now is legacy * Fix-ish conflict-ish * minimize diff * refactor `repo_folder_name` * windows support + shortcut if user passes a commit hash * Rewrite `snapshot_download` and make it more robust * OOops * Create example-transformers-tf.py * Fix + add a way more complete example (running on Ubuntu) * Apply suggestions from code review Co-authored-by: Lysandre Debut Co-authored-by: Patrick von Platen * Update src/huggingface_hub/file_download.py Co-authored-by: Lysandre Debut * Update src/huggingface_hub/file_download.py Co-authored-by: Lysandre Debut * Only allow full revision hashes otherwise the `revision != commit_hash` test is not reliable * add a little bit more doc + consistency * Update src/huggingface_hub/snapshot_download.py Co-authored-by: Patrick von Platen * Update snapshot download * First pass on tests * Wrap up tests * :wolf: Fix for bug reported by @thomwolf see https://github.com/huggingface/huggingface_hub/pull/801#issuecomment-1134576435 * Special case for Windows * Address comments and docs * Clean up with ternary cc @julien-c * Add argument to `cached_download` * Opt-in for filename_to-url * Opt-in for filename_to-url * Pass the flag * Update docs/source/package_reference/file_download.mdx Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/huggingface_hub/file_download.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Address review comments Co-authored-by: Lysandre Debut Co-authored-by: Patrick von Platen Co-authored-by: Lysandre Debut Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .../package_reference/file_download.mdx | 106 +++++ setup.py | 3 +- src/huggingface_hub/_snapshot_download.py | 215 ++++----- src/huggingface_hub/constants.py | 7 + src/huggingface_hub/file_download.py | 416 +++++++++++++++++- src/huggingface_hub/hf_api.py | 4 +- tests/test_cache_layout.py | 390 ++++++++++++++++ tests/test_file_download.py | 45 +- tests/test_hf_api.py | 16 +- tests/test_snapshot_download.py | 8 - 10 files changed, 1037 insertions(+), 173 deletions(-) create mode 100644 tests/test_cache_layout.py diff --git a/docs/source/package_reference/file_download.mdx b/docs/source/package_reference/file_download.mdx index 79d1063..1a05dfd 100644 --- a/docs/source/package_reference/file_download.mdx +++ b/docs/source/package_reference/file_download.mdx @@ -8,3 +8,109 @@ [[autodoc]] huggingface_hub.hf_hub_url +## Caching + +The methods displayed above are designed to work with a caching system that prevents re-downloading files. +The caching system was updated in v0.8.0 to allow directory structure and file sharing across +libraries that depend on the hub. + +The caching system is designed as follows: + +``` + +├─ +├─ +├─ +``` + +The `` is usually your user's home directory. However, it is customizable with the +`cache_dir` argument on all methods, or by specifying the `HF_HOME` environment variable. + +Models, datasets and spaces share a common root. Each of these repositories contains the namespace +(organization, username) if it exists, alongside the repository name: + +``` + +├─ models--julien-c--EsperBERTo-small +├─ models--lysandrejik--arxiv-nlp +├─ models--bert-base-cased +├─ datasets--glue +├─ datasets--huggingface--DataMeasurementsFiles +├─ spaces--dalle-mini--dalle-mini +``` + +It is within these folders that all files will now be downloaded from the hub. Caching ensures that +a file isn't downloaded twice if it already exists and wasn't updated; but if it was updated, +and you're asking for the latest file, then it will download the latest file (while keeping +the previous file intact in case you need it again). + +In order to achieve this, all folders contain the same skeleton: + +``` + +├─ datasets--glue +│ ├─ refs +│ ├─ blobs +│ ├─ snapshots +... +``` + +Each folder is designed to contain the following: + +### Refs + +The `refs` folder contains files which indicates the latest revision of the given reference. For example, +if we have previously fetched a file from the `main` branch of a repository, the `refs` +folder will contain a file named `main`, which will itself contain the commit identifier of the current head. + +If the latest commit of `main` has `aaaaaa` as identifier, then it will contain `aaaaaa`. + +If that same branch gets updated with a new commit, that has `bbbbbb` as an identifier, then +redownloading a file from that reference will update the `refs/main` file to contain `bbbbbb`. + +### Blobs + +The `blobs` folder contains the actual files that we have downloaded. The name of each file is their hash. + +### Snapshots + +The `snapshots` folder contains symlinks to the blobs mentioned above. It is itself made up of several folders: +one per known revision! + +In the explanation above, we had initially fetched a file from the `aaaaaa` revision, before fetching a file from +the `bbbbbb` revision. In this situation, we would now have two folders in the `snapshots` folder: `aaaaaa` +and `bbbbbb`. + +In each of these folders, live symlinks that have the names of the files that we have downloaded. For example, +if we had downloaded the `READMD.md` file at revision `aaaaaa`, we would have the following path: + +``` +//snapshots/aaaaaa/README.md +``` + +That `README.md` file is actually a symlink linking to the blob that has the hash of the file. + +Creating the skeleton this way means opens up the mechanism to file sharing: if the same file was fetched in +revision `bbbbbb`, it would have the same hash and the file would not need to be redownloaded. + +### In practice + +In practice, it should look like the following tree in your cache: + +``` + [ 96] . + └── [ 160] models--julien-c--EsperBERTo-small + ├── [ 160] blobs + │ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + │ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e + │ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812 + ├── [ 96] refs + │ └── [ 40] main + └── [ 128] snapshots + ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f + │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 + │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + └── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48 + ├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e + └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd +``` \ No newline at end of file diff --git a/setup.py b/setup.py index 748f39d..0ba23a6 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,8 @@ def get_version() -> str: author="Hugging Face, Inc.", author_email="julien@huggingface.co", description=( - "Client library to download and publish models on the huggingface.co hub" + "Client library to download and publish models, datasets and other repos on the" + " huggingface.co hub" ), long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 5cd0a5f..df65e4f 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -1,22 +1,42 @@ import os from fnmatch import fnmatch -from glob import glob from pathlib import Path from typing import Dict, List, Optional, Union -from .constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE -from .file_download import cached_download, hf_hub_url +from .constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE, REPO_TYPES +from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name from .hf_api import HfApi, HfFolder from .utils import logging from .utils._deprecation import _deprecate_positional_args -REPO_ID_SEPARATOR = "--" -# ^ this substring is not allowed in repo_ids on hf.co -# and is the canonical one we use for serialization of repo ids elsewhere. +logger = logging.get_logger(__name__) -logger = logging.get_logger(__name__) +def _filter_repo_files( + *, + repo_files: List[str], + allow_regex: Optional[Union[List[str], str]] = None, + ignore_regex: Optional[Union[List[str], str]] = None, +) -> List[str]: + allow_regex = [allow_regex] if isinstance(allow_regex, str) else allow_regex + ignore_regex = [ignore_regex] if isinstance(ignore_regex, str) else ignore_regex + filtered_files = [] + for repo_file in repo_files: + # if there's an allowlist, skip download if file does not match any regex + if allow_regex is not None and not any( + fnmatch(repo_file, r) for r in allow_regex + ): + continue + + # if there's a denylist, skip download if file does matches any regex + if ignore_regex is not None and any( + fnmatch(repo_file, r) for r in ignore_regex + ): + continue + + filtered_files.append(repo_file) + return filtered_files @_deprecate_positional_args @@ -24,6 +44,7 @@ def snapshot_download( repo_id: str, *, revision: Optional[str] = None, + repo_type: Optional[str] = None, cache_dir: Union[str, Path, None] = None, library_name: Optional[str] = None, library_version: Optional[str] = None, @@ -52,6 +73,9 @@ def snapshot_download( revision (`str`, *optional*): An optional Git revision id which can be a branch name, a tag, or a commit hash. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if uploading to a dataset or space, + `None` or `"model"` if uploading to a model. Default is `None`. cache_dir (`str`, `Path`, *optional*): Path to the folder where cached files are stored. library_name (`str`, *optional*): @@ -97,9 +121,6 @@ def snapshot_download( """ - # Note: at some point maybe this format of storage should actually replace - # the flat storage structure we've used so far (initially from allennlp - # if I remember correctly). if cache_dir is None: cache_dir = HUGGINGFACE_HUB_CACHE @@ -120,122 +141,76 @@ def snapshot_download( else: token = None - # remove all `/` occurrences to correctly convert repo to directory name - repo_id_flattened = repo_id.replace("/", REPO_ID_SEPARATOR) - - # if we have no internet connection we will look for the - # last modified folder in the cache - if local_files_only: - # possible repos have / prefix - repo_folders_prefix = os.path.join(cache_dir, repo_id_flattened) - - # list all possible folders that can correspond to the repo_id - # and are of the format .. - # now let's list all cached repos that have to be included in the revision. - # There are 3 cases that we have to consider. - - # 1) cached repos of format .{revision}. - # -> in this case {revision} has to be a branch - repo_folders_branch = glob(repo_folders_prefix + "." + revision + ".*") - - # 2) cached repos of format .{revision} - # -> in this case {revision} has to be a commit sha - repo_folders_commit_only = glob(repo_folders_prefix + "." + revision) - - # 3) cached repos of format ..{revision} - # -> in this case {revision} also has to be a commit sha - repo_folders_branch_commit = glob(repo_folders_prefix + ".*." + revision) - - # combine all possible fetched cached repos - repo_folders = ( - repo_folders_branch + repo_folders_commit_only + repo_folders_branch_commit + if repo_type is None: + repo_type = "model" + if repo_type not in REPO_TYPES: + raise ValueError( + f"Invalid repo type: {repo_type}. Accepted repo types are:" + f" {str(REPO_TYPES)}" ) - if len(repo_folders) == 0: - raise ValueError( - "Cannot find the requested files in the cached path and outgoing" - " traffic has been disabled. To enable model look-ups and downloads" - " online, set 'local_files_only' to False." - ) + storage_folder = os.path.join( + cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type) + ) - # check if repo id was previously cached from a commit sha revision - # and passed {revision} is not a commit sha - # in this case snapshotting repos locally might lead to unexpected - # behavior the user should be warned about - - # get all folders that were cached with just a sha commit revision - all_repo_folders_from_sha = set(glob(repo_folders_prefix + ".*")) - set( - glob(repo_folders_prefix + ".*.*") - ) - # 1) is there any repo id that was previously cached from a commit sha? - has_a_sha_revision_been_cached = len(all_repo_folders_from_sha) > 0 - # 2) is the passed {revision} is a branch - is_revision_a_branch = ( - len(repo_folders_commit_only + repo_folders_branch_commit) == 0 - ) - - if has_a_sha_revision_been_cached and is_revision_a_branch: - # -> in this case let's warn the user - logger.warn( - f"The repo {repo_id} was previously downloaded from a commit hash" - " revision and has created the following cached directories" - f" {all_repo_folders_from_sha}. In this case, trying to load a repo" - f" from the branch {revision} in offline mode might lead to unexpected" - " behavior by not taking into account the latest commits." - ) - - # find last modified folder - storage_folder = max(repo_folders, key=os.path.getmtime) - - # get commit sha - repo_id_sha = storage_folder.split(".")[-1] - model_files = os.listdir(storage_folder) - else: - # if we have internet connection we retrieve the correct folder name from the huggingface api - _api = HfApi() - model_info = _api.model_info(repo_id=repo_id, revision=revision, token=token) - - storage_folder = os.path.join(cache_dir, repo_id_flattened + "." + revision) - - # if passed revision is not identical to the commit sha - # then revision has to be a branch name, e.g. "main" - # in this case make sure that the branch name is included - # cached storage folder name - if revision != model_info.sha: - storage_folder += f".{model_info.sha}" - - repo_id_sha = model_info.sha - model_files = [f.rfilename for f in model_info.siblings] - - allow_regex = [allow_regex] if isinstance(allow_regex, str) else allow_regex - ignore_regex = [ignore_regex] if isinstance(ignore_regex, str) else ignore_regex + # if we have no internet connection we will look for an + # appropriate folder in the cache + # If the specified revision is a commit hash, look inside "snapshots". + # If the specified revision is a branch or tag, look inside "refs". + if local_files_only: - for model_file in model_files: - # if there's an allowlist, skip download if file does not match any regex - if allow_regex is not None and not any( - fnmatch(model_file, r) for r in allow_regex - ): - continue + if REGEX_COMMIT_HASH.match(revision): + commit_hash = revision + else: + # retrieve commit_hash from file + ref_path = os.path.join(storage_folder, "refs", revision) + with open(ref_path) as f: + commit_hash = f.read() - # if there's a denylist, skip download if file does matches any regex - if ignore_regex is not None and any( - fnmatch(model_file, r) for r in ignore_regex - ): - continue + snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) - url = hf_hub_url(repo_id, filename=model_file, revision=repo_id_sha) - relative_filepath = os.path.join(*model_file.split("/")) + if os.path.exists(snapshot_folder): + return snapshot_folder - # Create potential nested dir - nested_dirname = os.path.dirname( - os.path.join(storage_folder, relative_filepath) + raise ValueError( + "Cannot find an appropriate cached snapshot folder for the specified" + " revision on the local disk and outgoing traffic has been disabled. To" + " enable repo look-ups and downloads online, set 'local_files_only' to" + " False." ) - os.makedirs(nested_dirname, exist_ok=True) - path = cached_download( - url, - cache_dir=storage_folder, - force_filename=relative_filepath, + # if we have internet connection we retrieve the correct folder name from the huggingface api + _api = HfApi() + repo_info = _api.repo_info( + repo_id=repo_id, repo_type=repo_type, revision=revision, token=token + ) + filtered_repo_files = _filter_repo_files( + repo_files=[f.rfilename for f in repo_info.siblings], + allow_regex=allow_regex, + ignore_regex=ignore_regex, + ) + commit_hash = repo_info.sha + snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) + # if passed revision is not identical to commit_hash + # then revision has to be a branch name or tag name. + # In that case store a ref. + if revision != commit_hash: + ref_path = os.path.join(storage_folder, "refs", revision) + os.makedirs(os.path.dirname(ref_path), exist_ok=True) + with open(ref_path, "w") as f: + f.write(commit_hash) + + # we pass the commit_hash to hf_hub_download + # so no network call happens if we already + # have the file locally. + + for repo_file in filtered_repo_files: + _ = hf_hub_download( + repo_id, + filename=repo_file, + repo_type=repo_type, + revision=commit_hash, + cache_dir=cache_dir, library_name=library_name, library_version=library_version, user_agent=user_agent, @@ -243,10 +218,6 @@ def snapshot_download( etag_timeout=etag_timeout, resume_download=resume_download, use_auth_token=use_auth_token, - local_files_only=local_files_only, ) - if os.path.exists(path + ".lock"): - os.remove(path + ".lock") - - return storage_folder + return snapshot_folder diff --git a/src/huggingface_hub/constants.py b/src/huggingface_hub/constants.py index 91f930d..728f45a 100644 --- a/src/huggingface_hub/constants.py +++ b/src/huggingface_hub/constants.py @@ -29,6 +29,13 @@ ) HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}" +HUGGINGFACE_HEADER_X_REPO_COMMIT = "X-Repo-Commit" +HUGGINGFACE_HEADER_X_LINKED_ETAG = "X-Linked-Etag" + +REPO_ID_SEPARATOR = "--" +# ^ this substring is not allowed in repo_ids on hf.co +# and is the canonical one we use for serialization of repo ids elsewhere. + REPO_TYPE_DATASET = "dataset" REPO_TYPE_SPACE = "space" diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 1720036..c3e602f 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -3,9 +3,11 @@ import io import json import os +import re import sys import tempfile import time +import warnings from contextlib import contextmanager from functools import partial from hashlib import sha256 @@ -23,7 +25,10 @@ from .constants import ( DEFAULT_REVISION, HUGGINGFACE_CO_URL_TEMPLATE, + HUGGINGFACE_HEADER_X_LINKED_ETAG, + HUGGINGFACE_HEADER_X_REPO_COMMIT, HUGGINGFACE_HUB_CACHE, + REPO_ID_SEPARATOR, REPO_TYPES, REPO_TYPES_URL_PREFIXES, ) @@ -142,6 +147,9 @@ def get_fastcore_version(): return _fastcore_version +REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$") + + @_deprecate_positional_args def hf_hub_url( repo_id: str, @@ -255,11 +263,31 @@ def url_to_filename(url: str, etag: Optional[str] = None) -> str: return filename -def filename_to_url(filename, cache_dir=None) -> Tuple[str, str]: +def filename_to_url( + filename, + cache_dir: Optional[str] = None, + legacy_cache_layout: Optional[bool] = False, +) -> Tuple[str, str]: """ Return the url and etag (which may be `None`) stored for `filename`. Raise `EnvironmentError` if `filename` or its stored metadata do not exist. + + Args: + filename (`str`): + The name of the file + cache_dir (`str`, *optional*): + The cache directory to use instead of the default one. + legacy_cache_layout (`bool`, *optional*, defaults to `False`): + If `True`, uses the legacy file cache layout i.e. just call `hf_hub_url` + then `cached_download`. This is deprecated as the new cache layout is + more powerful. """ + if not legacy_cache_layout: + warnings.warn( + "`filename_to_url` uses the legacy way cache file layout", + FutureWarning, + ) + if cache_dir is None: cache_dir = HUGGINGFACE_HUB_CACHE if isinstance(cache_dir, Path): @@ -455,6 +483,7 @@ def cached_download( resume_download: Optional[bool] = False, use_auth_token: Union[bool, str, None] = None, local_files_only: Optional[bool] = False, + legacy_cache_layout: Optional[bool] = False, ) -> Optional[str]: # pragma: no cover """ Download from a given URL and cache it if it's not already present in the @@ -496,6 +525,11 @@ def cached_download( local_files_only (`bool`, *optional*, defaults to `False`): If `True`, avoid downloading the file and return the path to the local cached file if it exists. + legacy_cache_layout (`bool`, *optional*, defaults to `False`): + Set this parameter to `True` to mention that you'd like to continue + the old cache layout. Putting this to `True` manually will not raise + any warning when using `cached_download`. We recommend using + `hf_hub_download` to take advantage of the new cache. Returns: Local path (string) of file or if networking is off, last version of @@ -514,6 +548,13 @@ def cached_download( """ + if not legacy_cache_layout: + warnings.warn( + "`cached_download` is the legacy way to download files from the HF hub," + " please consider upgrading to `hf_hub_download`", + FutureWarning, + ) + if cache_dir is None: cache_dir = HUGGINGFACE_HUB_CACHE if isinstance(cache_dir, Path): @@ -689,6 +730,66 @@ def _resumable_file_manager() -> "io.BufferedWriter": return cache_path +def _normalize_etag(etag: str) -> str: + """Normalize ETag HTTP header, so it can be used to create nice filepaths. + + The HTTP spec allows two forms of ETag: + ETag: W/"" + ETag: "" + + The hf.co hub guarantees to only send the second form. + + Args: + etag (`str`): HTTP header + + Returns: + `str`: string that can be used as a nice directory name. + """ + return etag.strip('"') + + +def _create_relative_symlink(src: str, dst: str) -> None: + """Create a symbolic link named dst pointing to src as a relative path to dst. + + The relative part is mostly because it seems more elegant to the author. + + The result layout looks something like + └── [ 128] snapshots + ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f + │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 + │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + """ + relative_src = os.path.relpath(src, start=os.path.dirname(dst)) + try: + os.remove(dst) + except OSError: + pass + try: + os.symlink(relative_src, dst) + except OSError: + # Likely running on Windows + if os.name == "nt": + raise OSError( + "Windows requires Developer Mode to be activated, or to run Python as " + "an administrator, in order to create symlinks.\nIn order to " + "activate Developer Mode, see this article: " + "https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development" + ) + else: + raise + + +def repo_folder_name(*, repo_id: str, repo_type: str) -> str: + """Return a serialized version of a hf.co repo name and type, safe for disk storage + as a single non-nested folder. + + Example: models--julien-c--EsperBERTo-small + """ + # remove all `/` occurrences to correctly convert repo to directory name + parts = [f"{repo_type}s", *repo_id.split("/")] + return REPO_ID_SEPARATOR.join(parts) + + @_deprecate_positional_args def hf_hub_download( repo_id: str, @@ -708,9 +809,36 @@ def hf_hub_download( resume_download: Optional[bool] = False, use_auth_token: Union[bool, str, None] = None, local_files_only: Optional[bool] = False, + legacy_cache_layout: Optional[bool] = False, ): """Download a given file if it's not already present in the local cache. + The new cache file layout looks like this: + - The cache directory contains one subfolder per repo_id (namespaced by repo type) + - inside each repo folder: + - refs is a list of the latest known revision => commit_hash pairs + - blobs contains the actual file blobs (identified by their git-sha or sha256, depending on + whether they're LFS files or not) + - snapshots contains one subfolder per commit, each "commit" contains the subset of the files + that have been resolved at that particular commit. Each filename is a symlink to the blob + at that particular commit. + + [ 96] . + └── [ 160] models--julien-c--EsperBERTo-small + ├── [ 160] blobs + │ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + │ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e + │ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812 + ├── [ 96] refs + │ └── [ 40] main + └── [ 128] snapshots + ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f + │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 + │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + └── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48 + ├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e + └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + Args: repo_id (`str`): A user or an organization name and a repo name separated by a `/`. @@ -735,8 +863,6 @@ def hf_hub_download( force_download (`bool`, *optional*, defaults to `False`): Whether the file should be downloaded even if it already exists in the local cache. - force_filename (`str`, *optional*): - Use this name instead of a generated file name. proxies (`dict`, *optional*): Dictionary mapping protocol to the URL of the proxy passed to `requests.request`. @@ -753,6 +879,10 @@ def hf_hub_download( local_files_only (`bool`, *optional*, defaults to `False`): If `True`, avoid downloading the file and return the path to the local cached file if it exists. + legacy_cache_layout (`bool`, *optional*, defaults to `False`): + If `True`, uses the legacy file cache layout i.e. just call [`hf_hub_url`] + then `cached_download`. This is deprecated as the new cache layout is + more powerful. Returns: Local path (string) of file or if networking is off, last version of @@ -771,21 +901,271 @@ def hf_hub_download( """ - url = hf_hub_url( - repo_id, filename, subfolder=subfolder, repo_type=repo_type, revision=revision + if force_filename is not None: + warnings.warn( + "The `force_filename` parameter is deprecated as a new caching system, " + "which keeps the filenames as they are on the Hub, is now in place.", + FutureWarning, + ) + legacy_cache_layout = True + + if legacy_cache_layout: + url = hf_hub_url( + repo_id, + filename, + subfolder=subfolder, + repo_type=repo_type, + revision=revision, + ) + + return cached_download( + url, + library_name=library_name, + library_version=library_version, + cache_dir=cache_dir, + user_agent=user_agent, + force_download=force_download, + force_filename=force_filename, + proxies=proxies, + etag_timeout=etag_timeout, + resume_download=resume_download, + use_auth_token=use_auth_token, + local_files_only=local_files_only, + legacy_cache_layout=legacy_cache_layout, + ) + + if cache_dir is None: + cache_dir = HUGGINGFACE_HUB_CACHE + if revision is None: + revision = DEFAULT_REVISION + if isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if subfolder is not None: + # This is used to create a URL, and not a local path, hence the forward slash. + filename = f"{subfolder}/{filename}" + + if repo_type is None: + repo_type = "model" + if repo_type not in REPO_TYPES: + raise ValueError( + f"Invalid repo type: {repo_type}. Accepted repo types are:" + f" {str(REPO_TYPES)}" + ) + + storage_folder = os.path.join( + cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type) ) + os.makedirs(storage_folder, exist_ok=True) - return cached_download( - url, - library_name=library_name, - library_version=library_version, - cache_dir=cache_dir, - user_agent=user_agent, - force_download=force_download, - force_filename=force_filename, - proxies=proxies, - etag_timeout=etag_timeout, - resume_download=resume_download, - use_auth_token=use_auth_token, - local_files_only=local_files_only, + # cross platform transcription of filename, to be used as a local file path. + relative_filename = os.path.join(*filename.split("/")) + + # if user provides a commit_hash and they already have the file on disk, + # shortcut everything. + if REGEX_COMMIT_HASH.match(revision): + pointer_path = os.path.join( + storage_folder, "snapshots", revision, relative_filename + ) + if os.path.exists(pointer_path): + return pointer_path + + url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision) + + headers = { + "user-agent": http_user_agent( + library_name=library_name, + library_version=library_version, + user_agent=user_agent, + ) + } + if isinstance(use_auth_token, str): + headers["authorization"] = f"Bearer {use_auth_token}" + elif use_auth_token: + token = HfFolder.get_token() + if token is None: + raise EnvironmentError( + "You specified use_auth_token=True, but a huggingface token was not" + " found." + ) + headers["authorization"] = f"Bearer {token}" + + url_to_download = url + etag = None + commit_hash = None + if not local_files_only: + try: + r = _request_with_retry( + method="HEAD", + url=url, + headers=headers, + allow_redirects=False, + proxies=proxies, + timeout=etag_timeout, + ) + r.raise_for_status() + commit_hash = r.headers[HUGGINGFACE_HEADER_X_REPO_COMMIT] + if commit_hash is None: + raise OSError( + "Distant resource does not seem to be on huggingface.co (missing" + " commit header)." + ) + etag = r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get( + "ETag" + ) + # We favor a custom header indicating the etag of the linked resource, and + # we fallback to the regular etag header. + # If we don't have any of those, raise an error. + if etag is None: + raise OSError( + "Distant resource does not have an ETag, we won't be able to" + " reliably ensure reproducibility." + ) + etag = _normalize_etag(etag) + # In case of a redirect, + # save an extra redirect on the request.get call, + # and ensure we download the exact atomic version even if it changed + # between the HEAD and the GET (unlikely, but hey). + if 300 <= r.status_code <= 399: + url_to_download = r.headers["Location"] + except (requests.exceptions.SSLError, requests.exceptions.ProxyError): + # Actually raise for those subclasses of ConnectionError + raise + except ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + OfflineModeIsEnabled, + ): + # Otherwise, our Internet connection is down. + # etag is None + pass + + # etag is None == we don't have a connection or we passed local_files_only. + # try to get the last downloaded one from the specified revision. + # If the specified revision is a commit hash, look inside "snapshots". + # If the specified revision is a branch or tag, look inside "refs". + if etag is None: + # In those cases, we cannot force download. + if force_download: + raise ValueError( + "We have no connection or you passed local_files_only, so" + " force_download is not an accepted option." + ) + commit_hash = revision + if not REGEX_COMMIT_HASH.match(revision): + ref_path = os.path.join(storage_folder, "refs", revision) + with open(ref_path) as f: + commit_hash = f.read() + + pointer_path = os.path.join( + storage_folder, "snapshots", commit_hash, relative_filename + ) + if os.path.exists(pointer_path): + return pointer_path + + # If we couldn't find an appropriate file on disk, + # raise an error. + # If files cannot be found and local_files_only=True, + # the models might've been found if local_files_only=False + # Notify the user about that + if local_files_only: + raise ValueError( + "Cannot find the requested files in the disk cache and" + " outgoing traffic has been disabled. To enable hf.co look-ups" + " and downloads online, set 'local_files_only' to False." + ) + else: + raise ValueError( + "Connection error, and we cannot find the requested files in" + " the disk cache. Please try again or make sure your Internet" + " connection is on." + ) + + # From now on, etag and commit_hash are not None. + blob_path = os.path.join(storage_folder, "blobs", etag) + pointer_path = os.path.join( + storage_folder, "snapshots", commit_hash, relative_filename ) + + os.makedirs(os.path.dirname(blob_path), exist_ok=True) + os.makedirs(os.path.dirname(pointer_path), exist_ok=True) + # if passed revision is not identical to commit_hash + # then revision has to be a branch name or tag name. + # In that case store a ref. + if revision != commit_hash: + ref_path = os.path.join(storage_folder, "refs", revision) + os.makedirs(os.path.dirname(ref_path), exist_ok=True) + with open(ref_path, "w") as f: + f.write(commit_hash) + + if os.path.exists(pointer_path) and not force_download: + return pointer_path + + if os.path.exists(blob_path) and not force_download: + # we have the blob already, but not the pointer + logger.info("creating pointer to %s from %s", blob_path, pointer_path) + _create_relative_symlink(blob_path, pointer_path) + return pointer_path + + # Prevent parallel downloads of the same file with a lock. + lock_path = blob_path + ".lock" + + # Some Windows versions do not allow for paths longer than 255 characters. + # In this case, we must specify it is an extended path by using the "\\?\" prefix. + if os.name == "nt" and len(os.path.abspath(lock_path)) > 255: + lock_path = "\\\\?\\" + os.path.abspath(lock_path) + + if os.name == "nt" and len(os.path.abspath(blob_path)) > 255: + blob_path = "\\\\?\\" + os.path.abspath(blob_path) + + with FileLock(lock_path): + + # If the download just completed while the lock was activated. + if os.path.exists(pointer_path) and not force_download: + # Even if returning early like here, the lock will be released. + return pointer_path + + if resume_download: + incomplete_path = blob_path + ".incomplete" + + @contextmanager + def _resumable_file_manager() -> "io.BufferedWriter": + with open(incomplete_path, "ab") as f: + yield f + + temp_file_manager = _resumable_file_manager + if os.path.exists(incomplete_path): + resume_size = os.stat(incomplete_path).st_size + else: + resume_size = 0 + else: + temp_file_manager = partial( + tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False + ) + resume_size = 0 + + # Download to temporary file, then copy to cache dir once finished. + # Otherwise you get corrupt cache entries if the download gets interrupted. + with temp_file_manager() as temp_file: + logger.info("downloading %s to %s", url, temp_file.name) + + http_get( + url_to_download, + temp_file, + proxies=proxies, + resume_size=resume_size, + headers=headers, + ) + + logger.info("storing %s in cache at %s", url, blob_path) + os.replace(temp_file.name, blob_path) + + logger.info("creating pointer to %s from %s", blob_path, pointer_path) + _create_relative_symlink(blob_path, pointer_path) + + try: + os.remove(lock_path) + except OSError: + pass + + return pointer_path diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index e5ad61b..626a23c 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -89,7 +89,9 @@ def _validate_repo_id_deprecation(repo_id, name, organization): return name, organization -def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None): +def repo_type_and_id_from_hf_id( + hf_id: str, hub_url: Optional[str] = None +) -> Tuple[Optional[str], Optional[str], str]: """ Returns the repo type and ID from a huggingface.co URL linking to a repository diff --git a/tests/test_cache_layout.py b/tests/test_cache_layout.py new file mode 100644 index 0000000..442fa6a --- /dev/null +++ b/tests/test_cache_layout.py @@ -0,0 +1,390 @@ +import os +import tempfile +import time +import unittest +import uuid +from io import BytesIO + +from huggingface_hub import ( + HfApi, + create_repo, + delete_repo, + hf_hub_download, + snapshot_download, + upload_file, +) +from huggingface_hub.utils import logging + +from .testing_constants import ENDPOINT_STAGING, TOKEN, USER +from .testing_utils import with_production_testing + + +logger = logging.get_logger(__name__) +MODEL_IDENTIFIER = "hf-internal-testing/hfh-cache-layout" + + +def repo_name(id=uuid.uuid4().hex[:6]): + return "repo-{0}-{1}".format(id, int(time.time() * 10e3)) + + +def get_file_contents(path): + with open(path) as f: + content = f.read() + + return content + + +@with_production_testing +class CacheFileLayoutHfHubDownload(unittest.TestCase): + def test_file_downloaded_in_cache(self): + with tempfile.TemporaryDirectory() as cache: + hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache) + + expected_directory_name = f'models--{MODEL_IDENTIFIER.replace("/", "--")}' + expected_path = os.path.join(cache, expected_directory_name) + + refs = os.listdir(os.path.join(expected_path, "refs")) + snapshots = os.listdir(os.path.join(expected_path, "snapshots")) + + expected_reference = "main" + + # Only reference should be `main`. + self.assertListEqual(refs, [expected_reference]) + + with open(os.path.join(expected_path, "refs", expected_reference)) as f: + snapshot_name = f.readline().strip() + + # The `main` reference should point to the only snapshot we have downloaded + self.assertListEqual(snapshots, [snapshot_name]) + + snapshot_path = os.path.join(expected_path, "snapshots", snapshot_name) + snapshot_content = os.listdir(snapshot_path) + + # Only a single file in the snapshot + self.assertEqual(len(snapshot_content), 1) + + snapshot_content_path = os.path.join(snapshot_path, snapshot_content[0]) + + # The snapshot content should link to a blob + self.assertTrue(os.path.islink(snapshot_content_path)) + + resolved_blob_relative = os.readlink(snapshot_content_path) + resolved_blob_absolute = os.path.normpath( + os.path.join(snapshot_path, resolved_blob_relative) + ) + + with open(resolved_blob_absolute) as f: + blob_contents = f.read().strip() + + # The contents of the file should be 'File 0'. + self.assertEqual(blob_contents, "File 0") + + def test_file_downloaded_in_cache_with_revision(self): + with tempfile.TemporaryDirectory() as cache: + hf_hub_download( + MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache, revision="file-2" + ) + + expected_directory_name = f'models--{MODEL_IDENTIFIER.replace("/", "--")}' + expected_path = os.path.join(cache, expected_directory_name) + + refs = os.listdir(os.path.join(expected_path, "refs")) + snapshots = os.listdir(os.path.join(expected_path, "snapshots")) + + expected_reference = "file-2" + + # Only reference should be `file-2`. + self.assertListEqual(refs, [expected_reference]) + + with open(os.path.join(expected_path, "refs", expected_reference)) as f: + snapshot_name = f.read().strip() + + # The `main` reference should point to the only snapshot we have downloaded + self.assertListEqual(snapshots, [snapshot_name]) + + snapshot_path = os.path.join(expected_path, "snapshots", snapshot_name) + snapshot_content = os.listdir(snapshot_path) + + # Only a single file in the snapshot + self.assertEqual(len(snapshot_content), 1) + + snapshot_content_path = os.path.join(snapshot_path, snapshot_content[0]) + + # The snapshot content should link to a blob + self.assertTrue(os.path.islink(snapshot_content_path)) + + resolved_blob_relative = os.readlink(snapshot_content_path) + resolved_blob_absolute = os.path.normpath( + os.path.join(snapshot_path, resolved_blob_relative) + ) + + with open(resolved_blob_absolute) as f: + blob_contents = f.readline().strip() + + # The contents of the file should be 'File 0'. + self.assertEqual(blob_contents, "File 0") + + def test_file_download_happens_once(self): + # Tests that a file is only downloaded once if it's not updated. + + with tempfile.TemporaryDirectory() as cache: + path = hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache) + creation_time_0 = os.path.getmtime(path) + + time.sleep(2) + + path = hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache) + creation_time_1 = os.path.getmtime(path) + + self.assertEqual(creation_time_0, creation_time_1) + + def test_file_download_happens_once_intra_revision(self): + # Tests that a file is only downloaded once if it's not updated, even across different revisions. + + with tempfile.TemporaryDirectory() as cache: + path = hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache) + creation_time_0 = os.path.getmtime(path) + + time.sleep(2) + + path = hf_hub_download( + MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache, revision="file-2" + ) + creation_time_1 = os.path.getmtime(path) + + self.assertEqual(creation_time_0, creation_time_1) + + def test_multiple_refs_for_same_file(self): + with tempfile.TemporaryDirectory() as cache: + hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache) + hf_hub_download( + MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache, revision="file-2" + ) + + expected_directory_name = f'models--{MODEL_IDENTIFIER.replace("/", "--")}' + expected_path = os.path.join(cache, expected_directory_name) + + refs = os.listdir(os.path.join(expected_path, "refs")) + refs.sort() + + snapshots = os.listdir(os.path.join(expected_path, "snapshots")) + snapshots.sort() + + # Directory should contain two revisions + self.assertListEqual(refs, ["file-2", "main"]) + + refs_contents = [ + get_file_contents(os.path.join(expected_path, "refs", f)) for f in refs + ] + refs_contents.sort() + + # snapshots directory should contain two snapshots + self.assertListEqual(refs_contents, snapshots) + + snapshot_links = [ + os.readlink( + os.path.join(expected_path, "snapshots", filename, "file_0.txt") + ) + for filename in snapshots + ] + + # All snapshot links should point to the same file. + self.assertEqual(*snapshot_links) + + +@with_production_testing +class CacheFileLayoutSnapshotDownload(unittest.TestCase): + def test_file_downloaded_in_cache(self): + with tempfile.TemporaryDirectory() as cache: + snapshot_download(MODEL_IDENTIFIER, cache_dir=cache) + + expected_directory_name = f'models--{MODEL_IDENTIFIER.replace("/", "--")}' + expected_path = os.path.join(cache, expected_directory_name) + + refs = os.listdir(os.path.join(expected_path, "refs")) + + snapshots = os.listdir(os.path.join(expected_path, "snapshots")) + snapshots.sort() + + # Directory should contain two revisions + self.assertListEqual(refs, ["main"]) + + ref_content = get_file_contents( + os.path.join(expected_path, "refs", refs[0]) + ) + + # snapshots directory should contain two snapshots + self.assertListEqual([ref_content], snapshots) + + snapshot_path = os.path.join(expected_path, "snapshots", snapshots[0]) + + files_in_snapshot = os.listdir(snapshot_path) + + snapshot_links = [ + os.readlink(os.path.join(snapshot_path, filename)) + for filename in files_in_snapshot + ] + + resolved_snapshot_links = [ + os.path.normpath(os.path.join(snapshot_path, link)) + for link in snapshot_links + ] + + self.assertTrue(all([os.path.isfile(l) for l in resolved_snapshot_links])) + + def test_file_downloaded_in_cache_several_revisions(self): + with tempfile.TemporaryDirectory() as cache: + snapshot_download(MODEL_IDENTIFIER, cache_dir=cache, revision="file-3") + snapshot_download(MODEL_IDENTIFIER, cache_dir=cache, revision="file-2") + + expected_directory_name = f'models--{MODEL_IDENTIFIER.replace("/", "--")}' + expected_path = os.path.join(cache, expected_directory_name) + + refs = os.listdir(os.path.join(expected_path, "refs")) + refs.sort() + + snapshots = os.listdir(os.path.join(expected_path, "snapshots")) + snapshots.sort() + + # Directory should contain two revisions + self.assertListEqual(refs, ["file-2", "file-3"]) + + refs_content = [ + get_file_contents(os.path.join(expected_path, "refs", ref)) + for ref in refs + ] + refs_content.sort() + + # snapshots directory should contain two snapshots + self.assertListEqual(refs_content, snapshots) + + snapshots_paths = [ + os.path.join(expected_path, "snapshots", s) for s in snapshots + ] + + files_in_snapshots = {s: os.listdir(s) for s in snapshots_paths} + links_in_snapshots = { + k: [os.readlink(os.path.join(k, _v)) for _v in v] + for k, v in files_in_snapshots.items() + } + + resolved_snapshots_links = { + k: [os.path.normpath(os.path.join(k, link)) for link in v] + for k, v in links_in_snapshots.items() + } + + all_links = [b for a in resolved_snapshots_links.values() for b in a] + all_unique_links = set(all_links) + + # [ 100] . + # ├── [ 140] blobs + # │ ├── [ 7] 4475433e279a71203927cbe80125208a3b5db560 + # │ ├── [ 7] 50fcd26d6ce3000f9d5f12904e80eccdc5685dd1 + # │ ├── [ 7] 80146afc836c60e70ba67933fec439ab05b478f6 + # │ ├── [ 7] 8cf9e18f080becb674b31c21642538269fe886a4 + # │ └── [1.1K] ac481c8eb05e4d2496fbe076a38a7b4835dd733d + # ├── [ 80] refs + # │ ├── [ 40] file-2 + # │ └── [ 40] file-3 + # └── [ 80] snapshots + # ├── [ 120] 5e23cb3ae7f904919a442e1b27dcddae6c6bc292 + # │ ├── [ 52] file_0.txt -> ../../blobs/80146afc836c60e70ba67933fec439ab05b478f6 + # │ ├── [ 52] file_1.txt -> ../../blobs/50fcd26d6ce3000f9d5f12904e80eccdc5685dd1 + # │ ├── [ 52] file_2.txt -> ../../blobs/4475433e279a71203927cbe80125208a3b5db560 + # │ └── [ 52] .gitattributes -> ../../blobs/ac481c8eb05e4d2496fbe076a38a7b4835dd733d + # └── [ 120] 78aa2ebdb60bba086496a8792ba506e58e587b4c + # ├── [ 52] file_0.txt -> ../../blobs/80146afc836c60e70ba67933fec439ab05b478f6 + # ├── [ 52] file_1.txt -> ../../blobs/50fcd26d6ce3000f9d5f12904e80eccdc5685dd1 + # ├── [ 52] file_3.txt -> ../../blobs/8cf9e18f080becb674b31c21642538269fe886a4 + # └── [ 52] .gitattributes -> ../../blobs/ac481c8eb05e4d2496fbe076a38a7b4835dd733d + + # Across the two revisions, there should be 8 total links + self.assertEqual(len(all_links), 8) + + # Across the two revisions, there should only be 5 unique files. + self.assertEqual(len(all_unique_links), 5) + + +class ReferenceUpdates(unittest.TestCase): + _api = HfApi(endpoint=ENDPOINT_STAGING) + + @classmethod + def setUpClass(cls): + """ + Share this valid token in all tests below. + """ + cls._token = TOKEN + cls._api.set_access_token(TOKEN) + + def test_update_reference(self): + repo_id = f"{USER}/{repo_name()}" + create_repo(repo_id, token=self._token, exist_ok=True) + + try: + upload_file( + path_or_fileobj=BytesIO(b"Some string"), + path_in_repo="file.txt", + repo_id=repo_id, + token=self._token, + ) + + with tempfile.TemporaryDirectory() as cache: + hf_hub_download(repo_id, "file.txt", cache_dir=cache) + + expected_directory_name = f'models--{repo_id.replace("/", "--")}' + expected_path = os.path.join(cache, expected_directory_name) + + refs = os.listdir(os.path.join(expected_path, "refs")) + + # Directory should contain two revisions + self.assertListEqual(refs, ["main"]) + + initial_ref_content = get_file_contents( + os.path.join(expected_path, "refs", refs[0]) + ) + + # Upload a new file on the same branch + upload_file( + path_or_fileobj=BytesIO(b"Some new string"), + path_in_repo="file.txt", + repo_id=repo_id, + token=self._token, + ) + + hf_hub_download(repo_id, "file.txt", cache_dir=cache) + + final_ref_content = get_file_contents( + os.path.join(expected_path, "refs", refs[0]) + ) + + # The `main` reference should point to two different, but existing snapshots which contain + # a 'file.txt' + self.assertNotEqual(initial_ref_content, final_ref_content) + self.assertTrue( + os.path.isdir( + os.path.join(expected_path, "snapshots", initial_ref_content) + ) + ) + self.assertTrue( + os.path.isfile( + os.path.join( + expected_path, "snapshots", initial_ref_content, "file.txt" + ) + ) + ) + self.assertTrue( + os.path.isdir( + os.path.join(expected_path, "snapshots", final_ref_content) + ) + ) + self.assertTrue( + os.path.isfile( + os.path.join( + expected_path, "snapshots", final_ref_content, "file.txt" + ) + ) + ) + except Exception: + raise + finally: + delete_repo(repo_id, token=self._token) diff --git a/tests/test_file_download.py b/tests/test_file_download.py index 05499e3..071478c 100644 --- a/tests/test_file_download.py +++ b/tests/test_file_download.py @@ -57,7 +57,7 @@ class CachedDownloadTests(unittest.TestCase): def test_bogus_url(self): url = "https://bogus" with self.assertRaisesRegex(ValueError, "Connection error"): - _ = cached_download(url) + _ = cached_download(url, legacy_cache_layout=True) def test_no_connection(self): invalid_url = hf_hub_url( @@ -68,20 +68,26 @@ def test_no_connection(self): valid_url = hf_hub_url( DUMMY_MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT ) - self.assertIsNotNone(cached_download(valid_url, force_download=True)) + self.assertIsNotNone( + cached_download(valid_url, force_download=True, legacy_cache_layout=True) + ) for offline_mode in OfflineSimulationMode: with offline(mode=offline_mode): with self.assertRaisesRegex(ValueError, "Connection error"): - _ = cached_download(invalid_url) + _ = cached_download(invalid_url, legacy_cache_layout=True) with self.assertRaisesRegex(ValueError, "Connection error"): - _ = cached_download(valid_url, force_download=True) - self.assertIsNotNone(cached_download(valid_url)) + _ = cached_download( + valid_url, force_download=True, legacy_cache_layout=True + ) + self.assertIsNotNone( + cached_download(valid_url, legacy_cache_layout=True) + ) def test_file_not_found(self): # Valid revision (None) but missing file. url = hf_hub_url(DUMMY_MODEL_ID, filename="missing.bin") with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"): - _ = cached_download(url) + _ = cached_download(url, legacy_cache_layout=True) def test_revision_not_found(self): # Valid file but missing revision @@ -91,14 +97,14 @@ def test_revision_not_found(self): revision=DUMMY_MODEL_ID_REVISION_INVALID, ) with self.assertRaisesRegex(requests.exceptions.HTTPError, "404 Client Error"): - _ = cached_download(url) + _ = cached_download(url, legacy_cache_layout=True) def test_standard_object(self): url = hf_hub_url( DUMMY_MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT ) - filepath = cached_download(url, force_download=True) - metadata = filename_to_url(filepath) + filepath = cached_download(url, force_download=True, legacy_cache_layout=True) + metadata = filename_to_url(filepath, legacy_cache_layout=True) self.assertEqual(metadata, (url, f'"{DUMMY_MODEL_ID_PINNED_SHA1}"')) def test_standard_object_rev(self): @@ -108,8 +114,8 @@ def test_standard_object_rev(self): filename=CONFIG_NAME, revision=DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT, ) - filepath = cached_download(url, force_download=True) - metadata = filename_to_url(filepath) + filepath = cached_download(url, force_download=True, legacy_cache_layout=True) + metadata = filename_to_url(filepath, legacy_cache_layout=True) self.assertNotEqual(metadata[1], f'"{DUMMY_MODEL_ID_PINNED_SHA1}"') # Caution: check that the etag is *not* equal to the one from `test_standard_object` @@ -117,8 +123,8 @@ def test_lfs_object(self): url = hf_hub_url( DUMMY_MODEL_ID, filename=PYTORCH_WEIGHTS_NAME, revision=REVISION_ID_DEFAULT ) - filepath = cached_download(url, force_download=True) - metadata = filename_to_url(filepath) + filepath = cached_download(url, force_download=True, legacy_cache_layout=True) + metadata = filename_to_url(filepath, legacy_cache_layout=True) self.assertEqual(metadata, (url, f'"{DUMMY_MODEL_ID_PINNED_SHA256}"')) def test_dataset_standard_object_rev(self): @@ -136,8 +142,8 @@ def test_dataset_standard_object_rev(self): ) self.assertEqual(url, url2) # now let's download - filepath = cached_download(url, force_download=True) - metadata = filename_to_url(filepath) + filepath = cached_download(url, force_download=True, legacy_cache_layout=True) + metadata = filename_to_url(filepath, legacy_cache_layout=True) self.assertNotEqual(metadata[1], f'"{DUMMY_MODEL_ID_PINNED_SHA1}"') def test_dataset_lfs_object(self): @@ -147,19 +153,20 @@ def test_dataset_lfs_object(self): repo_type=REPO_TYPE_DATASET, revision=DATASET_REVISION_ID_ONE_SPECIFIC_COMMIT, ) - filepath = cached_download(url, force_download=True) - metadata = filename_to_url(filepath) + filepath = cached_download(url, force_download=True, legacy_cache_layout=True) + metadata = filename_to_url(filepath, legacy_cache_layout=True) self.assertEqual( metadata, (url, '"95aa6a52d5d6a735563366753ca50492a658031da74f301ac5238b03966972c9"'), ) - def test_hf_hub_download(self): + def test_hf_hub_download_legacy(self): filepath = hf_hub_download( DUMMY_MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT, force_download=True, + legacy_cache_layout=True, ) - metadata = filename_to_url(filepath) + metadata = filename_to_url(filepath, legacy_cache_layout=True) self.assertEqual(metadata[1], f'"{DUMMY_MODEL_ID_PINNED_SHA1}"') diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 64e48b3..c64cf6c 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -511,7 +511,9 @@ def test_upload_file_path(self): user=USER, repo=REPO_NAME, ) - filepath = cached_download(url, force_download=True) + filepath = cached_download( + url, force_download=True, legacy_cache_layout=True + ) with open(filepath) as downloaded_file: content = downloaded_file.read() self.assertEqual(content, self.tmp_file_content) @@ -538,7 +540,9 @@ def test_upload_file_fileobj(self): user=USER, repo=REPO_NAME, ) - filepath = cached_download(url, force_download=True) + filepath = cached_download( + url, force_download=True, legacy_cache_layout=True + ) with open(filepath) as downloaded_file: content = downloaded_file.read() self.assertEqual(content, self.tmp_file_content) @@ -565,7 +569,9 @@ def test_upload_file_bytesio(self): user=USER, repo=REPO_NAME, ) - filepath = cached_download(url, force_download=True) + filepath = cached_download( + url, force_download=True, legacy_cache_layout=True + ) with open(filepath) as downloaded_file: content = downloaded_file.read() self.assertEqual(content, filecontent.getvalue().decode()) @@ -648,7 +654,9 @@ def test_upload_buffer(self): user=USER, repo=REPO_NAME, ) - filepath = cached_download(url, force_download=True) + filepath = cached_download( + url, force_download=True, legacy_cache_layout=True + ) with open(filepath) as downloaded_file: content = downloaded_file.read() self.assertEqual(content, self.tmp_file_content) diff --git a/tests/test_snapshot_download.py b/tests/test_snapshot_download.py index b41b438..23d2947 100644 --- a/tests/test_snapshot_download.py +++ b/tests/test_snapshot_download.py @@ -275,14 +275,6 @@ def test_download_model_local_only_multiple(self): cache_dir=tmpdirname, ) - # now load from cache and make sure warning to be raised - with self.assertWarns(Warning): - snapshot_download( - f"{USER}/{REPO_NAME}", - cache_dir=tmpdirname, - local_files_only=True, - ) - # cache multiple commits and make sure correct commit is taken with tempfile.TemporaryDirectory() as tmpdirname: # first download folder to cache it