Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Harmonize HF environment variables + other cleaning #27564

Merged
merged 4 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/transformers/dynamic_module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from huggingface_hub import try_to_load_from_cache

from .utils import (
HF_MODULES_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME,
cached_file,
extract_commit_hash,
is_offline_mode,
logging,
try_to_load_from_cache,
)


Expand Down
157 changes: 59 additions & 98 deletions src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@
import huggingface_hub
import requests
from huggingface_hub import (
_CACHED_NO_EXIST,
CommitOperationAdd,
constants,
create_branch,
create_commit,
create_repo,
get_hf_file_metadata,
hf_hub_download,
hf_hub_url,
try_to_load_from_cache,
)
from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get
from huggingface_hub.utils import (
Expand All @@ -49,7 +52,9 @@
RevisionNotFoundError,
build_hf_headers,
hf_raise_for_status,
send_telemetry,
)
from huggingface_hub.utils._deprecation import _deprecate_method
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make this a public method to leverage?

Not ideal to depend on private attributes from hfh 🫣

Copy link
Contributor Author

@Wauplin Wauplin Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I realized about that while working on this PR 😭

I'm fine with either making them public (huggingface_hub.utils.deprecate_method ?) which would require to wait next hfh release or making an exception for this one since it's a quite internal method anyway (so using private attribute in transformers).

Copy link
Contributor Author

@Wauplin Wauplin Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in #27564 (comment) I suggest to still use the private attribute until it's made more official in hfh side. Worse case scenario, the plan is to remove this in 3 transformers releases.

from requests.exceptions import HTTPError

from . import __version__, logging
Expand All @@ -75,17 +80,25 @@ def is_offline_mode():


torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
default_cache_path = constants.default_cache_path
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
# New default cache, shared with the Datasets library
hf_cache_home = os.path.expanduser(
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
)
default_cache_path = os.path.join(hf_cache_home, "hub")

# Determine default cache directory. Lots of legacy environment variables to ensure backward compatibility.
# The best way to set the cache path is with the environment variable HF_HOME. For more details, checkout this
# documentation page: https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables.
#
# In code, use `HF_HUB_CACHE` as the default cache path. This variable is set by the library and is guaranteed
# to be set to the right value.
#
# TODO: clean this for v5?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", constants.HF_HUB_CACHE)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)

# Onetime move from the old location to the new one if no ENV variable has been set.
if (
os.path.isdir(old_default_cache_path)
and not os.path.isdir(default_cache_path)
and not os.path.isdir(constants.HF_HUB_CACHE)
and "PYTORCH_PRETRAINED_BERT_CACHE" not in os.environ
and "PYTORCH_TRANSFORMERS_CACHE" not in os.environ
and "TRANSFORMERS_CACHE" not in os.environ
Expand All @@ -97,16 +110,26 @@ def is_offline_mode():
" '~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should"
" only see this message once."
)
shutil.move(old_default_cache_path, default_cache_path)
shutil.move(old_default_cache_path, constants.HF_HUB_CACHE)

PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", PYTORCH_TRANSFORMERS_CACHE)
amyeroberts marked this conversation as resolved.
Show resolved Hide resolved
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", HUGGINGFACE_HUB_CACHE)
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(constants.HF_HOME, "modules"))
TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules"
SESSION_ID = uuid4().hex
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", constants.HF_HUB_DISABLE_TELEMETRY) in ENV_VARS_TRUE_VALUES

# Add deprecation warning for old environment variables.
for key in ("PYTORCH_PRETRAINED_BERT_CACHE", "PYTORCH_TRANSFORMERS_CACHE", "TRANSFORMERS_CACHE"):
if os.getenv(key) is not None:
warnings.warn(
f"Using `{key}` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.",
FutureWarning,
)
if os.getenv("DISABLE_TELEMETRY") is not None:
warnings.warn(
"Using `DISABLE_TELEMETRY` is deprecated and will be removed in v5 of Transformers. Use `HF_HUB_DISABLE_TELEMETRY` instead.",
FutureWarning,
)


S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
Expand All @@ -126,15 +149,16 @@ def is_offline_mode():
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"

# Return value when trying to load a file from cache but the file does not exist in the distant repo.
_CACHED_NO_EXIST = object()


def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")


# TODO: remove this once fully deprecated
# TODO? remove from './examples/research_projects/lxmert/utils.py' as well
# TODO? remove from './examples/research_projects/visual_bert/utils.py' as well
@_deprecate_method(version="4.39.0", message="This method is outdated and does not support the new cache system.")
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
"""
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url,
Expand Down Expand Up @@ -219,7 +243,7 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua


def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]):
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]) -> Optional[str]:
"""
Extracts the commit hash from a resolved filename toward a cache file.
"""
Expand All @@ -233,73 +257,6 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None


def try_to_load_from_cache(
repo_id: str,
filename: str,
cache_dir: Union[str, Path, None] = None,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
) -> Optional[str]:
"""
Explores the cache to return the latest cached file for a given revision if found.

This function will not raise any exception if the file in not cached.

Args:
cache_dir (`str` or `os.PathLike`):
The folder where the cached files lie.
repo_id (`str`):
The ID of the repo on huggingface.co.
filename (`str`):
The filename to look for inside `repo_id`.
revision (`str`, *optional*):
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
provided either.
repo_type (`str`, *optional*):
The type of the repo.

Returns:
`Optional[str]` or `_CACHED_NO_EXIST`:
Will return `None` if the file was not cached. Otherwise:
- The exact path to the cached file if it's found in the cache
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
cached.
"""
if revision is None:
revision = "main"

if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE

object_id = repo_id.replace("/", "--")
if repo_type is None:
repo_type = "model"
repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
if not os.path.isdir(repo_cache):
# No cache for this model
return None
for subfolder in ["refs", "snapshots"]:
if not os.path.isdir(os.path.join(repo_cache, subfolder)):
return None

# Resolve refs (for instance to convert main to the associated commit sha)
cached_refs = os.listdir(os.path.join(repo_cache, "refs"))
if revision in cached_refs:
with open(os.path.join(repo_cache, "refs", revision)) as f:
revision = f.read()

if os.path.isfile(os.path.join(repo_cache, ".no_exist", revision, filename)):
return _CACHED_NO_EXIST

cached_shas = os.listdir(os.path.join(repo_cache, "snapshots"))
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None

cached_file = os.path.join(repo_cache, "snapshots", revision, filename)
return cached_file if os.path.isfile(cached_file) else None


def cached_file(
path_or_repo_id: Union[str, os.PathLike],
filename: str,
Expand All @@ -317,7 +274,7 @@ def cached_file(
_raise_exceptions_for_connection_errors: bool = True,
_commit_hash: Optional[str] = None,
**deprecated_kwargs,
):
) -> Optional[str]:
"""
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.

Expand Down Expand Up @@ -369,7 +326,8 @@ def cached_file(
```python
# Download a model weight from the Hub and cache it.
model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")
```"""
```
"""
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
Expand Down Expand Up @@ -499,6 +457,10 @@ def cached_file(
return resolved_file


# TODO: deprecate `get_file_from_repo` or document it differently?
# Docstring is exactly the same as `cached_repo` but behavior is slightly different. If file is missing or if
# there is a connection error, `cached_repo` will return None while `get_file_from_repo` will raise an error.
# IMO we should keep only 1 method and have a single `raise_error` argument (to be discussed).
def get_file_from_repo(
Comment on lines +460 to 464
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be discussed. I tend to prefer to have a single cached_file method with a raise_on_error argument (and with the correct typing to let the user know when to expect a None value and when it's not possible). I can do it here or in a separate PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'd vote to uniformize everything under hfh's banner.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I'd be ok for moving the logic to huggingface_hub but this method doesn't exist at the moment.

In hfh we have hf_hub_download that does the job to download/retrieve a file given a repo_id + filename. In transformers, cached_file is slightly different as it accept either a repo_id or a local path (in path_or_repo_id parameter), making it more versatile.

So for now I'd be more in favor of uniforming within transformers and then later move it to huggingface_hub.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_file_from_repo calls cached_file (while that method has True as default values

        _raise_exceptions_for_missing_entries=False,
        _raise_exceptions_for_connection_errors=False,

Not against to keep only one of them. But the changes might get larger. I would personally do this in a separate PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, let's wait for a separate PR 👍

path_or_repo: Union[str, os.PathLike],
filename: str,
Expand Down Expand Up @@ -564,7 +526,8 @@ def get_file_from_repo(
tokenizer_config = get_file_from_repo("bert-base-uncased", "tokenizer_config.json")
# This model does not have a tokenizer config so the result will be None.
tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json")
```"""
```
"""
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
Expand Down Expand Up @@ -609,10 +572,11 @@ def download_url(url, proxies=None):
f"Using `from_pretrained` with the url of a file (here {url}) is deprecated and won't be possible anymore in"
" v5 of Transformers. You should host your file on the Hub (hf.co) instead and use the repository ID. Note"
" that this is not compatible with the caching system (your file will be downloaded at each execution) or"
" multiple processes (each process will download the file in a different temporary file)."
" multiple processes (each process will download the file in a different temporary file).",
FutureWarning,
)
tmp_file = tempfile.mkstemp()[1]
with open(tmp_file, "wb") as f:
tmp_fd, tmp_file = tempfile.mkstemp()
with os.fdopen(tmp_fd, "wb") as f:
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
http_get(url, f, proxies=proxies)
return tmp_file

Expand Down Expand Up @@ -947,13 +911,10 @@ def send_example_telemetry(example_name, *example_args, framework="pytorch"):
script_name = script_name.replace("_no_trainer", "")
data["dataset_name"] = f"{script_name}-{args_as_dict['task_name']}"

headers = {"user-agent": http_user_agent(data)}
try:
r = requests.head(HUGGINGFACE_CO_EXAMPLES_TELEMETRY, headers=headers)
r.raise_for_status()
except Exception:
# We don't want to error in case of connection errors of any kind.
pass
# Send telemetry in the background
send_telemetry(
topic="examples", library_name="transformers", library_version=__version__, user_agent=http_user_agent(data)
)


def convert_file_size_to_int(size: Union[int, str]):
Expand Down Expand Up @@ -1258,7 +1219,7 @@ def cancel(self) -> None:
"`transformers.utils.move_cache()`."
)
try:
if TRANSFORMERS_CACHE != default_cache_path:
if TRANSFORMERS_CACHE != constants.HF_HUB_CACHE:
# Users set some env variable to customize cache storage
move_cache(TRANSFORMERS_CACHE, TRANSFORMERS_CACHE)
else:
Expand Down
Loading