diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index 53f9697204..9a520b36de 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Dict, List, Literal, Optional, Union +import requests from tqdm.auto import tqdm as base_tqdm from tqdm.contrib.concurrent import thread_map @@ -10,12 +11,20 @@ DEFAULT_REVISION, HF_HUB_CACHE, HF_HUB_ENABLE_HF_TRANSFER, - HF_HUB_OFFLINE, REPO_TYPES, ) from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name -from .hf_api import HfApi -from .utils import filter_repo_objects, logging, validate_hf_hub_args +from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo +from .utils import ( + GatedRepoError, + LocalEntryNotFoundError, + OfflineModeIsEnabled, + RepositoryNotFoundError, + RevisionNotFoundError, + filter_repo_objects, + logging, + validate_hf_hub_args, +) from .utils import tqdm as hf_tqdm @@ -158,37 +167,97 @@ def snapshot_download( storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)) - # if we have no internet connection we will look for an - # appropriate folder in the cache - # If the specified revision is a commit hash, look inside "snapshots". - # If the specified revision is a branch or tag, look inside "refs". - if local_files_only or HF_HUB_OFFLINE: + repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None + api_call_error: Optional[Exception] = None + if not local_files_only: + # try/except logic to handle different errors => taken from `hf_hub_download` + try: + # if we have internet connection we want to list files to download + api = HfApi( + library_name=library_name, library_version=library_version, user_agent=user_agent, endpoint=endpoint + ) + repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token) + except (requests.exceptions.SSLError, requests.exceptions.ProxyError): + # Actually raise for those subclasses of ConnectionError + raise + except ( + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + OfflineModeIsEnabled, + ) as error: + # Internet connection is down + # => will try to use local files only + api_call_error = error + pass + except RevisionNotFoundError: + # The repo was found but the revision doesn't exist on the Hub (never existed or got deleted) + raise + except requests.HTTPError as error: + # Multiple reasons for an http error: + # - Repository is private and invalid/missing token sent + # - Repository is gated and invalid/missing token sent + # - Hub is down (error 500 or 504) + # => let's switch to 'local_files_only=True' to check if the files are already cached. + # (if it's not the case, the error will be re-raised) + api_call_error = error + pass + + # At this stage, if `repo_info` is None it means either: + # - internet connection is down + # - internet connection is deactivated (local_files_only=True or HF_HUB_OFFLINE=True) + # - repo is private/gated and invalid/missing token sent + # - Hub is down + # => let's look if we can find the appropriate folder in the cache: + # - if the specified revision is a commit hash, look inside "snapshots". + # - f the specified revision is a branch or tag, look inside "refs". + if repo_info is None: + # Try to get which commit hash corresponds to the specified revision + commit_hash = None if REGEX_COMMIT_HASH.match(revision): commit_hash = revision else: - # retrieve commit_hash from file ref_path = os.path.join(storage_folder, "refs", revision) - with open(ref_path) as f: - commit_hash = f.read() - - snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) + if os.path.exists(ref_path): + # retrieve commit_hash from refs file + with open(ref_path) as f: + commit_hash = f.read() - if os.path.exists(snapshot_folder): - return snapshot_folder + # Try to locate snapshot folder for this commit hash + if commit_hash is not None: + snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash) + if os.path.exists(snapshot_folder): + # Snapshot folder exists => let's return it + # (but we can't check if all the files are actually there) + return snapshot_folder - raise ValueError( - "Cannot find an appropriate cached snapshot folder for the specified" - " revision on the local disk and outgoing traffic has been disabled. To" - " enable repo look-ups and downloads online, set 'local_files_only' to" - " False." - ) + # If we couldn't find the appropriate folder on disk, raise an error. + if local_files_only: + raise LocalEntryNotFoundError( + "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and " + "outgoing traffic has been disabled. To enable repo look-ups and downloads online, pass " + "'local_files_only=False' as input." + ) + elif isinstance(api_call_error, OfflineModeIsEnabled): + raise LocalEntryNotFoundError( + "Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and " + "outgoing traffic has been disabled. To enable repo look-ups and downloads online, set " + "'HF_HUB_OFFLINE=0' as environment variable." + ) from api_call_error + elif isinstance(api_call_error, RepositoryNotFoundError) or isinstance(api_call_error, GatedRepoError): + # Repo not found => let's raise the actual error + raise api_call_error + else: + # Otherwise: most likely a connection issue or Hub downtime => let's warn the user + raise LocalEntryNotFoundError( + "An error happened while trying to locate the files on the Hub and we cannot find the appropriate" + " snapshot folder for the specified revision on the local disk. Please check your internet connection" + " and try again." + ) from api_call_error - # if we have internet connection we retrieve the correct folder name from the huggingface api - api = HfApi(library_name=library_name, library_version=library_version, user_agent=user_agent, endpoint=endpoint) - repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token) + # At this stage, internet connection is up and running + # => let's download the files! assert repo_info.sha is not None, "Repo info returned from server must have a revision sha." assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list." - filtered_repo_files = list( filter_repo_objects( items=[f.rfilename for f in repo_info.siblings], diff --git a/src/huggingface_hub/utils/_errors.py b/src/huggingface_hub/utils/_errors.py index 787495a1c4..37bddf331d 100644 --- a/src/huggingface_hub/utils/_errors.py +++ b/src/huggingface_hub/utils/_errors.py @@ -192,7 +192,7 @@ class EntryNotFoundError(HfHubHTTPError): class LocalEntryNotFoundError(EntryNotFoundError, FileNotFoundError, ValueError): """ - Raised when trying to access a file that is not on the disk when network is + Raised when trying to access a file or snapshot that is not on the disk when network is disabled or unavailable (connection issue). The entry may exist on the Hub. Note: `ValueError` type is to ensure backward compatibility. diff --git a/tests/test_snapshot_download.py b/tests/test_snapshot_download.py index 8e13dc75f7..865c464b59 100644 --- a/tests/test_snapshot_download.py +++ b/tests/test_snapshot_download.py @@ -3,13 +3,11 @@ from pathlib import Path from unittest.mock import patch -import requests - from huggingface_hub import CommitOperationAdd, HfApi, snapshot_download -from huggingface_hub.utils import SoftTemporaryDirectory +from huggingface_hub.utils import LocalEntryNotFoundError, RepositoryNotFoundError, SoftTemporaryDirectory from .testing_constants import TOKEN -from .testing_utils import repo_name +from .testing_utils import OfflineSimulationMode, offline, repo_name class SnapshotDownloadTests(unittest.TestCase): @@ -101,7 +99,7 @@ def test_download_private_model(self): # Test download fails without token with SoftTemporaryDirectory() as tmpdir: - with self.assertRaisesRegex(requests.exceptions.HTTPError, "401 Client Error"): + with self.assertRaises(RepositoryNotFoundError): _ = snapshot_download(self.repo_id, revision="main", cache_dir=tmpdir) # Test we can download with token from cache @@ -144,6 +142,18 @@ def test_download_model_local_only(self): ) self.assertTrue(self.first_commit_hash in storage_folder) # has expected revision + def test_download_model_offline_mode_not_cached(self): + """Test when connection error but cache is empty.""" + with SoftTemporaryDirectory() as tmpdir: + with self.assertRaises(LocalEntryNotFoundError): + snapshot_download(self.repo_id, cache_dir=tmpdir, local_files_only=True) + + for offline_mode in OfflineSimulationMode: + with offline(mode=offline_mode): + with SoftTemporaryDirectory() as tmpdir: + with self.assertRaises(LocalEntryNotFoundError): + snapshot_download(self.repo_id, cache_dir=tmpdir) + def test_download_model_local_only_multiple(self): # cache multiple commits and make sure correct commit is taken with SoftTemporaryDirectory() as tmpdir: