From 27909c8368d7d46b7dece0d369cbbd9eda209154 Mon Sep 17 00:00:00 2001
From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Date: Fri, 5 Aug 2022 10:12:40 -0400
Subject: [PATCH] Use new huggingface_hub tools for download models (#18438)
* Draft new cached_file
* Initial draft for config and model
* Small fixes
* Fix first batch of tests
* Look in cache when internet is down
* Fix last tests
* Bad black, not fixing all quality errors
* Make diff less
* Implement change for TF and Flax models
* Add tokenizer and feature extractor
* For compatibility with main
* Add utils to move the cache and auto-do it at first use.
* Quality
* Deal with empty commit shas
* Deal with empty etag
* Address review comments
---
src/transformers/configuration_utils.py | 118 ++---
src/transformers/feature_extraction_utils.py | 103 ++---
src/transformers/modeling_flax_utils.py | 139 +++---
src/transformers/modeling_tf_utils.py | 139 +++---
src/transformers/modeling_utils.py | 136 ++----
src/transformers/tokenization_utils_base.py | 94 +---
src/transformers/utils/__init__.py | 2 +
src/transformers/utils/hub.py | 456 +++++++++++++++++--
tests/test_configuration_common.py | 4 +-
tests/test_feature_extraction_common.py | 4 +-
tests/test_modeling_common.py | 4 +-
tests/test_modeling_tf_common.py | 4 +-
tests/test_tokenization_common.py | 4 +-
13 files changed, 662 insertions(+), 545 deletions(-)
diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py
index fe2d6b3aaef637..b10475127b4fce 100755
--- a/src/transformers/configuration_utils.py
+++ b/src/transformers/configuration_utils.py
@@ -25,25 +25,9 @@
from packaging import version
-from requests import HTTPError
-
from . import __version__
from .dynamic_module_utils import custom_object_save
-from .utils import (
- CONFIG_NAME,
- HUGGINGFACE_CO_RESOLVE_ENDPOINT,
- EntryNotFoundError,
- PushToHubMixin,
- RepositoryNotFoundError,
- RevisionNotFoundError,
- cached_path,
- copy_func,
- hf_bucket_url,
- is_offline_mode,
- is_remote_url,
- is_torch_available,
- logging,
-)
+from .utils import CONFIG_NAME, PushToHubMixin, cached_file, copy_func, is_torch_available, logging
logger = logging.get_logger(__name__)
@@ -591,77 +575,43 @@ def _get_config_dict(
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline
- if is_offline_mode() and not local_files_only:
- logger.info("Offline mode: forcing local_files_only=True")
- local_files_only = True
-
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
- if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
- pretrained_model_name_or_path
- ):
- config_file = pretrained_model_name_or_path
+
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
+ # Soecial case when pretrained_model_name_or_path is a local file
+ resolved_config_file = pretrained_model_name_or_path
+ is_local = True
else:
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
- if os.path.isdir(os.path.join(pretrained_model_name_or_path, subfolder)):
- config_file = os.path.join(pretrained_model_name_or_path, subfolder, configuration_file)
- else:
- config_file = hf_bucket_url(
+ try:
+ # Load from local folder or from cache or download from model Hub and cache
+ resolved_config_file = cached_file(
pretrained_model_name_or_path,
- filename=configuration_file,
+ configuration_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
revision=revision,
- subfolder=subfolder if len(subfolder) > 0 else None,
- mirror=None,
+ subfolder=subfolder,
+ )
+ except EnvironmentError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
+ # the original exception.
+ raise
+ except Exception:
+ # For any other exception, we throw a generic error.
+ raise EnvironmentError(
+ f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
+ " from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
+ f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
+ f" containing a {configuration_file} file"
)
-
- try:
- # Load from URL or cache if already cached
- resolved_config_file = cached_path(
- config_file,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- user_agent=user_agent,
- )
-
- except RepositoryNotFoundError:
- raise EnvironmentError(
- f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
- "'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
- "permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
- "`use_auth_token=True`."
- )
- except RevisionNotFoundError:
- raise EnvironmentError(
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
- f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
- "available revisions."
- )
- except EntryNotFoundError:
- raise EnvironmentError(
- f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
- )
- except HTTPError as err:
- raise EnvironmentError(
- f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
- )
- except ValueError:
- raise EnvironmentError(
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
- f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
- f" containing a {configuration_file} file.\nCheckout your internet connection or see how to run the"
- " library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
- )
- except EnvironmentError:
- raise EnvironmentError(
- f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
- "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
- f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
- f"containing a {configuration_file} file"
- )
try:
# Load config dict
@@ -671,10 +621,10 @@ def _get_config_dict(
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
)
- if resolved_config_file == config_file:
- logger.info(f"loading configuration file {config_file}")
+ if is_local:
+ logger.info(f"loading configuration file {resolved_config_file}")
else:
- logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
+ logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
return config_dict, kwargs
diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py
index b411d744284665..ec68f355191c1d 100644
--- a/src/transformers/feature_extraction_utils.py
+++ b/src/transformers/feature_extraction_utils.py
@@ -24,23 +24,15 @@
import numpy as np
-from requests import HTTPError
-
from .dynamic_module_utils import custom_object_save
from .utils import (
FEATURE_EXTRACTOR_NAME,
- HUGGINGFACE_CO_RESOLVE_ENDPOINT,
- EntryNotFoundError,
PushToHubMixin,
- RepositoryNotFoundError,
- RevisionNotFoundError,
TensorType,
- cached_path,
+ cached_file,
copy_func,
- hf_bucket_url,
is_flax_available,
is_offline_mode,
- is_remote_url,
is_tf_available,
is_torch_available,
logging,
@@ -388,64 +380,40 @@ def get_feature_extractor_dict(
local_files_only = True
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
- elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
- feature_extractor_file = pretrained_model_name_or_path
+ if os.path.isfile(pretrained_model_name_or_path):
+ resolved_feature_extractor_file = pretrained_model_name_or_path
+ is_local = True
else:
- feature_extractor_file = hf_bucket_url(
- pretrained_model_name_or_path, filename=FEATURE_EXTRACTOR_NAME, revision=revision, mirror=None
- )
-
- try:
- # Load from URL or cache if already cached
- resolved_feature_extractor_file = cached_path(
- feature_extractor_file,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- user_agent=user_agent,
- )
-
- except RepositoryNotFoundError:
- raise EnvironmentError(
- f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
- "'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
- "permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
- "`use_auth_token=True`."
- )
- except RevisionNotFoundError:
- raise EnvironmentError(
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
- f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
- "available revisions."
- )
- except EntryNotFoundError:
- raise EnvironmentError(
- f"{pretrained_model_name_or_path} does not appear to have a file named {FEATURE_EXTRACTOR_NAME}."
- )
- except HTTPError as err:
- raise EnvironmentError(
- f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
- )
- except ValueError:
- raise EnvironmentError(
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
- f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
- f" containing a {FEATURE_EXTRACTOR_NAME} file.\nCheckout your internet connection or see how to run"
- " the library in offline mode at"
- " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
- )
- except EnvironmentError:
- raise EnvironmentError(
- f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load it "
- "from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
- f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
- f"containing a {FEATURE_EXTRACTOR_NAME} file"
- )
+ feature_extractor_file = FEATURE_EXTRACTOR_NAME
+ try:
+ # Load from local folder or from cache or download from model Hub and cache
+ resolved_feature_extractor_file = cached_file(
+ pretrained_model_name_or_path,
+ feature_extractor_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ revision=revision,
+ )
+ except EnvironmentError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
+ # the original exception.
+ raise
+ except Exception:
+ # For any other exception, we throw a generic error.
+ raise EnvironmentError(
+ f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+ f" directory containing a {FEATURE_EXTRACTOR_NAME} file"
+ )
try:
# Load feature_extractor dict
@@ -458,12 +426,11 @@ def get_feature_extractor_dict(
f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
)
- if resolved_feature_extractor_file == feature_extractor_file:
- logger.info(f"loading feature extractor configuration file {feature_extractor_file}")
+ if is_local:
+ logger.info(f"loading configuration file {resolved_feature_extractor_file}")
else:
logger.info(
- f"loading feature extractor configuration file {feature_extractor_file} from cache at"
- f" {resolved_feature_extractor_file}"
+ f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}"
)
return feature_extractor_dict, kwargs
diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py
index 0dcb3bc959e83d..af75b418cad23e 100644
--- a/src/transformers/modeling_flax_utils.py
+++ b/src/transformers/modeling_flax_utils.py
@@ -32,7 +32,6 @@
from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey
-from requests import HTTPError
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
@@ -41,20 +40,14 @@
from .utils import (
FLAX_WEIGHTS_INDEX_NAME,
FLAX_WEIGHTS_NAME,
- HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME,
- EntryNotFoundError,
PushToHubMixin,
- RepositoryNotFoundError,
- RevisionNotFoundError,
add_code_sample_docstrings,
add_start_docstrings_to_model_forward,
- cached_path,
+ cached_file,
copy_func,
has_file,
- hf_bucket_url,
is_offline_mode,
- is_remote_url,
logging,
replace_return_docstrings,
)
@@ -557,6 +550,9 @@ def from_pretrained(
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
@@ -598,6 +594,7 @@ def from_pretrained(
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_do_init = kwargs.pop("_do_init", True)
+ subfolder = kwargs.pop("subfolder", "")
if trust_remote_code is True:
logger.warning(
@@ -642,6 +639,8 @@ def from_pretrained(
# Load model
if pretrained_model_name_or_path is not None:
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
@@ -665,65 +664,44 @@ def from_pretrained(
f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_model_name_or_path}."
)
- elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
+ elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
+ is_local = True
else:
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
- archive_file = hf_bucket_url(
- pretrained_model_name_or_path,
- filename=filename,
- revision=revision,
- )
-
- # redirect to the cache, if necessary
+ try:
+ # Load from URL or cache if already cached
+ cached_file_kwargs = dict(
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ )
+ resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
- try:
- resolved_archive_file = cached_path(
- archive_file,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- user_agent=user_agent,
- )
-
- except RepositoryNotFoundError:
- raise EnvironmentError(
- f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
- "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
- "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
- "login` and pass `use_auth_token=True`."
- )
- except RevisionNotFoundError:
- raise EnvironmentError(
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
- "this model name. Check the model page at "
- f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
- )
- except EntryNotFoundError:
- if filename == FLAX_WEIGHTS_NAME:
- try:
+ # Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
+ # result when internet is up, the repo and revision exist, but the file does not.
+ if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
- archive_file = hf_bucket_url(
- pretrained_model_name_or_path,
- filename=FLAX_WEIGHTS_INDEX_NAME,
- revision=revision,
- )
- resolved_archive_file = cached_path(
- archive_file,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- user_agent=user_agent,
+ resolved_archive_file = cached_file(
+ pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
- is_sharded = True
- except EntryNotFoundError:
- has_file_kwargs = {"revision": revision, "proxies": proxies, "use_auth_token": use_auth_token}
+ if resolved_archive_file is not None:
+ is_sharded = True
+ if resolved_archive_file is None:
+ # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
+ # message.
+ has_file_kwargs = {
+ "revision": revision,
+ "proxies": proxies,
+ "use_auth_token": use_auth_token,
+ }
if has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
@@ -735,35 +713,24 @@ def from_pretrained(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
- else:
+ except EnvironmentError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
+ # to the original exception.
+ raise
+ except Exception:
+ # For any other exception, we throw a generic error.
raise EnvironmentError(
- f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
+ " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+ f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
- except HTTPError as err:
- raise EnvironmentError(
- f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
- f"{err}"
- )
- except ValueError:
- raise EnvironmentError(
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
- f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
- f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your"
- " internet connection or see how to run the library in offline mode at"
- " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
- )
- except EnvironmentError:
- raise EnvironmentError(
- f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
- "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
- f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
- f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
- )
-
- if resolved_archive_file == archive_file:
+
+ if is_local:
logger.info(f"loading weights file {archive_file}")
+ resolved_archive_file = archive_file
else:
- logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
+ logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else:
resolved_archive_file = None
diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py
index e1d8f5b7957be5..1a63d32e4196a0 100644
--- a/src/transformers/modeling_tf_utils.py
+++ b/src/transformers/modeling_tf_utils.py
@@ -37,7 +37,6 @@
from huggingface_hub import Repository, list_repo_files
from keras.saving.hdf5_format import save_attributes_to_hdf5_group
-from requests import HTTPError
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from . import DataCollatorWithPadding, DefaultDataCollator
@@ -48,22 +47,16 @@
from .tf_utils import shape_list
from .utils import (
DUMMY_INPUTS,
- HUGGINGFACE_CO_RESOLVE_ENDPOINT,
TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
- EntryNotFoundError,
ModelOutput,
PushToHubMixin,
- RepositoryNotFoundError,
- RevisionNotFoundError,
- cached_path,
+ cached_file,
find_labels,
has_file,
- hf_bucket_url,
is_offline_mode,
- is_remote_url,
logging,
requires_backends,
working_or_temp_dir,
@@ -2112,6 +2105,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
@@ -2164,6 +2160,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
load_weight_prefix = kwargs.pop("load_weight_prefix", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
+ subfolder = kwargs.pop("subfolder", "")
if trust_remote_code is True:
logger.warning(
@@ -2202,9 +2199,10 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded = False
- sharded_metadata = None
# Load model
if pretrained_model_name_or_path is not None:
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint in priority if from_pt
@@ -2232,68 +2230,43 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
f"Error no file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
f"{pretrained_model_name_or_path}."
)
- elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
+ elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
+ is_local = True
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
archive_file = pretrained_model_name_or_path + ".index"
+ is_local = True
else:
+ # set correct filename
filename = WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME
- archive_file = hf_bucket_url(
- pretrained_model_name_or_path,
- filename=filename,
- revision=revision,
- mirror=mirror,
- )
- try:
- # Load from URL or cache if already cached
- resolved_archive_file = cached_path(
- archive_file,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- user_agent=user_agent,
- )
+ try:
+ # Load from URL or cache if already cached
+ cached_file_kwargs = dict(
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ )
+ resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
- except RepositoryNotFoundError:
- raise EnvironmentError(
- f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
- "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
- "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
- "login` and pass `use_auth_token=True`."
- )
- except RevisionNotFoundError:
- raise EnvironmentError(
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
- "this model name. Check the model page at "
- f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
- )
- except EntryNotFoundError:
- if filename == TF2_WEIGHTS_NAME:
- try:
+ # Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
+ # result when internet is up, the repo and revision exist, but the file does not.
+ if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
- archive_file = hf_bucket_url(
- pretrained_model_name_or_path,
- filename=TF2_WEIGHTS_INDEX_NAME,
- revision=revision,
- mirror=mirror,
- )
- resolved_archive_file = cached_path(
- archive_file,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- user_agent=user_agent,
+ resolved_archive_file = cached_file(
+ pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
- is_sharded = True
- except EntryNotFoundError:
- # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
+ if resolved_archive_file is not None:
+ is_sharded = True
+ if resolved_archive_file is None:
+ # Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error
# message.
has_file_kwargs = {
"revision": revision,
@@ -2312,42 +2285,32 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
- else:
+
+ except EnvironmentError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
+ # to the original exception.
+ raise
+ except Exception:
+ # For any other exception, we throw a generic error.
+
raise EnvironmentError(
- f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
+ " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+ f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
)
- except HTTPError as err:
- raise EnvironmentError(
- f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
- f"{err}"
- )
- except ValueError:
- raise EnvironmentError(
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
- f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
- f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}.\nCheckout your internet"
- " connection or see how to run the library in offline mode at"
- " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
- )
- except EnvironmentError:
- raise EnvironmentError(
- f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
- "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
- f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
- f"containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
- )
-
- if resolved_archive_file == archive_file:
+ if is_local:
logger.info(f"loading weights file {archive_file}")
+ resolved_archive_file = archive_file
else:
- logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
+ logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else:
resolved_archive_file = None
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
- resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
+ resolved_archive_file, _ = get_checkpoint_shard_files(
pretrained_model_name_or_path,
resolved_archive_file,
cache_dir=cache_dir,
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 5cd458d1f9d587..8709ec66365c66 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -31,7 +31,6 @@
from torch import Tensor, device, nn
from torch.nn import CrossEntropyLoss
-from requests import HTTPError
from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from transformers.utils.import_utils import is_sagemaker_mp_enabled
@@ -51,24 +50,18 @@
from .utils import (
DUMMY_INPUTS,
FLAX_WEIGHTS_NAME,
- HUGGINGFACE_CO_RESOLVE_ENDPOINT,
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
ContextManagers,
- EntryNotFoundError,
ModelOutput,
PushToHubMixin,
- RepositoryNotFoundError,
- RevisionNotFoundError,
- cached_path,
+ cached_file,
copy_func,
has_file,
- hf_bucket_url,
is_accelerate_available,
is_offline_mode,
- is_remote_url,
logging,
replace_return_docstrings,
)
@@ -1868,7 +1861,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
- if os.path.isdir(pretrained_model_name_or_path):
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if is_local:
if from_tf and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
):
@@ -1911,10 +1905,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or "
f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
)
- elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
- pretrained_model_name_or_path
- ):
+ elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
archive_file = pretrained_model_name_or_path
+ is_local = True
elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
if not from_tf:
raise ValueError(
@@ -1922,6 +1915,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
"from_tf to True to load from this checkpoint."
)
archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
+ is_local = True
else:
# set correct filename
if from_tf:
@@ -1931,63 +1925,32 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
filename = WEIGHTS_NAME
- archive_file = hf_bucket_url(
- pretrained_model_name_or_path,
- filename=filename,
- revision=revision,
- mirror=mirror,
- subfolder=subfolder if len(subfolder) > 0 else None,
- )
-
- try:
- # Load from URL or cache if already cached
- resolved_archive_file = cached_path(
- archive_file,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- user_agent=user_agent,
- )
+ try:
+ # Load from URL or cache if already cached
+ cached_file_kwargs = dict(
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ )
+ resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
- except RepositoryNotFoundError:
- raise EnvironmentError(
- f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
- "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
- "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
- "login` and pass `use_auth_token=True`."
- )
- except RevisionNotFoundError:
- raise EnvironmentError(
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
- "this model name. Check the model page at "
- f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
- )
- except EntryNotFoundError:
- if filename == WEIGHTS_NAME:
- try:
+ # Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
+ # result when internet is up, the repo and revision exist, but the file does not.
+ if resolved_archive_file is None and filename == WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
- archive_file = hf_bucket_url(
- pretrained_model_name_or_path,
- filename=WEIGHTS_INDEX_NAME,
- revision=revision,
- mirror=mirror,
- subfolder=subfolder if len(subfolder) > 0 else None,
- )
- resolved_archive_file = cached_path(
- archive_file,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- user_agent=user_agent,
+ resolved_archive_file = cached_file(
+ pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
- is_sharded = True
- except EntryNotFoundError:
+ if resolved_archive_file is not None:
+ is_sharded = True
+ if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.
has_file_kwargs = {
@@ -2013,42 +1976,31 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
)
- else:
+ except EnvironmentError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
+ # to the original exception.
+ raise
+ except Exception:
+ # For any other exception, we throw a generic error.
raise EnvironmentError(
- f"{pretrained_model_name_or_path} does not appear to have a file named {filename}."
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
+ " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+ f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
+ f" {FLAX_WEIGHTS_NAME}."
)
- except HTTPError as err:
- raise EnvironmentError(
- f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
- f"{err}"
- )
- except ValueError:
- raise EnvironmentError(
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
- f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
- f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or"
- f" {FLAX_WEIGHTS_NAME}.\nCheckout your internet connection or see how to run the library in"
- " offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
- )
- except EnvironmentError:
- raise EnvironmentError(
- f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
- "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
- f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
- f"containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or "
- f"{FLAX_WEIGHTS_NAME}."
- )
- if resolved_archive_file == archive_file:
+ if is_local:
logger.info(f"loading weights file {archive_file}")
+ resolved_archive_file = archive_file
else:
- logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
+ logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
else:
resolved_archive_file = None
# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
if is_sharded:
- # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
+ # rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path,
resolved_archive_file,
diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py
index 8d24baf05bdb86..fc1c0ff8da3b32 100644
--- a/src/transformers/tokenization_utils_base.py
+++ b/src/transformers/tokenization_utils_base.py
@@ -35,21 +35,16 @@
from . import __version__
from .dynamic_module_utils import custom_object_save
from .utils import (
- EntryNotFoundError,
ExplicitEnum,
PaddingStrategy,
PushToHubMixin,
- RepositoryNotFoundError,
- RevisionNotFoundError,
TensorType,
add_end_docstrings,
- cached_path,
+ cached_file,
copy_func,
get_file_from_repo,
- hf_bucket_url,
is_flax_available,
is_offline_mode,
- is_remote_url,
is_tf_available,
is_tokenizers_available,
is_torch_available,
@@ -1669,7 +1664,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
vocab_files = {}
init_configuration = {}
- if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if os.path.isfile(pretrained_model_name_or_path):
if len(cls.vocab_files_names) > 1:
raise ValueError(
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not "
@@ -1689,9 +1685,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
}
- vocab_files_target = {**cls.vocab_files_names, **additional_files_names}
+ vocab_files = {**cls.vocab_files_names, **additional_files_names}
- if "tokenizer_file" in vocab_files_target:
+ if "tokenizer_file" in vocab_files:
# Try to get the tokenizer config to see if there are versioned tokenizer files.
fast_tokenizer_file = FULL_TOKENIZER_FILE
resolved_config_file = get_file_from_repo(
@@ -1704,80 +1700,38 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
+ subfolder=subfolder,
)
if resolved_config_file is not None:
with open(resolved_config_file, encoding="utf-8") as reader:
tokenizer_config = json.load(reader)
if "fast_tokenizer_files" in tokenizer_config:
fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
- vocab_files_target["tokenizer_file"] = fast_tokenizer_file
-
- # Look for the tokenizer files
- for file_id, file_name in vocab_files_target.items():
- if os.path.isdir(pretrained_model_name_or_path):
- if subfolder is not None:
- full_file_name = os.path.join(pretrained_model_name_or_path, subfolder, file_name)
- else:
- full_file_name = os.path.join(pretrained_model_name_or_path, file_name)
- if not os.path.exists(full_file_name):
- logger.info(f"Didn't find file {full_file_name}. We won't load it.")
- full_file_name = None
- else:
- full_file_name = hf_bucket_url(
- pretrained_model_name_or_path,
- filename=file_name,
- subfolder=subfolder,
- revision=revision,
- mirror=None,
- )
-
- vocab_files[file_id] = full_file_name
+ vocab_files["tokenizer_file"] = fast_tokenizer_file
# Get files from url, cache, or disk depending on the case
resolved_vocab_files = {}
unresolved_files = []
for file_id, file_path in vocab_files.items():
+ print(file_id, file_path)
if file_path is None:
resolved_vocab_files[file_id] = None
else:
- try:
- resolved_vocab_files[file_id] = cached_path(
- file_path,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- user_agent=user_agent,
- )
-
- except FileNotFoundError as error:
- if local_files_only:
- unresolved_files.append(file_id)
- else:
- raise error
-
- except RepositoryNotFoundError:
- raise EnvironmentError(
- f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
- "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
- "pass a token having permission to this repo with `use_auth_token` or log in with "
- "`huggingface-cli login` and pass `use_auth_token=True`."
- )
- except RevisionNotFoundError:
- raise EnvironmentError(
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
- "for this model name. Check the model page at "
- f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
- )
- except EntryNotFoundError:
- logger.debug(f"{pretrained_model_name_or_path} does not contain a file named {file_path}.")
- resolved_vocab_files[file_id] = None
-
- except ValueError:
- logger.debug(f"Connection problem to access {file_path} and it wasn't found in the cache.")
- resolved_vocab_files[file_id] = None
+ resolved_vocab_files[file_id] = cached_file(
+ pretrained_model_name_or_path,
+ file_path,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ _raise_exceptions_for_connection_errors=False,
+ )
if len(unresolved_files) > 0:
logger.info(
@@ -1797,7 +1751,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
if file_id not in resolved_vocab_files:
continue
- if file_path == resolved_vocab_files[file_id]:
+ if is_local:
logger.info(f"loading file {file_path}")
else:
logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}")
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index 377932e2d490e7..023dffc27a703b 100644
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -60,6 +60,7 @@
PushToHubMixin,
RepositoryNotFoundError,
RevisionNotFoundError,
+ cached_file,
cached_path,
default_cache_path,
define_sagemaker_information,
@@ -76,6 +77,7 @@
is_local_clone,
is_offline_mode,
is_remote_url,
+ move_cache,
send_example_telemetry,
url_to_filename,
)
diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py
index 1fd22d7a7cb70f..9e81654cda7e1a 100644
--- a/src/transformers/utils/hub.py
+++ b/src/transformers/utils/hub.py
@@ -19,11 +19,13 @@
import io
import json
import os
+import re
import shutil
import subprocess
import sys
import tarfile
import tempfile
+import traceback
import warnings
from contextlib import contextmanager
from functools import partial
@@ -34,9 +36,20 @@
from uuid import uuid4
from zipfile import ZipFile, is_zipfile
+import huggingface_hub
import requests
from filelock import FileLock
-from huggingface_hub import CommitOperationAdd, HfFolder, create_commit, create_repo, list_repo_files, whoami
+from huggingface_hub import (
+ CommitOperationAdd,
+ HfFolder,
+ create_commit,
+ create_repo,
+ hf_hub_download,
+ list_repo_files,
+ whoami,
+)
+from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
+from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests.exceptions import HTTPError
from requests.models import Response
from transformers.utils.logging import tqdm
@@ -385,21 +398,6 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua
-class RepositoryNotFoundError(HTTPError):
- """
- Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
- not have access to.
- """
-
-
-class EntryNotFoundError(HTTPError):
- """Raised when trying to access a hf.co URL with a valid repository and revision but an invalid filename."""
-
-
-class RevisionNotFoundError(HTTPError):
- """Raised when trying to access a hf.co URL with a valid repository but an invalid revision."""
-
-
def _raise_for_status(response: Response):
"""
Internal version of `request.raise_for_status()` that will refine a potential HTTPError.
@@ -628,8 +626,58 @@ def _resumable_file_manager() -> "io.BufferedWriter":
return cache_path
-def get_file_from_repo(
- path_or_repo: Union[str, os.PathLike],
+def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
+ """
+ Explores the cache to return the latest cached file for a given revision.
+ """
+ if revision is None:
+ revision = "main"
+
+ model_id = repo_id.replace("/", "--")
+ model_cache = os.path.join(cache_dir, f"models--{model_id}")
+ if not os.path.isdir(model_cache):
+ # No cache for this model
+ return None
+
+ # Resolve refs (for instance to convert main to the associated commit sha)
+ cached_refs = os.listdir(os.path.join(model_cache, "refs"))
+ if revision in cached_refs:
+ with open(os.path.join(model_cache, "refs", revision)) as f:
+ revision = f.read()
+
+ cached_shas = os.listdir(os.path.join(model_cache, "snapshots"))
+ if revision not in cached_shas:
+ # No cache for this revision and we won't try to return a random revision
+ return None
+
+ cached_file = os.path.join(model_cache, "snapshots", revision, filename)
+ return cached_file if os.path.isfile(cached_file) else None
+
+
+# If huggingface_hub changes the class of error for this to FileNotFoundError, we will be able to avoid that in the
+# future.
+LOCAL_FILES_ONLY_HF_ERROR = (
+ "Cannot find the requested files in the disk cache and outgoing traffic has been disabled. To enable hf.co "
+ "look-ups and downloads online, set 'local_files_only' to False."
+)
+
+
+# In the future, this ugly contextmanager can be removed when huggingface_hub as a released version where we can
+# activate/deactivate progress bars.
+@contextmanager
+def _patch_hf_hub_tqdm():
+ """
+ A context manager to make huggingface hub use the tqdm version of Transformers (which is controlled by some utils)
+ in logging.
+ """
+ old_tqdm = huggingface_hub.file_download.tqdm
+ huggingface_hub.file_download.tqdm = tqdm
+ yield
+ huggingface_hub.file_download.tqdm = old_tqdm
+
+
+def cached_file(
+ path_or_repo_id: Union[str, os.PathLike],
filename: str,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
@@ -638,12 +686,16 @@ def get_file_from_repo(
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
+ subfolder: str = "",
+ user_agent: Optional[Union[str, Dict[str, str]]] = None,
+ _raise_exceptions_for_missing_entries=True,
+ _raise_exceptions_for_connection_errors=True,
):
"""
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
Args:
- path_or_repo (`str` or `os.PathLike`):
+ path_or_repo_id (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a model repo on huggingface.co.
@@ -670,6 +722,9 @@ def get_file_from_repo(
identifier allowed by git.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
@@ -678,43 +733,56 @@ def get_file_from_repo(
Returns:
- `Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo) or `None` if the
- file does not exist.
+ `Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo).
Examples:
```python
- # Download a tokenizer configuration from huggingface.co and cache.
- tokenizer_config = get_file_from_repo("bert-base-uncased", "tokenizer_config.json")
- # This model does not have a tokenizer config so the result will be None.
- tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json")
+ # Download a model weight from the Hub and cache it.
+ model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")
```"""
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
+ if subfolder is None:
+ subfolder = ""
+
+ path_or_repo_id = str(path_or_repo_id)
+ full_filename = os.path.join(subfolder, filename)
+ if os.path.isdir(path_or_repo_id):
+ resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename)
+ if not os.path.isfile(resolved_file):
+ if _raise_exceptions_for_missing_entries:
+ raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")
+ else:
+ return None
+ return resolved_file
- path_or_repo = str(path_or_repo)
- if os.path.isdir(path_or_repo):
- resolved_file = os.path.join(path_or_repo, filename)
- return resolved_file if os.path.isfile(resolved_file) else None
- else:
- resolved_file = hf_bucket_url(path_or_repo, filename=filename, revision=revision, mirror=None)
-
+ if cache_dir is None:
+ cache_dir = TRANSFORMERS_CACHE
+ if isinstance(cache_dir, Path):
+ cache_dir = str(cache_dir)
+ user_agent = http_user_agent(user_agent)
try:
# Load from URL or cache if already cached
- resolved_file = cached_path(
- resolved_file,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- use_auth_token=use_auth_token,
- )
+ with _patch_hf_hub_tqdm():
+ resolved_file = hf_hub_download(
+ path_or_repo_id,
+ filename,
+ subfolder=None if len(subfolder) == 0 else subfolder,
+ revision=revision,
+ cache_dir=cache_dir,
+ user_agent=user_agent,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ use_auth_token=use_auth_token,
+ local_files_only=local_files_only,
+ )
except RepositoryNotFoundError:
raise EnvironmentError(
- f"{path_or_repo} is not a local folder and is not a valid model identifier "
+ f"{path_or_repo_id} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
"pass a token having permission to this repo with `use_auth_token` or log in with "
"`huggingface-cli login` and pass `use_auth_token=True`."
@@ -723,15 +791,129 @@ def get_file_from_repo(
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
"for this model name. Check the model page at "
- f"'https://huggingface.co/{path_or_repo}' for available revisions."
+ f"'https://huggingface.co/{path_or_repo_id}' for available revisions."
+ )
+ except EntryNotFoundError:
+ if not _raise_exceptions_for_missing_entries:
+ return None
+ if revision is None:
+ revision = "main"
+ raise EnvironmentError(
+ f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
+ f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
+ )
+ except HTTPError as err:
+ # First we try to see if we have a cached version (not up to date):
+ resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
+ if resolved_file is not None:
+ return resolved_file
+ if not _raise_exceptions_for_connection_errors:
+ return None
+
+ raise EnvironmentError(f"There was a specific connection error when trying to load {path_or_repo_id}:\n{err}")
+ except ValueError as err:
+ # HuggingFace Hub returns a ValueError for a missing file when local_files_only=True we need to catch it here
+ # This could be caught above along in `EntryNotFoundError` if hf_hub sent a different error message here
+ if LOCAL_FILES_ONLY_HF_ERROR in err.args[0] and local_files_only and not _raise_exceptions_for_missing_entries:
+ return None
+
+ # Otherwise we try to see if we have a cached version (not up to date):
+ resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, revision=revision)
+ if resolved_file is not None:
+ return resolved_file
+ if not _raise_exceptions_for_connection_errors:
+ return None
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this file, couldn't find it in the"
+ f" cached files and it looks like {path_or_repo_id} is not the path to a directory containing a file named"
+ f" {full_filename}.\nCheckout your internet connection or see how to run the library in offline mode at"
+ " 'https://huggingface.co/docs/transformers/installation#offline-mode'."
)
- except EnvironmentError:
- # The repo and revision exist, but the file does not or there was a connection error fetching it.
- return None
return resolved_file
+def get_file_from_repo(
+ path_or_repo: Union[str, os.PathLike],
+ filename: str,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict[str, str]] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ subfolder: str = "",
+):
+ """
+ Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
+
+ Args:
+ path_or_repo (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a model repo on huggingface.co.
+ - a path to a *directory* potentially containing the file.
+ filename (`str`):
+ The name of the file to locate in `path_or_repo`.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
+
+
+
+ Passing `use_auth_token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo) or `None` if the
+ file does not exist.
+
+ Examples:
+
+ ```python
+ # Download a tokenizer configuration from huggingface.co and cache.
+ tokenizer_config = get_file_from_repo("bert-base-uncased", "tokenizer_config.json")
+ # This model does not have a tokenizer config so the result will be None.
+ tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json")
+ ```"""
+ return cached_file(
+ path_or_repo_id=path_or_repo,
+ filename=filename,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ local_files_only=local_files_only,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ _raise_exceptions_for_connection_errors=False,
+ )
+
+
def has_file(
path_or_repo: Union[str, os.PathLike],
filename: str,
@@ -766,7 +948,7 @@ def has_file(
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10)
try:
- _raise_for_status(r)
+ huggingface_hub.utils._errors._raise_for_status(r)
return True
except RepositoryNotFoundError as e:
logger.error(e)
@@ -1196,3 +1378,183 @@ def get_checkpoint_shard_files(
cached_filenames.append(cached_filename)
return cached_filenames, sharded_metadata
+
+
+# All what is below is for conversion between old cache format and new cache format.
+
+
+def get_all_cached_files(cache_dir=None):
+ """
+ Returns a list for all files cached with appropriate metadata.
+ """
+ if cache_dir is None:
+ cache_dir = TRANSFORMERS_CACHE
+ else:
+ cache_dir = str(cache_dir)
+
+ cached_files = []
+ for file in os.listdir(cache_dir):
+ meta_path = os.path.join(cache_dir, f"{file}.json")
+ if not os.path.isfile(meta_path):
+ continue
+
+ with open(meta_path, encoding="utf-8") as meta_file:
+ metadata = json.load(meta_file)
+ url = metadata["url"]
+ etag = metadata["etag"].replace('"', "")
+ cached_files.append({"file": file, "url": url, "etag": etag})
+
+ return cached_files
+
+
+def get_hub_metadata(url, token=None):
+ """
+ Returns the commit hash and associated etag for a given url.
+ """
+ if token is None:
+ token = HfFolder.get_token()
+ headers = {"user-agent": http_user_agent()}
+ headers["authorization"] = f"Bearer {token}"
+
+ r = huggingface_hub.file_download._request_with_retry(
+ method="HEAD", url=url, headers=headers, allow_redirects=False
+ )
+ huggingface_hub.file_download._raise_for_status(r)
+ commit_hash = r.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT)
+ etag = r.headers.get(HUGGINGFACE_HEADER_X_LINKED_ETAG) or r.headers.get("ETag")
+ if etag is not None:
+ etag = huggingface_hub.file_download._normalize_etag(etag)
+ return etag, commit_hash
+
+
+def extract_info_from_url(url):
+ """
+ Extract repo_name, revision and filename from an url.
+ """
+ search = re.search(r"^https://huggingface\.co/(.*)/resolve/([^/]*)/(.*)$", url)
+ if search is None:
+ return None
+ repo, revision, filename = search.groups()
+ cache_repo = "--".join(["models"] + repo.split("/"))
+ return {"repo": cache_repo, "revision": revision, "filename": filename}
+
+
+def clean_files_for(file):
+ """
+ Remove, if they exist, file, file.json and file.lock
+ """
+ for f in [file, f"{file}.json", f"{file}.lock"]:
+ if os.path.isfile(f):
+ os.remove(f)
+
+
+def move_to_new_cache(file, repo, filename, revision, etag, commit_hash):
+ """
+ Move file to repo following the new huggingface hub cache organization.
+ """
+ os.makedirs(repo, exist_ok=True)
+
+ # refs
+ os.makedirs(os.path.join(repo, "refs"), exist_ok=True)
+ if revision != commit_hash:
+ ref_path = os.path.join(repo, "refs", revision)
+ with open(ref_path, "w") as f:
+ f.write(commit_hash)
+
+ # blobs
+ os.makedirs(os.path.join(repo, "blobs"), exist_ok=True)
+ # TODO: replace copy by move when all works well.
+ blob_path = os.path.join(repo, "blobs", etag)
+ shutil.move(file, blob_path)
+
+ # snapshots
+ os.makedirs(os.path.join(repo, "snapshots"), exist_ok=True)
+ os.makedirs(os.path.join(repo, "snapshots", commit_hash), exist_ok=True)
+ pointer_path = os.path.join(repo, "snapshots", commit_hash, filename)
+ huggingface_hub.file_download._create_relative_symlink(blob_path, pointer_path)
+ clean_files_for(file)
+
+
+def move_cache(cache_dir=None, token=None):
+ if cache_dir is None:
+ cache_dir = TRANSFORMERS_CACHE
+ if token is None:
+ token = HfFolder.get_token()
+ cached_files = get_all_cached_files(cache_dir=cache_dir)
+ print(f"Moving {len(cached_files)} files to the new cache system")
+
+ hub_metadata = {}
+ for file_info in tqdm(cached_files):
+ url = file_info.pop("url")
+ if url not in hub_metadata:
+ try:
+ hub_metadata[url] = get_hub_metadata(url, token=token)
+ except requests.HTTPError:
+ continue
+
+ etag, commit_hash = hub_metadata[url]
+ if etag is None or commit_hash is None:
+ continue
+
+ if file_info["etag"] != etag:
+ # Cached file is not up to date, we just throw it as a new version will be downloaded anyway.
+ clean_files_for(os.path.join(cache_dir, file_info["file"]))
+ continue
+
+ url_info = extract_info_from_url(url)
+ if url_info is None:
+ # Not a file from huggingface.co
+ continue
+
+ repo = os.path.join(cache_dir, url_info["repo"])
+ move_to_new_cache(
+ file=os.path.join(cache_dir, file_info["file"]),
+ repo=repo,
+ filename=url_info["filename"],
+ revision=url_info["revision"],
+ etag=etag,
+ commit_hash=commit_hash,
+ )
+
+
+cache_version_file = os.path.join(TRANSFORMERS_CACHE, "version.txt")
+if not os.path.isfile(cache_version_file):
+ cache_version = 0
+else:
+ with open(cache_version_file) as f:
+ cache_version = int(f.read())
+
+
+if cache_version < 1:
+ if is_offline_mode():
+ logger.warn(
+ "You are offline and the cache for model files in Transformers v4.22.0 has been updated while your local "
+ "cache seems to be the one of a previous version. It is very likely that all your calls to any "
+ "`from_pretrained()` method will fail. Remove the offline mode and enable internet connection to have "
+ "your cache be updated automatically, then you can go back to offline mode."
+ )
+ else:
+ logger.warn(
+ "The cache for model files in Transformers v4.22.0 has been udpated. Migrating your old cache. This is a "
+ "one-time only operation. You can interrupt this and resume the migration later on by calling "
+ "`transformers.utils.move_cache()`."
+ )
+ try:
+ move_cache()
+ except Exception as e:
+ trace = "\n".join(traceback.format_tb(e.__traceback__))
+ logger.error(
+ f"There was a problem when trying to move your cache:\n\n{trace}\n\nPlease file an issue at "
+ "https://github.com/huggingface/transformers/issues/new/choose and copy paste this whole message and we "
+ "will do our best to help."
+ )
+
+ try:
+ os.makedirs(TRANSFORMERS_CACHE, exist_ok=True)
+ with open(cache_version_file, "w") as f:
+ f.write("1")
+ except Exception:
+ logger.warn(
+ f"There was a problem when trying to write in your cache folder ({TRANSFORMERS_CACHE}). You should set "
+ "the environment variable TRANSFORMERS_CACHE to a writable directory."
+ )
diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py
index b6c8ed77dc3571..397346c7deec77 100644
--- a/tests/test_configuration_common.py
+++ b/tests/test_configuration_common.py
@@ -345,14 +345,14 @@ def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
- response_mock.headers = []
+ response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
- with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
+ with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
diff --git a/tests/test_feature_extraction_common.py b/tests/test_feature_extraction_common.py
index a822b75cc5eb62..3ecf89a908672f 100644
--- a/tests/test_feature_extraction_common.py
+++ b/tests/test_feature_extraction_common.py
@@ -170,13 +170,13 @@ def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
- response_mock.headers = []
+ response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# Under the mock environment we get a 500 error when trying to reach the model.
- with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
+ with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
# This check we did call the fake head request
mock_head.assert_called()
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index c05771336e6365..8f80d7fa42f791 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -2925,14 +2925,14 @@ def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
- response_mock.headers = []
+ response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
- with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
+ with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py
index 15855e6a1f40e6..abf26af2b65116 100644
--- a/tests/test_modeling_tf_common.py
+++ b/tests/test_modeling_tf_common.py
@@ -1922,14 +1922,14 @@ def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
- response_mock.headers = []
+ response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
- with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
+ with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()
diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py
index e1ed8530fdbdea..5941a571189960 100644
--- a/tests/test_tokenization_common.py
+++ b/tests/test_tokenization_common.py
@@ -3829,14 +3829,14 @@ def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock()
response_mock.status_code = 500
- response_mock.headers = []
+ response_mock.headers = {}
response_mock.raise_for_status.side_effect = HTTPError
# Download this model to make sure it's in the cache.
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# Under the mock environment we get a 500 error when trying to reach the model.
- with mock.patch("transformers.utils.hub.requests.head", return_value=response_mock) as mock_head:
+ with mock.patch("requests.request", return_value=response_mock) as mock_head:
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
# This check we did call the fake head request
mock_head.assert_called()