Skip to content

Commit

Permalink
HF Hub download (#137)
Browse files Browse the repository at this point in the history
* HF Hub download

* Update src/huggingface_hub/file_download.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
  • Loading branch information
LysandreJik and julien-c authored Jun 22, 2021
1 parent 11d9ab6 commit cc91e1d
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
)
from .file_download import cached_download, hf_hub_url
from .file_download import cached_download, hf_hub_download, hf_hub_url
from .hf_api import HfApi, HfFolder
from .hub_mixin import ModelHubMixin
from .repository import Repository
Expand Down
61 changes: 61 additions & 0 deletions src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,64 @@ def _resumable_file_manager() -> "io.BufferedWriter":
json.dump(meta, meta_file)

return cache_path


def hf_hub_download(
repo_id: str,
filename: str,
subfolder: Optional[str] = None,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
user_agent: Union[Dict, str, None] = None,
force_download=False,
force_filename: Optional[str] = None,
proxies=None,
etag_timeout=10,
resume_download=False,
use_auth_token: Union[bool, str, None] = None,
local_files_only=False,
):
"""
Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co file distributed through
Cloudfront (a Content Delivery Network, or CDN) for large files (more than a few MBs).
The file is cached locally: look for the corresponding file in the local cache. If it's not there,
download it. Then return the path to the cached file.
Cloudfront is replicated over the globe so downloads are way faster for the end user.
Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
because we implement a git-based versioning system on huggingface.co, which means that we store the files on S3/Cloudfront
in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
can't ever be stale.
In terms of client-side caching from this library, we base our caching on the objects' ETag. An object's ETag is:
its git-sha1 if stored in git, or its sha256 if stored in git-lfs.
Return:
Local path (string) of file or if networking is off, last version of file cached on disk.
Raises:
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
"""
url = hf_hub_url(
repo_id, filename, subfolder=subfolder, repo_type=repo_type, revision=revision
)

return cached_download(
url,
library_name=library_name,
library_version=library_version,
cache_dir=cache_dir,
user_agent=user_agent,
force_download=force_download,
force_filename=force_filename,
proxies=proxies,
etag_timeout=etag_timeout,
resume_download=resume_download,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)
17 changes: 16 additions & 1 deletion tests/test_file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
PYTORCH_WEIGHTS_NAME,
REPO_TYPE_DATASET,
)
from huggingface_hub.file_download import cached_download, filename_to_url, hf_hub_url
from huggingface_hub.file_download import (
cached_download,
filename_to_url,
hf_hub_download,
hf_hub_url,
)

from .testing_utils import (
DUMMY_MODEL_ID,
Expand Down Expand Up @@ -146,3 +151,13 @@ def test_dataset_lfs_object(self):
metadata,
(url, '"95aa6a52d5d6a735563366753ca50492a658031da74f301ac5238b03966972c9"'),
)

def test_hf_hub_download(self):
filepath = hf_hub_download(
DUMMY_MODEL_ID,
filename=CONFIG_NAME,
revision=REVISION_ID_DEFAULT,
force_download=True,
)
metadata = filename_to_url(filepath)
self.assertEqual(metadata[1], f'"{DUMMY_MODEL_ID_PINNED_SHA1}"')

0 comments on commit cc91e1d

Please sign in to comment.