Skip to content

Commit

Permalink
Fix local_files_only for checkpoints with shards
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Dec 18, 2024
1 parent f35a387 commit 233fc77
Showing 1 changed file with 29 additions and 38 deletions.
67 changes: 29 additions & 38 deletions src/diffusers/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,48 +455,39 @@ def _get_checkpoint_shard_files(
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]

ignore_patterns = ["*.json", "*.md"]
if not local_files_only:
# `model_info` call must guarded with the above condition.
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:
raise EnvironmentError(
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)

try:
# Load from URL
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
if subfolder is not None:
cached_folder = os.path.join(cached_folder, subfolder)

# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
except HTTPError as e:
# `model_info` call must guarded with the above condition.
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
) from e
f"{shards_path} does not appear to have a file named {shard_file} which is "
"required according to the checkpoint index."
)

# If `local_files_only=True`, `cached_folder` may not contain all the shard files.
elif local_files_only:
_check_if_shards_exist_locally(
local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames
try:
# Load from URL
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
if subfolder is not None:
cached_folder = os.path.join(cache_dir, subfolder)
cached_folder = os.path.join(cached_folder, subfolder)

# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
except HTTPError as e:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
) from e

return cached_folder, sharded_metadata

Expand Down

0 comments on commit 233fc77

Please sign in to comment.