Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BaseConfig update for transformers>=4.22.0 #386

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 155 additions & 69 deletions optimum/configuration_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -17,19 +18,31 @@
import json
import os
import re
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Tuple, Union

from packaging import version
from transformers import PretrainedConfig
from transformers import __version__ as transformers_version
from transformers.utils import is_offline_mode

from huggingface_hub import hf_hub_download
from transformers import __version__ as transformers_version_str

from .utils import logging
from .version import __version__


# TODO: remove once transformers release version is way above 4.22.
_transformers_version = version.parse(transformers_version_str)
_transformers_version_threshold = (4, 22)
_transformers_version_is_below_threshold = (
_transformers_version.major,
_transformers_version.minor,
) < _transformers_version_threshold

if _transformers_version_is_below_threshold:
from transformers.utils import cached_path, hf_bucket_url
else:
from transformers.dynamic_module_utils import custom_object_save
from transformers.utils import cached_file, download_url, extract_commit_hash, is_remote_url


logger = logging.get_logger(__name__)


Expand All @@ -49,51 +62,70 @@ def _re_configuration_file(cls):
# Adapted from transformers.configuration_utils.PretrainedConfig.save_pretrained
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a configuration object to the directory ``save_directory``, so that it can be re-loaded using the
:func:`~transformers.PretrainedConfig.from_pretrained` class method.
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~PretrainedConfig.from_pretrained`] class method.

Args:
save_directory (:obj:`str` or :obj:`os.PathLike`):
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to push your model to the Hugging Face model hub after saving it.

.. warning::

Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with
:obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are
pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory
instead.

push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
kwargs:
Additional key word arguments passed along to the
:meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method.
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")

# TODO: remove conditon once transformers release version is way above 4.22.
if not _transformers_version_is_below_threshold:
os.makedirs(save_directory, exist_ok=True)

if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo = self._create_or_get_repo(save_directory, **kwargs)
# TODO: remove once transformers release version is way above 4.22.
if _transformers_version_is_below_threshold:
repo = self._create_or_get_repo(save_directory, **kwargs)
else:
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)

# TODO: remove once transformers release version is way above 4.22.
if _transformers_version_is_below_threshold:
os.makedirs(save_directory, exist_ok=True)

# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:
custom_object_save(self, save_directory, config=self)

os.makedirs(save_directory, exist_ok=True)
# If we save using the predefined names, we can load using `from_pretrained`
output_config_file = os.path.join(save_directory, self.CONFIG_NAME)

self.to_json_file(output_config_file, use_diff=True)
logger.info(f"Configuration saved in {output_config_file}")

if push_to_hub:
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Configuration pushed to the hub in this commit: {url}")
# TODO: remove once transformers release version is way above 4.22.
if _transformers_version_is_below_threshold:
url = self._push_to_hub(repo, commit_message=commit_message)
logger.info(f"Configuration pushed to the hub in this commit: {url}")
else:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
)

# Adapted from transformers.configuration_utils.PretrainedConfig.get_configuration_file
@classmethod
def get_configuration_file(cls, configuration_files: List[str]) -> str:
"""
Get the configuration file to use for this version of transformers.

Args:
configuration_files (`List[str]`): The list of available configuration files.

Returns:
`str`: The configuration file to use.
"""
Expand All @@ -106,6 +138,7 @@ def get_configuration_file(cls, configuration_files: List[str]) -> str:
configuration_files_map[v] = file_name
available_versions = sorted(configuration_files_map.keys())

# Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
configuration_file = cls.CONFIG_NAME
optimum_version = version.parse(__version__)
for v in available_versions:
Expand All @@ -125,11 +158,14 @@ def get_config_dict(
"""
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
[`PretrainedConfig`] using `from_dict`.

Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`):
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.

Returns:
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.

"""
original_kwargs = copy.deepcopy(kwargs)
# Get config dict associated with the base config file
Expand All @@ -151,114 +187,160 @@ def get_config_dict(
def _get_config_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a
:class:`~transformers.PretrainedConfig` using ``from_dict``.

Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.

Returns:
:obj:`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.

"""
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
subfolder = kwargs.pop("subfolder", "")
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
commit_hash = kwargs.pop("_commit_hash", None)

if trust_remote_code is True:
logger.warning(
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
" ignored."
)

user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline

if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True

pretrained_model_name_or_path = str(pretrained_model_name_or_path)
configuration_file = kwargs.pop("_configuration_file", cls.CONFIG_NAME)
is_local = os.path.isdir(pretrained_model_name_or_path)

if os.path.isfile(pretrained_model_name_or_path):
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
# Special case when pretrained_model_name_or_path is a local file
resolved_config_file = pretrained_model_name_or_path
is_local = True
elif os.path.isdir(pretrained_model_name_or_path):
# TODO: remove once transformers release version is way above 4.22.
elif _transformers_version_is_below_threshold and os.path.isdir(pretrained_model_name_or_path):
configuration_file = kwargs.pop("_configuration_file", cls.CONFIG_NAME)
resolved_config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
if not os.path.isfile(resolved_config_file):
raise EnvironmentError(
f"Could not locate {configuration_file} inside {pretrained_model_name_or_path}."
)
# TODO: remove condition once transformers release version is way above 4.22.
elif not _transformers_version_is_below_threshold and is_remote_url(pretrained_model_name_or_path):
configuration_file = pretrained_model_name_or_path
resolved_config_file = download_url(pretrained_model_name_or_path)
else:
configuration_file = kwargs.pop("_configuration_file", cls.CONFIG_NAME)

try:
resolved_config_file = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename=configuration_file,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)
# TODO: remove once transformers release version is way above 4.22.
if _transformers_version_is_below_threshold:
config_file = hf_bucket_url(
pretrained_model_name_or_path,
filename=configuration_file,
revision=revision,
subfolder=subfolder if len(subfolder) > 0 else None,
mirror=None,
)
# Load from URL or cache if already cached
resolved_config_file = cached_path(
config_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
)
else:
# Load from local folder or from cache or download from model Hub and cache
resolved_config_file = cached_file(
echarlaix marked this conversation as resolved.
Show resolved Hide resolved
pretrained_model_name_or_path,
configuration_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
_commit_hash=commit_hash,
)
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
except EnvironmentError:
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
# the original exception.
raise
except Exception:
# For any other exception, we throw a generic error.
raise EnvironmentError(
f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it "
"from 'https://huggingface.co/models', make sure you don't have a local directory with the same "
f"name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {configuration_file} file"
f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
f" containing a {configuration_file} file"
)

try:
# Load config dict
config_dict = cls._dict_from_json_file(resolved_config_file)
# TODO: remove once transformers release version is way above 4.22.
if _transformers_version_is_below_threshold:
config_dict["_commit_hash"] = commit_hash
except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
)

if is_local:
logger.info(f"Loading configuration file {resolved_config_file}")
logger.info(f"loading configuration file {resolved_config_file}")
else:
logger.info(f"Loading configuration file from cache at {resolved_config_file}")
logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")

return config_dict, kwargs

# Adapted from transformers.configuration_utils.PretrainedConfig.from_dict
@classmethod
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
"""
Instantiates a :class:`~transformers.PretrainedConfig` from a Python dictionary of parameters.
Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.

Args:
config_dict (:obj:`Dict[str, Any]`):
config_dict (`Dict[str, Any]`):
Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
retrieved from a pretrained checkpoint by leveraging the
:func:`~transformers.PretrainedConfig.get_config_dict` method.
kwargs (:obj:`Dict[str, Any]`):
retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
kwargs (`Dict[str, Any]`):
Additional parameters from which to initialize the configuration object.

Returns:
:class:`PretrainedConfig`: The configuration object instantiated from those parameters.
[`PretrainedConfig`]: The configuration object instantiated from those parameters.
"""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
# Those arguments may be passed along for our internal telemetry.
# We remove them so they don't appear in `return_unused_kwargs`.
kwargs.pop("_from_auto", None)
kwargs.pop("_from_pipeline", None)
# The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
kwargs["_commit_hash"] = config_dict["_commit_hash"]

config = cls(**config_dict)

if hasattr(config, "pruned_heads"):
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())

# Update config with kwargs if needed
if "num_labels" in kwargs and "id2label" in kwargs:
num_labels = kwargs["num_labels"]
id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
if len(id2label) != num_labels:
raise ValueError(
f"You passed along `num_labels={num_labels }` with an incompatible id to label map: "
f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
"one of them."
)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
Expand All @@ -280,14 +362,18 @@ def to_dict(self) -> Dict[str, Any]:
Serializes this instance to a Python dictionary.

Returns:
:obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
if hasattr(self.__class__, "model_type"):
output["model_type"] = self.__class__.model_type
if "_auto_class" in output:
del output["_auto_class"]
if "_commit_hash" in output:
del output["_commit_hash"]

# Transformers version when serializing the model
output["transformers_version"] = transformers_version
output["transformers_version"] = transformers_version_str
output["optimum_version"] = __version__

self.dict_torch_dtype_to_str(output)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_push_to_hub(self):
)

new_config = FakeConfig.from_pretrained(f"{USER}/optimum-test-base-config")
for k, v in config.__dict__.items():
for k, v in config.to_dict().items():
if k != "optimum_version" and k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))

Expand All @@ -92,6 +92,6 @@ def test_push_to_hub_in_organization(self):
)

new_config = FakeConfig.from_pretrained("valid_org/optimum-test-base-config-org")
for k, v in config.__dict__.items():
for k, v in config.to_dict().items():
if k != "optimum_version" and k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))