Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use commit hash to look in cache instead of calling head #18534

Merged
merged 7 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,15 @@

from . import __version__
from .dynamic_module_utils import custom_object_save
from .utils import CONFIG_NAME, PushToHubMixin, cached_file, copy_func, is_torch_available, logging
from .utils import (
CONFIG_NAME,
PushToHubMixin,
cached_file,
copy_func,
extract_commit_hash,
is_torch_available,
logging,
)


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -343,6 +351,8 @@ def __init__(self, **kwargs):

# Name or path to the pretrained checkpoint
self._name_or_path = str(kwargs.pop("name_or_path", ""))
# Config hash
self._commit_hash = kwargs.pop("_commit_hash", None)

# Drop the transformers version info
self.transformers_version = kwargs.pop("transformers_version", None)
Expand Down Expand Up @@ -539,6 +549,8 @@ def get_config_dict(
original_kwargs = copy.deepcopy(kwargs)
# Get config dict associated with the base config file
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
if "_commit_hash" in config_dict:
original_kwargs["_commit_hash"] = config_dict["_commit_hash"]

# That config file may point us toward another config file to use.
if "configuration_files" in config_dict:
Expand All @@ -564,6 +576,7 @@ def _get_config_dict(
subfolder = kwargs.pop("subfolder", "")
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
commit_hash = kwargs.pop("_commit_hash", None)

if trust_remote_code is True:
logger.warning(
Expand Down Expand Up @@ -599,7 +612,9 @@ def _get_config_dict(
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
Expand All @@ -616,6 +631,7 @@ def _get_config_dict(
try:
# Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file)
config_dict["_commit_hash"] = commit_hash
except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
Expand Down Expand Up @@ -648,6 +664,9 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
# We remove them so they don't appear in `return_unused_kwargs`.
kwargs.pop("_from_auto", None)
kwargs.pop("_from_pipeline", None)
# The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
kwargs["_commit_hash"] = config_dict["_commit_hash"]

config = cls(**config_dict)

Expand Down Expand Up @@ -751,6 +770,8 @@ def to_dict(self) -> Dict[str, Any]:
output["model_type"] = self.__class__.model_type
if "_auto_class" in output:
del output["_auto_class"]
if "_commit_hash" in output:
del output["_commit_hash"]

# Transformers version when serializing the model
output["transformers_version"] = __version__
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ def from_pretrained(
from_auto_class = kwargs.pop("_from_auto", False)
_do_init = kwargs.pop("_do_init", True)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)

if trust_remote_code is True:
logger.warning(
Expand Down Expand Up @@ -625,11 +626,15 @@ def from_pretrained(
revision=revision,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
_commit_hash=commit_hash,
**kwargs,
)
else:
model_kwargs = kwargs

if commit_hash is None:
commit_hash = getattr(config, "_commit_hash", None)

# Add the dtype to model_kwargs
model_kwargs["dtype"] = dtype

Expand Down Expand Up @@ -682,6 +687,7 @@ def from_pretrained(
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)

Expand Down Expand Up @@ -748,6 +754,7 @@ def from_pretrained(
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
_commit_hash=commit_hash,
)

# init random models
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2161,6 +2161,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)

if trust_remote_code is True:
logger.warning(
Expand Down Expand Up @@ -2191,11 +2192,15 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
revision=revision,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
_commit_hash=commit_hash,
**kwargs,
)
else:
model_kwargs = kwargs

if commit_hash is None:
commit_hash = getattr(config, "_commit_hash", None)

# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded = False
Expand Down Expand Up @@ -2253,6 +2258,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)

Expand Down Expand Up @@ -2320,6 +2326,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
_commit_hash=commit_hash,
)

config.name_or_path = pretrained_model_name_or_path
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,6 +1795,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", None)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)

if trust_remote_code is True:
logger.warning(
Expand Down Expand Up @@ -1852,6 +1853,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
model_kwargs = kwargs

if commit_hash is None:
commit_hash = getattr(config, "_commit_hash", None)

# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files.
is_sharded = False
Expand Down Expand Up @@ -1938,6 +1942,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_commit_hash=commit_hash,
)
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)

Expand Down Expand Up @@ -2012,6 +2017,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)

# load pt weights early so that we know which dtype to init the model under
Expand Down
15 changes: 12 additions & 3 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import get_file_from_repo, is_sentencepiece_available, is_tokenizers_available, logging
from ...utils import cached_file, extract_commit_hash, is_sentencepiece_available, is_tokenizers_available, logging
from ..encoder_decoder import EncoderDecoderConfig
from .auto_factory import _LazyAutoMapping
from .configuration_auto import (
Expand Down Expand Up @@ -389,7 +389,8 @@ def get_tokenizer_config(
tokenizer.save_pretrained("tokenizer-test")
tokenizer_config = get_tokenizer_config("tokenizer-test")
```"""
resolved_config_file = get_file_from_repo(
commit_hash = kwargs.get("_commit_hash", None)
resolved_config_file = cached_file(
pretrained_model_name_or_path,
TOKENIZER_CONFIG_FILE,
cache_dir=cache_dir,
Expand All @@ -399,13 +400,19 @@ def get_tokenizer_config(
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)
if resolved_config_file is None:
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
return {}
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)

with open(resolved_config_file, encoding="utf-8") as reader:
return json.load(reader)
result = json.load(reader)
result["_commit_hash"] = commit_hash
return result


class AutoTokenizer:
Expand Down Expand Up @@ -532,6 +539,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):

# Next, let's try to use the tokenizer_config file to get the tokenizer class.
tokenizer_config = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
if "_commit_hash" in tokenizer_config:
kwargs["_commit_hash"] = tokenizer_config["_commit_hash"]
config_tokenizer_class = tokenizer_config.get("tokenizer_class")
tokenizer_auto_map = None
if "auto_map" in tokenizer_config:
Expand Down
11 changes: 10 additions & 1 deletion src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,12 @@ def pipeline(
# Make sure we only pass use_auth_token once as a kwarg (it used to be possible to pass it in model_kwargs,
# this is to keep BC).
use_auth_token = model_kwargs.pop("use_auth_token", use_auth_token)
hub_kwargs = {"revision": revision, "use_auth_token": use_auth_token, "trust_remote_code": trust_remote_code}
hub_kwargs = {
"revision": revision,
"use_auth_token": use_auth_token,
"trust_remote_code": trust_remote_code,
"_commit_hash": None,
}

if task is None and model is None:
raise RuntimeError(
Expand All @@ -583,8 +588,10 @@ def pipeline(
# Instantiate config if needed
if isinstance(config, str):
config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs)
hub_kwargs["_commit_hash"] = config._commit_hash
elif config is None and isinstance(model, str):
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
hub_kwargs["_commit_hash"] = config._commit_hash

custom_tasks = {}
if config is not None and len(getattr(config, "custom_pipelines", {})) > 0:
Expand Down Expand Up @@ -639,6 +646,7 @@ def pipeline(
)
if config is None and isinstance(model, str):
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
hub_kwargs["_commit_hash"] = config._commit_hash

if device_map is not None:
if "device_map" in model_kwargs:
Expand Down Expand Up @@ -672,6 +680,7 @@ def pipeline(
)

model_config = model.config
hub_kwargs["_commit_hash"] = model.config._commit_hash

load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from typing import Iterator, List, Union
from unittest import mock

import huggingface_hub
from transformers import logging as transformers_logging

from .deepspeed import is_deepspeed_available
Expand Down Expand Up @@ -1588,3 +1589,30 @@ def run_command(command: List[str], return_stdout=False):
raise SubprocessCallException(
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
) from e


class RequestCounter:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha that's cool! Nice that we test them.

"""
Helper class that will count all requests made online.
"""

def __enter__(self):
self.head_request_count = 0
self.get_request_count = 0
self.other_request_count = 0
self.old_request = huggingface_hub.file_download.requests.request
huggingface_hub.file_download.requests.request = self.new_request
return self

def __exit__(self, *args, **kwargs):
huggingface_hub.file_download.requests.request = self.old_request

def new_request(self, method, **kwargs):
if method == "GET":
self.get_request_count += 1
elif method == "HEAD":
self.head_request_count += 1
else:
self.other_request_count += 1

return self.old_request(method=method, **kwargs)
Comment on lines +1594 to +1618
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OMG Python 🤯

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Black Python magic ;-) 🪄

Loading