Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve snapshot_download offline mode #1913

Merged
merged 6 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 94 additions & 25 deletions src/huggingface_hub/_snapshot_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union

import requests
from tqdm.auto import tqdm as base_tqdm
from tqdm.contrib.concurrent import thread_map

Expand All @@ -10,12 +11,20 @@
DEFAULT_REVISION,
HF_HUB_CACHE,
HF_HUB_ENABLE_HF_TRANSFER,
HF_HUB_OFFLINE,
REPO_TYPES,
)
from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
from .hf_api import HfApi
from .utils import filter_repo_objects, logging, validate_hf_hub_args
from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo
from .utils import (
GatedRepoError,
LocalEntryNotFoundError,
OfflineModeIsEnabled,
RepositoryNotFoundError,
RevisionNotFoundError,
filter_repo_objects,
logging,
validate_hf_hub_args,
)
from .utils import tqdm as hf_tqdm


Expand Down Expand Up @@ -158,37 +167,97 @@ def snapshot_download(

storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))

# if we have no internet connection we will look for an
# appropriate folder in the cache
# If the specified revision is a commit hash, look inside "snapshots".
# If the specified revision is a branch or tag, look inside "refs".
if local_files_only or HF_HUB_OFFLINE:
repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None
api_call_error: Optional[Exception] = None
if not local_files_only:
# try/except logic to handle different errors => taken from `hf_hub_download`
try:
# if we have internet connection we want to list files to download
api = HfApi(
library_name=library_name, library_version=library_version, user_agent=user_agent, endpoint=endpoint
)
repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token)
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
# Actually raise for those subclasses of ConnectionError
raise
except (
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
OfflineModeIsEnabled,
) as error:
# Internet connection is down
# => will try to use local files only
api_call_error = error
pass
except RevisionNotFoundError:
# The repo was found but the revision doesn't exist on the Hub (never existed or got deleted)
raise
except requests.HTTPError as error:
# Multiple reasons for an http error:
# - Repository is private and invalid/missing token sent
# - Repository is gated and invalid/missing token sent
# - Hub is down (error 500 or 504)
# => let's switch to 'local_files_only=True' to check if the files are already cached.
# (if it's not the case, the error will be re-raised)
api_call_error = error
pass

# At this stage, if `repo_info` is None it means either:
# - internet connection is down
# - internet connection is deactivated (local_files_only=True or HF_HUB_OFFLINE=True)
# - repo is private/gated and invalid/missing token sent
# - Hub is down
# => let's look if we can find the appropriate folder in the cache:
# - if the specified revision is a commit hash, look inside "snapshots".
# - f the specified revision is a branch or tag, look inside "refs".
if repo_info is None:
# Try to get which commit hash corresponds to the specified revision
commit_hash = None
if REGEX_COMMIT_HASH.match(revision):
commit_hash = revision
else:
# retrieve commit_hash from file
ref_path = os.path.join(storage_folder, "refs", revision)
with open(ref_path) as f:
commit_hash = f.read()

snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
if os.path.exists(ref_path):
# retrieve commit_hash from refs file
with open(ref_path) as f:
commit_hash = f.read()

if os.path.exists(snapshot_folder):
return snapshot_folder
# Try to locate snapshot folder for this commit hash
if commit_hash is not None:
snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
if os.path.exists(snapshot_folder):
# Snapshot folder exists => let's return it
# (but we can't check if all the files are actually there)
return snapshot_folder

raise ValueError(
"Cannot find an appropriate cached snapshot folder for the specified"
" revision on the local disk and outgoing traffic has been disabled. To"
" enable repo look-ups and downloads online, set 'local_files_only' to"
" False."
)
# If we couldn't find the appropriate folder on disk, raise an error.
if local_files_only:
raise LocalEntryNotFoundError(
"Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
"outgoing traffic has been disabled. To enable repo look-ups and downloads online, pass "
"'local_files_only=False' as input."
)
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(api_call_error, OfflineModeIsEnabled):
raise LocalEntryNotFoundError(
"Cannot find an appropriate cached snapshot folder for the specified revision on the local disk and "
"outgoing traffic has been disabled. To enable repo look-ups and downloads online, set "
"'HF_HUB_OFFLINE' as environment variable."
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
) from api_call_error
elif isinstance(api_call_error, RepositoryNotFoundError) or isinstance(api_call_error, GatedRepoError):
# Repo not found => let's raise the actual error
raise api_call_error
else:
# Otherwise: most likely a connection issue or Hub downtime => let's warn the user
raise LocalEntryNotFoundError(
"An error happened while trying to locate the files on the Hub and we cannot find the appropriate"
" snapshot folder for the specified revision on the local disk. Please check your internet connection"
" and try again."
) from api_call_error

# if we have internet connection we retrieve the correct folder name from the huggingface api
api = HfApi(library_name=library_name, library_version=library_version, user_agent=user_agent, endpoint=endpoint)
repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token)
# At this stage, internet connection is up and running
# => let's download the files!
assert repo_info.sha is not None, "Repo info returned from server must have a revision sha."
assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list."

filtered_repo_files = list(
filter_repo_objects(
items=[f.rfilename for f in repo_info.siblings],
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/utils/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class EntryNotFoundError(HfHubHTTPError):

class LocalEntryNotFoundError(EntryNotFoundError, FileNotFoundError, ValueError):
"""
Raised when trying to access a file that is not on the disk when network is
Raised when trying to access a file or snapshot that is not on the disk when network is
disabled or unavailable (connection issue). The entry may exist on the Hub.

Note: `ValueError` type is to ensure backward compatibility.
Expand Down
20 changes: 15 additions & 5 deletions tests/test_snapshot_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
from pathlib import Path
from unittest.mock import patch

import requests

from huggingface_hub import CommitOperationAdd, HfApi, snapshot_download
from huggingface_hub.utils import SoftTemporaryDirectory
from huggingface_hub.utils import LocalEntryNotFoundError, RepositoryNotFoundError, SoftTemporaryDirectory

from .testing_constants import TOKEN
from .testing_utils import repo_name
from .testing_utils import OfflineSimulationMode, offline, repo_name


class SnapshotDownloadTests(unittest.TestCase):
Expand Down Expand Up @@ -101,7 +99,7 @@ def test_download_private_model(self):

# Test download fails without token
with SoftTemporaryDirectory() as tmpdir:
with self.assertRaisesRegex(requests.exceptions.HTTPError, "401 Client Error"):
with self.assertRaises(RepositoryNotFoundError):
_ = snapshot_download(self.repo_id, revision="main", cache_dir=tmpdir)

# Test we can download with token from cache
Expand Down Expand Up @@ -144,6 +142,18 @@ def test_download_model_local_only(self):
)
self.assertTrue(self.first_commit_hash in storage_folder) # has expected revision

def test_download_model_offline_mode_not_cached(self):
"""Test when connection error but cache is empty."""
with SoftTemporaryDirectory() as tmpdir:
with self.assertRaises(LocalEntryNotFoundError):
snapshot_download(self.repo_id, cache_dir=tmpdir, local_files_only=True)

for offline_mode in OfflineSimulationMode:
with offline(mode=offline_mode):
with SoftTemporaryDirectory() as tmpdir:
with self.assertRaises(LocalEntryNotFoundError):
snapshot_download(self.repo_id, cache_dir=tmpdir)

def test_download_model_local_only_multiple(self):
# cache multiple commits and make sure correct commit is taken
with SoftTemporaryDirectory() as tmpdir:
Expand Down
Loading