diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 503ab1c87b..7fcc5c7d96 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -24,8 +24,9 @@ def snapshot_download( repo_id: str, *, - revision: Optional[str] = None, repo_type: Optional[str] = None, + revision: Optional[str] = None, + endpoint: Optional[str] = None, cache_dir: Union[str, Path, None] = None, local_dir: Union[str, Path, None] = None, local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", @@ -71,12 +72,15 @@ def snapshot_download( Args: repo_id (`str`): A user or an organization name and a repo name separated by a `/`. - revision (`str`, *optional*): - An optional Git revision id which can be a branch name, a tag, or a - commit hash. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if downloading from a dataset or space, `None` or `"model"` if downloading from a model. Default is `None`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + endpoint (`str`, *optional*): + Hugging Face Hub base url. Will default to https://huggingface.co/. Otherwise, one can set the `HF_ENDPOINT` + environment variable. cache_dir (`str`, `Path`, *optional*): Path to the folder where cached files are stored. local_dir (`str` or `Path`, *optional*: @@ -181,7 +185,7 @@ def snapshot_download( ) # if we have internet connection we retrieve the correct folder name from the huggingface api - api = HfApi(library_name=library_name, library_version=library_version, user_agent=user_agent) + api = HfApi(library_name=library_name, library_version=library_version, user_agent=user_agent, endpoint=endpoint) repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token) assert repo_info.sha is not None, "Repo info returned from server must have a revision sha." @@ -212,6 +216,7 @@ def _inner_hf_hub_download(repo_file: str): filename=repo_file, repo_type=repo_type, revision=commit_hash, + endpoint=endpoint, cache_dir=cache_dir, local_dir=local_dir, local_dir_use_symlinks=local_dir_use_symlinks, diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index b45f450d85..173cbc2de8 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -26,6 +26,7 @@ from . import __version__ # noqa: F401 # for backward compatibility from .constants import ( DEFAULT_REVISION, + ENDPOINT, HF_HUB_DISABLE_SYMLINKS_WARNING, HF_HUB_ENABLE_HF_TRANSFER, HUGGINGFACE_CO_URL_TEMPLATE, @@ -176,6 +177,7 @@ def hf_hub_url( subfolder: Optional[str] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, + endpoint: Optional[str] = None, ) -> str: """Construct the URL of a file from the given information. @@ -197,6 +199,9 @@ def hf_hub_url( revision (`str`, *optional*): An optional Git revision id which can be a branch name, a tag, or a commit hash. + endpoint (`str`, *optional*): + Hugging Face Hub base url. Will default to https://huggingface.co/. Otherwise, one can set the `HF_ENDPOINT` + environment variable. Example: @@ -247,11 +252,13 @@ def hf_hub_url( if revision is None: revision = DEFAULT_REVISION - return HUGGINGFACE_CO_URL_TEMPLATE.format( - repo_id=repo_id, - revision=quote(revision, safe=""), - filename=quote(filename), + url = HUGGINGFACE_CO_URL_TEMPLATE.format( + repo_id=repo_id, revision=quote(revision, safe=""), filename=quote(filename) ) + # Update endpoint if provided + if endpoint is not None and url.startswith(ENDPOINT): + url = endpoint + url[len(ENDPOINT) :] + return url def url_to_filename(url: str, etag: Optional[str] = None) -> str: @@ -962,6 +969,7 @@ def hf_hub_download( subfolder: Optional[str] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, + endpoint: Optional[str] = None, library_name: Optional[str] = None, library_version: Optional[str] = None, cache_dir: Union[str, Path, None] = None, @@ -1035,6 +1043,9 @@ def hf_hub_download( revision (`str`, *optional*): An optional Git revision id which can be a branch name, a tag, or a commit hash. + endpoint (`str`, *optional*): + Hugging Face Hub base url. Will default to https://huggingface.co/. Otherwise, one can set the `HF_ENDPOINT` + environment variable. library_name (`str`, *optional*): The name of the library to which the object corresponds. library_version (`str`, *optional*): @@ -1116,6 +1127,7 @@ def hf_hub_download( subfolder=subfolder, repo_type=repo_type, revision=revision, + endpoint=endpoint, ) return cached_download( @@ -1175,7 +1187,7 @@ def hf_hub_download( return _to_local_dir(pointer_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks) return pointer_path - url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision) + url = hf_hub_url(repo_id, filename, repo_type=repo_type, revision=revision, endpoint=endpoint) headers = build_hf_headers( token=token, diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 7331a63607..8424635c90 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -46,6 +46,7 @@ import requests from requests.exceptions import HTTPError +from tqdm.auto import tqdm as base_tqdm from huggingface_hub.utils import ( IGNORE_GIT_FOLDER_PATTERNS, @@ -864,9 +865,8 @@ def __init__( Args: endpoint (`str`, *optional*): - Hugging Face Hub base url. Will default to https://huggingface.co/. To - be set if you are using a private hub. Otherwise, one can set the - `HF_ENDPOINT` environment variable. + Hugging Face Hub base url. Will default to https://huggingface.co/. Otherwise, + one can set the `HF_ENDPOINT` environment variable. token (`str`, *optional*): Hugging Face token. Will default to the locally saved token if not provided. @@ -3695,6 +3695,299 @@ def delete_folder( parent_commit=parent_commit, ) + @validate_hf_hub_args + def hf_hub_download( + self, + repo_id: str, + filename: str, + *, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", + force_download: bool = False, + force_filename: Optional[str] = None, + proxies: Optional[Dict] = None, + etag_timeout: float = 10, + resume_download: bool = False, + local_files_only: bool = False, + legacy_cache_layout: bool = False, + ) -> str: + """Download a given file if it's not already present in the local cache. + + The new cache file layout looks like this: + - The cache directory contains one subfolder per repo_id (namespaced by repo type) + - inside each repo folder: + - refs is a list of the latest known revision => commit_hash pairs + - blobs contains the actual file blobs (identified by their git-sha or sha256, depending on + whether they're LFS files or not) + - snapshots contains one subfolder per commit, each "commit" contains the subset of the files + that have been resolved at that particular commit. Each filename is a symlink to the blob + at that particular commit. + + If `local_dir` is provided, the file structure from the repo will be replicated in this location. You can configure + how you want to move those files: + - If `local_dir_use_symlinks="auto"` (default), files are downloaded and stored in the cache directory as blob + files. Small files (<5MB) are duplicated in `local_dir` while a symlink is created for bigger files. The goal + is to be able to manually edit and save small files without corrupting the cache while saving disk space for + binary files. The 5MB threshold can be configured with the `HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD` + environment variable. + - If `local_dir_use_symlinks=True`, files are downloaded, stored in the cache directory and symlinked in `local_dir`. + This is optimal in term of disk usage but files must not be manually edited. + - If `local_dir_use_symlinks=False` and the blob files exist in the cache directory, they are duplicated in the + local dir. This means disk usage is not optimized. + - Finally, if `local_dir_use_symlinks=False` and the blob files do not exist in the cache directory, then the + files are downloaded and directly placed under `local_dir`. This means if you need to download them again later, + they will be re-downloaded entirely. + + ``` + [ 96] . + └── [ 160] models--julien-c--EsperBERTo-small + ├── [ 160] blobs + │ ├── [321M] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + │ ├── [ 398] 7cb18dc9bafbfcf74629a4b760af1b160957a83e + │ └── [1.4K] d7edf6bd2a681fb0175f7735299831ee1b22b812 + ├── [ 96] refs + │ └── [ 40] main + └── [ 128] snapshots + ├── [ 128] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f + │ ├── [ 52] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812 + │ └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + └── [ 128] bbc77c8132af1cc5cf678da3f1ddf2de43606d48 + ├── [ 52] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e + └── [ 76] pytorch_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd + ``` + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + filename (`str`): + The name of the file in the repo. + subfolder (`str`, *optional*): + An optional value corresponding to a folder inside the model repo. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if downloading from a dataset or space, + `None` or `"model"` if downloading from a model. Default is `None`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + endpoint (`str`, *optional*): + Hugging Face Hub base url. Will default to https://huggingface.co/. Otherwise, one can set the `HF_ENDPOINT` + environment variable. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_dir (`str` or `Path`, *optional*): + If provided, the downloaded file will be placed under this directory, either as a symlink (default) or + a regular file (see description for more details). + local_dir_use_symlinks (`"auto"` or `bool`, defaults to `"auto"`): + To be used with `local_dir`. If set to "auto", the cache directory will be used and the file will be either + duplicated or symlinked to the local directory depending on its size. It set to `True`, a symlink will be + created, no matter the file size. If set to `False`, the file will either be duplicated from cache (if + already exists) or downloaded from the Hub and not cached. See description for more details. + force_download (`bool`, *optional*, defaults to `False`): + Whether the file should be downloaded even if it already exists in + the local cache. + proxies (`dict`, *optional*): + Dictionary mapping protocol to the URL of the proxy passed to + `requests.request`. + etag_timeout (`float`, *optional*, defaults to `10`): + When fetching ETag, how many seconds to wait for the server to send + data before giving up which is passed to `requests.request`. + resume_download (`bool`, *optional*, defaults to `False`): + If `True`, resume a previously interrupted download. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + legacy_cache_layout (`bool`, *optional*, defaults to `False`): + If `True`, uses the legacy file cache layout i.e. just call [`hf_hub_url`] + then `cached_download`. This is deprecated as the new cache layout is + more powerful. + + Returns: + Local path (string) of file or if networking is off, last version of + file cached on disk. + + + + Raises the following errors: + + - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + if `token=True` and the token cannot be found. + - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) + if ETag cannot be determined. + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + - [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + - [`~utils.EntryNotFoundError`] + If the file to download cannot be found. + - [`~utils.LocalEntryNotFoundError`] + If network is disabled or unavailable and file is not found in cache. + + + """ + from .file_download import hf_hub_download + + return hf_hub_download( + repo_id=repo_id, + filename=filename, + subfolder=subfolder, + repo_type=repo_type, + revision=revision, + endpoint=self.endpoint, + library_name=self.library_name, + library_version=self.library_version, + cache_dir=cache_dir, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + user_agent=self.user_agent, + force_download=force_download, + force_filename=force_filename, + proxies=proxies, + etag_timeout=etag_timeout, + resume_download=resume_download, + token=self.token, + local_files_only=local_files_only, + legacy_cache_layout=legacy_cache_layout, + ) + + @validate_hf_hub_args + def snapshot_download( + self, + repo_id: str, + *, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Union[str, Path, None] = None, + local_dir: Union[str, Path, None] = None, + local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", + proxies: Optional[Dict] = None, + etag_timeout: float = 10, + resume_download: bool = False, + force_download: bool = False, + local_files_only: bool = False, + allow_patterns: Optional[Union[List[str], str]] = None, + ignore_patterns: Optional[Union[List[str], str]] = None, + max_workers: int = 8, + tqdm_class: Optional[base_tqdm] = None, + ) -> str: + """Download repo files. + + Download a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from + a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order + to keep their actual filename relative to that folder. You can also filter which files to download using + `allow_patterns` and `ignore_patterns`. + + If `local_dir` is provided, the file structure from the repo will be replicated in this location. You can configure + how you want to move those files: + - If `local_dir_use_symlinks="auto"` (default), files are downloaded and stored in the cache directory as blob + files. Small files (<5MB) are duplicated in `local_dir` while a symlink is created for bigger files. The goal + is to be able to manually edit and save small files without corrupting the cache while saving disk space for + binary files. The 5MB threshold can be configured with the `HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD` + environment variable. + - If `local_dir_use_symlinks=True`, files are downloaded, stored in the cache directory and symlinked in `local_dir`. + This is optimal in term of disk usage but files must not be manually edited. + - If `local_dir_use_symlinks=False` and the blob files exist in the cache directory, they are duplicated in the + local dir. This means disk usage is not optimized. + - Finally, if `local_dir_use_symlinks=False` and the blob files do not exist in the cache directory, then the + files are downloaded and directly placed under `local_dir`. This means if you need to download them again later, + they will be re-downloaded entirely. + + An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly + configured. It is also not possible to filter which files to download when cloning a repository using git. + + Args: + repo_id (`str`): + A user or an organization name and a repo name separated by a `/`. + repo_type (`str`, *optional*): + Set to `"dataset"` or `"space"` if downloading from a dataset or space, + `None` or `"model"` if downloading from a model. Default is `None`. + revision (`str`, *optional*): + An optional Git revision id which can be a branch name, a tag, or a + commit hash. + cache_dir (`str`, `Path`, *optional*): + Path to the folder where cached files are stored. + local_dir (`str` or `Path`, *optional*: + If provided, the downloaded files will be placed under this directory, either as symlinks (default) or + regular files (see description for more details). + local_dir_use_symlinks (`"auto"` or `bool`, defaults to `"auto"`): + To be used with `local_dir`. If set to "auto", the cache directory will be used and the file will be either + duplicated or symlinked to the local directory depending on its size. It set to `True`, a symlink will be + created, no matter the file size. If set to `False`, the file will either be duplicated from cache (if + already exists) or downloaded from the Hub and not cached. See description for more details. + proxies (`dict`, *optional*): + Dictionary mapping protocol to the URL of the proxy passed to + `requests.request`. + etag_timeout (`float`, *optional*, defaults to `10`): + When fetching ETag, how many seconds to wait for the server to send + data before giving up which is passed to `requests.request`. + resume_download (`bool`, *optional*, defaults to `False): + If `True`, resume a previously interrupted download. + force_download (`bool`, *optional*, defaults to `False`): + Whether the file should be downloaded even if it already exists in the local cache. + local_files_only (`bool`, *optional*, defaults to `False`): + If `True`, avoid downloading the file and return the path to the + local cached file if it exists. + allow_patterns (`List[str]` or `str`, *optional*): + If provided, only files matching at least one pattern are downloaded. + ignore_patterns (`List[str]` or `str`, *optional*): + If provided, files matching any of the patterns are not downloaded. + max_workers (`int`, *optional*): + Number of concurrent threads to download files (1 thread = 1 file download). + Defaults to 8. + tqdm_class (`tqdm`, *optional*): + If provided, overwrites the default behavior for the progress bar. Passed + argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior. + Note that the `tqdm_class` is not passed to each individual download. + Defaults to the custom HF progress bar that can be disabled by setting + `HF_HUB_DISABLE_PROGRESS_BARS` environment variable. + + Returns: + Local folder path (string) of repo snapshot + + + + Raises the following errors: + + - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) + if `token=True` and the token cannot be found. + - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if + ETag cannot be determined. + - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) + if some parameter value is invalid + + + """ + from ._snapshot_download import snapshot_download + + return snapshot_download( + repo_id=repo_id, + repo_type=repo_type, + revision=revision, + endpoint=self.endpoint, + cache_dir=cache_dir, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + library_name=self.library_name, + library_version=self.library_version, + user_agent=self.user_agent, + proxies=proxies, + etag_timeout=etag_timeout, + resume_download=resume_download, + force_download=force_download, + token=self.token, + local_files_only=local_files_only, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + max_workers=max_workers, + tqdm_class=tqdm_class, + ) + @validate_hf_hub_args def create_branch( self, diff --git a/tests/test_file_download.py b/tests/test_file_download.py index 5f0a536eec..fb1d7629d5 100644 --- a/tests/test_file_download.py +++ b/tests/test_file_download.py @@ -361,6 +361,21 @@ def test_hf_hub_url_with_empty_subfolder(self): ) ) + @patch("huggingface_hub.file_download.ENDPOINT", "https://huggingface.co") + @patch( + "huggingface_hub.file_download.HUGGINGFACE_CO_URL_TEMPLATE", + "https://huggingface.co/{repo_id}/resolve/{revision}/{filename}", + ) + def test_hf_hub_url_with_endpoint(self): + self.assertEqual( + hf_hub_url( + DUMMY_MODEL_ID, + filename=CONFIG_NAME, + endpoint="https://hf-ci.co", + ), + "https://hf-ci.co/julien-c/dummy-unknown/resolve/main/config.json", + ) + def test_hf_hub_download_legacy(self): filepath = hf_hub_download( DUMMY_MODEL_ID, diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 9c540316b0..3fe74c7b36 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -2512,6 +2512,76 @@ def test_run_as_future(self, repo_url: RepoUrl) -> None: self.assertEqual(info_2.likes, 0) +class TestDownloadHfApiAlias(unittest.TestCase): + def setUp(self) -> None: + self.api = HfApi( + endpoint="https://hf.co", + token="user_token", + library_name="cool_one", + library_version="1.0.0", + user_agent="myself", + ) + return super().setUp() + + @patch("huggingface_hub.file_download.hf_hub_download") + def test_hf_hub_download_alias(self, mock: Mock) -> None: + self.api.hf_hub_download("my_repo_id", "file.txt") + mock.assert_called_once_with( + # Call values + repo_id="my_repo_id", + filename="file.txt", + # HfAPI values + endpoint="https://hf.co", + library_name="cool_one", + library_version="1.0.0", + user_agent="myself", + token="user_token", + # Default values + subfolder=None, + repo_type=None, + revision=None, + cache_dir=None, + local_dir=None, + local_dir_use_symlinks="auto", + force_download=False, + force_filename=None, + proxies=None, + etag_timeout=10, + resume_download=False, + local_files_only=False, + legacy_cache_layout=False, + ) + + @patch("huggingface_hub._snapshot_download.snapshot_download") + def test_snapshot_download_alias(self, mock: Mock) -> None: + self.api.snapshot_download("my_repo_id") + mock.assert_called_once_with( + # Call values + repo_id="my_repo_id", + # HfAPI values + endpoint="https://hf.co", + library_name="cool_one", + library_version="1.0.0", + user_agent="myself", + token="user_token", + # Default values + repo_type=None, + revision=None, + cache_dir=None, + local_dir=None, + local_dir_use_symlinks="auto", + proxies=None, + etag_timeout=10, + resume_download=False, + force_download=False, + local_files_only=False, + allow_patterns=None, + ignore_patterns=None, + max_workers=8, + tqdm_class=None, + ) + + class TestSpaceAPIMocked(unittest.TestCase): """ Testing Space hardware requests is resource intensive for the server (need to spawn