diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index ebf772a959e9d1..7cdc0ad93d5268 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -25,6 +25,8 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union +from huggingface_hub import try_to_load_from_cache + from .utils import ( HF_MODULES_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, @@ -32,7 +34,6 @@ extract_commit_hash, is_offline_mode, logging, - try_to_load_from_cache, ) diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index f6cf0a852ed752..8d2f77da6845c8 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -31,13 +31,16 @@ import huggingface_hub import requests from huggingface_hub import ( + _CACHED_NO_EXIST, CommitOperationAdd, + constants, create_branch, create_commit, create_repo, get_hf_file_metadata, hf_hub_download, hf_hub_url, + try_to_load_from_cache, ) from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get from huggingface_hub.utils import ( @@ -49,7 +52,9 @@ RevisionNotFoundError, build_hf_headers, hf_raise_for_status, + send_telemetry, ) +from huggingface_hub.utils._deprecation import _deprecate_method from requests.exceptions import HTTPError from . import __version__, logging @@ -75,17 +80,25 @@ def is_offline_mode(): torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) +default_cache_path = constants.default_cache_path old_default_cache_path = os.path.join(torch_cache_home, "transformers") -# New default cache, shared with the Datasets library -hf_cache_home = os.path.expanduser( - os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) -) -default_cache_path = os.path.join(hf_cache_home, "hub") + +# Determine default cache directory. Lots of legacy environment variables to ensure backward compatibility. +# The best way to set the cache path is with the environment variable HF_HOME. For more details, checkout this +# documentation page: https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables. +# +# In code, use `HF_HUB_CACHE` as the default cache path. This variable is set by the library and is guaranteed +# to be set to the right value. +# +# TODO: clean this for v5? +PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", constants.HF_HUB_CACHE) +PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) +TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE) # Onetime move from the old location to the new one if no ENV variable has been set. if ( os.path.isdir(old_default_cache_path) - and not os.path.isdir(default_cache_path) + and not os.path.isdir(constants.HF_HUB_CACHE) and "PYTORCH_PRETRAINED_BERT_CACHE" not in os.environ and "PYTORCH_TRANSFORMERS_CACHE" not in os.environ and "TRANSFORMERS_CACHE" not in os.environ @@ -97,16 +110,26 @@ def is_offline_mode(): " '~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should" " only see this message once." ) - shutil.move(old_default_cache_path, default_cache_path) + shutil.move(old_default_cache_path, constants.HF_HUB_CACHE) -PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) -PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE) -HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", PYTORCH_TRANSFORMERS_CACHE) -TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", HUGGINGFACE_HUB_CACHE) -HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) +HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(constants.HF_HOME, "modules")) TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules" SESSION_ID = uuid4().hex -DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES +DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", constants.HF_HUB_DISABLE_TELEMETRY) in ENV_VARS_TRUE_VALUES + +# Add deprecation warning for old environment variables. +for key in ("PYTORCH_PRETRAINED_BERT_CACHE", "PYTORCH_TRANSFORMERS_CACHE", "TRANSFORMERS_CACHE"): + if os.getenv(key) is not None: + warnings.warn( + f"Using `{key}` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.", + FutureWarning, + ) +if os.getenv("DISABLE_TELEMETRY") is not None: + warnings.warn( + "Using `DISABLE_TELEMETRY` is deprecated and will be removed in v5 of Transformers. Use `HF_HUB_DISABLE_TELEMETRY` instead.", + FutureWarning, + ) + S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co" @@ -126,15 +149,16 @@ def is_offline_mode(): HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}" HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples" -# Return value when trying to load a file from cache but the file does not exist in the distant repo. -_CACHED_NO_EXIST = object() - def is_remote_url(url_or_filename): parsed = urlparse(url_or_filename) return parsed.scheme in ("http", "https") +# TODO: remove this once fully deprecated +# TODO? remove from './examples/research_projects/lxmert/utils.py' as well +# TODO? remove from './examples/research_projects/visual_bert/utils.py' as well +@_deprecate_method(version="4.39.0", message="This method is outdated and does not support the new cache system.") def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]: """ Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url, @@ -219,7 +243,7 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: return ua -def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]): +def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]) -> Optional[str]: """ Extracts the commit hash from a resolved filename toward a cache file. """ @@ -233,73 +257,6 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None -def try_to_load_from_cache( - repo_id: str, - filename: str, - cache_dir: Union[str, Path, None] = None, - revision: Optional[str] = None, - repo_type: Optional[str] = None, -) -> Optional[str]: - """ - Explores the cache to return the latest cached file for a given revision if found. - - This function will not raise any exception if the file in not cached. - - Args: - cache_dir (`str` or `os.PathLike`): - The folder where the cached files lie. - repo_id (`str`): - The ID of the repo on huggingface.co. - filename (`str`): - The filename to look for inside `repo_id`. - revision (`str`, *optional*): - The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is - provided either. - repo_type (`str`, *optional*): - The type of the repo. - - Returns: - `Optional[str]` or `_CACHED_NO_EXIST`: - Will return `None` if the file was not cached. Otherwise: - - The exact path to the cached file if it's found in the cache - - A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was - cached. - """ - if revision is None: - revision = "main" - - if cache_dir is None: - cache_dir = TRANSFORMERS_CACHE - - object_id = repo_id.replace("/", "--") - if repo_type is None: - repo_type = "model" - repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}") - if not os.path.isdir(repo_cache): - # No cache for this model - return None - for subfolder in ["refs", "snapshots"]: - if not os.path.isdir(os.path.join(repo_cache, subfolder)): - return None - - # Resolve refs (for instance to convert main to the associated commit sha) - cached_refs = os.listdir(os.path.join(repo_cache, "refs")) - if revision in cached_refs: - with open(os.path.join(repo_cache, "refs", revision)) as f: - revision = f.read() - - if os.path.isfile(os.path.join(repo_cache, ".no_exist", revision, filename)): - return _CACHED_NO_EXIST - - cached_shas = os.listdir(os.path.join(repo_cache, "snapshots")) - if revision not in cached_shas: - # No cache for this revision and we won't try to return a random revision - return None - - cached_file = os.path.join(repo_cache, "snapshots", revision, filename) - return cached_file if os.path.isfile(cached_file) else None - - def cached_file( path_or_repo_id: Union[str, os.PathLike], filename: str, @@ -317,7 +274,7 @@ def cached_file( _raise_exceptions_for_connection_errors: bool = True, _commit_hash: Optional[str] = None, **deprecated_kwargs, -): +) -> Optional[str]: """ Tries to locate a file in a local folder and repo, downloads and cache it if necessary. @@ -369,7 +326,8 @@ def cached_file( ```python # Download a model weight from the Hub and cache it. model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin") - ```""" + ``` + """ use_auth_token = deprecated_kwargs.pop("use_auth_token", None) if use_auth_token is not None: warnings.warn( @@ -499,6 +457,10 @@ def cached_file( return resolved_file +# TODO: deprecate `get_file_from_repo` or document it differently? +# Docstring is exactly the same as `cached_repo` but behavior is slightly different. If file is missing or if +# there is a connection error, `cached_repo` will return None while `get_file_from_repo` will raise an error. +# IMO we should keep only 1 method and have a single `raise_error` argument (to be discussed). def get_file_from_repo( path_or_repo: Union[str, os.PathLike], filename: str, @@ -564,7 +526,8 @@ def get_file_from_repo( tokenizer_config = get_file_from_repo("bert-base-uncased", "tokenizer_config.json") # This model does not have a tokenizer config so the result will be None. tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json") - ```""" + ``` + """ use_auth_token = deprecated_kwargs.pop("use_auth_token", None) if use_auth_token is not None: warnings.warn( @@ -609,10 +572,11 @@ def download_url(url, proxies=None): f"Using `from_pretrained` with the url of a file (here {url}) is deprecated and won't be possible anymore in" " v5 of Transformers. You should host your file on the Hub (hf.co) instead and use the repository ID. Note" " that this is not compatible with the caching system (your file will be downloaded at each execution) or" - " multiple processes (each process will download the file in a different temporary file)." + " multiple processes (each process will download the file in a different temporary file).", + FutureWarning, ) - tmp_file = tempfile.mkstemp()[1] - with open(tmp_file, "wb") as f: + tmp_fd, tmp_file = tempfile.mkstemp() + with os.fdopen(tmp_fd, "wb") as f: http_get(url, f, proxies=proxies) return tmp_file @@ -947,13 +911,10 @@ def send_example_telemetry(example_name, *example_args, framework="pytorch"): script_name = script_name.replace("_no_trainer", "") data["dataset_name"] = f"{script_name}-{args_as_dict['task_name']}" - headers = {"user-agent": http_user_agent(data)} - try: - r = requests.head(HUGGINGFACE_CO_EXAMPLES_TELEMETRY, headers=headers) - r.raise_for_status() - except Exception: - # We don't want to error in case of connection errors of any kind. - pass + # Send telemetry in the background + send_telemetry( + topic="examples", library_name="transformers", library_version=__version__, user_agent=http_user_agent(data) + ) def convert_file_size_to_int(size: Union[int, str]): @@ -1258,7 +1219,7 @@ def cancel(self) -> None: "`transformers.utils.move_cache()`." ) try: - if TRANSFORMERS_CACHE != default_cache_path: + if TRANSFORMERS_CACHE != constants.HF_HUB_CACHE: # Users set some env variable to customize cache storage move_cache(TRANSFORMERS_CACHE, TRANSFORMERS_CACHE) else: