diff --git a/optimum/configuration_utils.py b/optimum/configuration_utils.py index c0589bfb17..f004049c2c 100644 --- a/optimum/configuration_utils.py +++ b/optimum/configuration_utils.py @@ -1,3 +1,4 @@ +# coding=utf-8 # Copyright 2021 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,19 +18,31 @@ import json import os import re -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple, Union from packaging import version from transformers import PretrainedConfig -from transformers import __version__ as transformers_version -from transformers.utils import is_offline_mode - -from huggingface_hub import hf_hub_download +from transformers import __version__ as transformers_version_str from .utils import logging from .version import __version__ +# TODO: remove once transformers release version is way above 4.22. +_transformers_version = version.parse(transformers_version_str) +_transformers_version_threshold = (4, 22) +_transformers_version_is_below_threshold = ( + _transformers_version.major, + _transformers_version.minor, +) < _transformers_version_threshold + +if _transformers_version_is_below_threshold: + from transformers.utils import cached_path, hf_bucket_url +else: + from transformers.dynamic_module_utils import custom_object_save + from transformers.utils import cached_file, download_url, extract_commit_hash, is_remote_url + + logger = logging.get_logger(__name__) @@ -49,34 +62,45 @@ def _re_configuration_file(cls): # Adapted from transformers.configuration_utils.PretrainedConfig.save_pretrained def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ - Save a configuration object to the directory ``save_directory``, so that it can be re-loaded using the - :func:`~transformers.PretrainedConfig.from_pretrained` class method. + Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~PretrainedConfig.from_pretrained`] class method. Args: - save_directory (:obj:`str` or :obj:`os.PathLike`): + save_directory (`str` or `os.PathLike`): Directory where the configuration JSON file will be saved (will be created if it does not exist). - push_to_hub (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to push your model to the Hugging Face model hub after saving it. - - .. warning:: - - Using :obj:`push_to_hub=True` will synchronize the repository you are pushing to with - :obj:`save_directory`, which requires :obj:`save_directory` to be a local clone of the repo you are - pushing to if it's an existing folder. Pass along :obj:`temp_dir=True` to use a temporary directory - instead. - + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). kwargs: - Additional key word arguments passed along to the - :meth:`~transformers.file_utils.PushToHubMixin.push_to_hub` method. + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ if os.path.isfile(save_directory): raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + # TODO: remove conditon once transformers release version is way above 4.22. + if not _transformers_version_is_below_threshold: + os.makedirs(save_directory, exist_ok=True) + if push_to_hub: commit_message = kwargs.pop("commit_message", None) - repo = self._create_or_get_repo(save_directory, **kwargs) + # TODO: remove once transformers release version is way above 4.22. + if _transformers_version_is_below_threshold: + repo = self._create_or_get_repo(save_directory, **kwargs) + else: + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id, token = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + # TODO: remove once transformers release version is way above 4.22. + if _transformers_version_is_below_threshold: + os.makedirs(save_directory, exist_ok=True) + + # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self) - os.makedirs(save_directory, exist_ok=True) # If we save using the predefined names, we can load using `from_pretrained` output_config_file = os.path.join(save_directory, self.CONFIG_NAME) @@ -84,16 +108,24 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: logger.info(f"Configuration saved in {output_config_file}") if push_to_hub: - url = self._push_to_hub(repo, commit_message=commit_message) - logger.info(f"Configuration pushed to the hub in this commit: {url}") + # TODO: remove once transformers release version is way above 4.22. + if _transformers_version_is_below_threshold: + url = self._push_to_hub(repo, commit_message=commit_message) + logger.info(f"Configuration pushed to the hub in this commit: {url}") + else: + self._upload_modified_files( + save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token + ) # Adapted from transformers.configuration_utils.PretrainedConfig.get_configuration_file @classmethod def get_configuration_file(cls, configuration_files: List[str]) -> str: """ Get the configuration file to use for this version of transformers. + Args: configuration_files (`List[str]`): The list of available configuration files. + Returns: `str`: The configuration file to use. """ @@ -106,6 +138,7 @@ def get_configuration_file(cls, configuration_files: List[str]) -> str: configuration_files_map[v] = file_name available_versions = sorted(configuration_files_map.keys()) + # Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions. configuration_file = cls.CONFIG_NAME optimum_version = version.parse(__version__) for v in available_versions: @@ -125,11 +158,14 @@ def get_config_dict( """ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a [`PretrainedConfig`] using `from_dict`. + Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`): The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + Returns: `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object. + """ original_kwargs = copy.deepcopy(kwargs) # Get config dict associated with the base config file @@ -151,18 +187,6 @@ def get_config_dict( def _get_config_dict( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - """ - From a ``pretrained_model_name_or_path``, resolve to a dictionary of parameters, to be used for instantiating a - :class:`~transformers.PretrainedConfig` using ``from_dict``. - - Args: - pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`): - The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. - - Returns: - :obj:`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object. - - """ cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) @@ -170,43 +194,82 @@ def _get_config_dict( use_auth_token = kwargs.pop("use_auth_token", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) + trust_remote_code = kwargs.pop("trust_remote_code", None) + subfolder = kwargs.pop("subfolder", "") from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) + commit_hash = kwargs.pop("_commit_hash", None) + + if trust_remote_code is True: + logger.warning( + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" + " ignored." + ) user_agent = {"file_type": "config", "from_auto_class": from_auto_class} if from_pipeline is not None: user_agent["using_pipeline"] = from_pipeline - if is_offline_mode() and not local_files_only: - logger.info("Offline mode: forcing local_files_only=True") - local_files_only = True - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - configuration_file = kwargs.pop("_configuration_file", cls.CONFIG_NAME) - is_local = os.path.isdir(pretrained_model_name_or_path) - if os.path.isfile(pretrained_model_name_or_path): + is_local = os.path.isdir(pretrained_model_name_or_path) + if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + # Special case when pretrained_model_name_or_path is a local file resolved_config_file = pretrained_model_name_or_path is_local = True - elif os.path.isdir(pretrained_model_name_or_path): + # TODO: remove once transformers release version is way above 4.22. + elif _transformers_version_is_below_threshold and os.path.isdir(pretrained_model_name_or_path): + configuration_file = kwargs.pop("_configuration_file", cls.CONFIG_NAME) resolved_config_file = os.path.join(pretrained_model_name_or_path, configuration_file) if not os.path.isfile(resolved_config_file): raise EnvironmentError( f"Could not locate {configuration_file} inside {pretrained_model_name_or_path}." ) + # TODO: remove condition once transformers release version is way above 4.22. + elif not _transformers_version_is_below_threshold and is_remote_url(pretrained_model_name_or_path): + configuration_file = pretrained_model_name_or_path + resolved_config_file = download_url(pretrained_model_name_or_path) else: + configuration_file = kwargs.pop("_configuration_file", cls.CONFIG_NAME) + try: - resolved_config_file = hf_hub_download( - repo_id=pretrained_model_name_or_path, - filename=configuration_file, - revision=revision, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - use_auth_token=use_auth_token, - local_files_only=local_files_only, - ) + # TODO: remove once transformers release version is way above 4.22. + if _transformers_version_is_below_threshold: + config_file = hf_bucket_url( + pretrained_model_name_or_path, + filename=configuration_file, + revision=revision, + subfolder=subfolder if len(subfolder) > 0 else None, + mirror=None, + ) + # Load from URL or cache if already cached + resolved_config_file = cached_path( + config_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + ) + else: + # Load from local folder or from cache or download from model Hub and cache + resolved_config_file = cached_file( + pretrained_model_name_or_path, + configuration_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _commit_hash=commit_hash, + ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) except EnvironmentError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # the original exception. @@ -214,23 +277,27 @@ def _get_config_dict( except Exception: # For any other exception, we throw a generic error. raise EnvironmentError( - f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it " - "from 'https://huggingface.co/models', make sure you don't have a local directory with the same " - f"name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " - f"containing a {configuration_file} file" + f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the same" + f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory" + f" containing a {configuration_file} file" ) try: # Load config dict config_dict = cls._dict_from_json_file(resolved_config_file) + # TODO: remove once transformers release version is way above 4.22. + if _transformers_version_is_below_threshold: + config_dict["_commit_hash"] = commit_hash except (json.JSONDecodeError, UnicodeDecodeError): raise EnvironmentError( f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file." ) + if is_local: - logger.info(f"Loading configuration file {resolved_config_file}") + logger.info(f"loading configuration file {resolved_config_file}") else: - logger.info(f"Loading configuration file from cache at {resolved_config_file}") + logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}") return config_dict, kwargs @@ -238,20 +305,26 @@ def _get_config_dict( @classmethod def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig": """ - Instantiates a :class:`~transformers.PretrainedConfig` from a Python dictionary of parameters. + Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters. Args: - config_dict (:obj:`Dict[str, Any]`): + config_dict (`Dict[str, Any]`): Dictionary that will be used to instantiate the configuration object. Such a dictionary can be - retrieved from a pretrained checkpoint by leveraging the - :func:`~transformers.PretrainedConfig.get_config_dict` method. - kwargs (:obj:`Dict[str, Any]`): + retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method. + kwargs (`Dict[str, Any]`): Additional parameters from which to initialize the configuration object. Returns: - :class:`PretrainedConfig`: The configuration object instantiated from those parameters. + [`PretrainedConfig`]: The configuration object instantiated from those parameters. """ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + # Those arguments may be passed along for our internal telemetry. + # We remove them so they don't appear in `return_unused_kwargs`. + kwargs.pop("_from_auto", None) + kwargs.pop("_from_pipeline", None) + # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update. + if "_commit_hash" in kwargs and "_commit_hash" in config_dict: + kwargs["_commit_hash"] = config_dict["_commit_hash"] config = cls(**config_dict) @@ -259,6 +332,15 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig": config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) # Update config with kwargs if needed + if "num_labels" in kwargs and "id2label" in kwargs: + num_labels = kwargs["num_labels"] + id2label = kwargs["id2label"] if kwargs["id2label"] is not None else [] + if len(id2label) != num_labels: + raise ValueError( + f"You passed along `num_labels={num_labels }` with an incompatible id to label map: " + f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove " + "one of them." + ) to_remove = [] for key, value in kwargs.items(): if hasattr(config, key): @@ -280,14 +362,18 @@ def to_dict(self) -> Dict[str, Any]: Serializes this instance to a Python dictionary. Returns: - :obj:`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. """ output = copy.deepcopy(self.__dict__) if hasattr(self.__class__, "model_type"): output["model_type"] = self.__class__.model_type + if "_auto_class" in output: + del output["_auto_class"] + if "_commit_hash" in output: + del output["_commit_hash"] # Transformers version when serializing the model - output["transformers_version"] = transformers_version + output["transformers_version"] = transformers_version_str output["optimum_version"] = __version__ self.dict_torch_dtype_to_str(output) diff --git a/tests/test_configuration_utils.py b/tests/test_configuration_utils.py index 1da48005f1..764a53a91c 100644 --- a/tests/test_configuration_utils.py +++ b/tests/test_configuration_utils.py @@ -76,7 +76,7 @@ def test_push_to_hub(self): ) new_config = FakeConfig.from_pretrained(f"{USER}/optimum-test-base-config") - for k, v in config.__dict__.items(): + for k, v in config.to_dict().items(): if k != "optimum_version" and k != "transformers_version": self.assertEqual(v, getattr(new_config, k)) @@ -92,6 +92,6 @@ def test_push_to_hub_in_organization(self): ) new_config = FakeConfig.from_pretrained("valid_org/optimum-test-base-config-org") - for k, v in config.__dict__.items(): + for k, v in config.to_dict().items(): if k != "optimum_version" and k != "transformers_version": self.assertEqual(v, getattr(new_config, k))