From 234e3c10cce32c25348aa3718d91f0f80fdd2b59 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Fri, 16 Aug 2024 13:05:57 -0400 Subject: [PATCH] Reduce the error log when using core models that need their weights renamed, and provide a step forward (#32656) * Fin * Modify msg * Finish up nits --- src/transformers/modeling_utils.py | 42 +++++++++++++++++++++++++----- src/transformers/utils/logging.py | 15 +++++++++++ tests/utils/test_modeling_utils.py | 10 ++++--- 3 files changed, 57 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 369788e73698..e7bdf4ddaa04 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -104,7 +104,6 @@ XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() -PARAM_RENAME_WARNING = "A parameter name that contains `{}` will be renamed internally to `{}`. Please use a different name to suppress this warning." if is_accelerate_available(): @@ -692,17 +691,30 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_ # Convert old format to new format if needed from a PyTorch state_dict old_keys = [] new_keys = [] + renamed_keys = {} + renamed_gamma = {} + renamed_beta = {} + warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` " for key in state_dict.keys(): new_key = None if "gamma" in key: - logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight")) + # We add only the first key as an example new_key = key.replace("gamma", "weight") + renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma if "beta" in key: - logger.warning(PARAM_RENAME_WARNING.format("beta", "bias")) + # We add only the first key as an example new_key = key.replace("beta", "bias") + renamed_beta[key] = new_key if not renamed_beta else renamed_beta if new_key: old_keys.append(key) new_keys.append(new_key) + renamed_keys = {**renamed_gamma, **renamed_beta} + if renamed_keys: + warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n" + for old_key, new_key in renamed_keys.items(): + warning_msg += f"* `{old_key}` -> `{new_key}`\n" + warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." + logger.info_once(warning_msg) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) @@ -818,6 +830,7 @@ def _load_state_dict_into_meta_model( is_safetensors=False, keep_in_fp32_modules=None, unexpected_keys=None, # passing `unexpected` for cleanup from quantization items + pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys ): """ This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its @@ -840,18 +853,30 @@ def _load_state_dict_into_meta_model( old_keys = [] new_keys = [] + renamed_gamma = {} + renamed_beta = {} is_quantized = hf_quantizer is not None + warning_msg = f"This model {type(model)}" for key in state_dict.keys(): new_key = None if "gamma" in key: - logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight")) + # We add only the first key as an example new_key = key.replace("gamma", "weight") + renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma if "beta" in key: - logger.warning(PARAM_RENAME_WARNING.format("beta", "bias")) + # We add only the first key as an example new_key = key.replace("beta", "bias") + renamed_beta[key] = new_key if not renamed_beta else renamed_beta if new_key: old_keys.append(key) new_keys.append(new_key) + renamed_keys = {**renamed_gamma, **renamed_beta} + if renamed_keys: + warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n" + for old_key, new_key in renamed_keys.items(): + warning_msg += f"* `{old_key}` -> `{new_key}`\n" + warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users." + logger.info_once(warning_msg) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) @@ -4534,7 +4559,12 @@ def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=Fal @staticmethod def _load_pretrained_model_low_mem( - model, loaded_state_dict_keys, resolved_archive_file, start_prefix="", hf_quantizer=None + model, + loaded_state_dict_keys, + resolved_archive_file, + start_prefix="", + hf_quantizer=None, + pretrained_model_name_or_path=None, ): """ This is an experimental function that loads the model using ~1.x model size CPU memory diff --git a/src/transformers/utils/logging.py b/src/transformers/utils/logging.py index f2fbe393f724..a304e9d29f46 100644 --- a/src/transformers/utils/logging.py +++ b/src/transformers/utils/logging.py @@ -331,6 +331,21 @@ def warning_once(self, *args, **kwargs): logging.Logger.warning_once = warning_once +@functools.lru_cache(None) +def info_once(self, *args, **kwargs): + """ + This method is identical to `logger.info()`, but will emit the info with the same message only once + + Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache. + The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to + another type of cache that includes the caller frame information in the hashing function. + """ + self.info(*args, **kwargs) + + +logging.Logger.info_once = info_once + + class EmptyTqdm: """Dummy tqdm which doesn't do anything.""" diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 71c72f9212e5..521624f992e6 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1640,17 +1640,18 @@ def forward(self): logger = logging.get_logger("transformers.modeling_utils") config = PretrainedConfig() - warning_msg_gamma = "A parameter name that contains `gamma` will be renamed internally" + warning_msg_gamma = "`gamma_param` -> `weight_param`" model = TestModelGamma(config) with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir) - with LoggingLevel(logging.WARNING): + with LoggingLevel(logging.INFO): with CaptureLogger(logger) as cl1: _, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True) missing_keys = loading_info["missing_keys"] unexpected_keys = loading_info["unexpected_keys"] + self.assertIn("`TestModelGamma`", cl1.out) self.assertIn(warning_msg_gamma, cl1.out) self.assertIn("gamma_param", missing_keys) self.assertIn("weight_param", unexpected_keys) @@ -1664,17 +1665,18 @@ def __init__(self, config): def forward(self): return self.beta_param.sum() - warning_msg_beta = "A parameter name that contains `beta` will be renamed internally" + warning_msg_beta = "`beta_param` -> `bias_param`" model = TestModelBeta(config) with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir) - with LoggingLevel(logging.WARNING): + with LoggingLevel(logging.INFO): with CaptureLogger(logger) as cl2: _, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True) missing_keys = loading_info["missing_keys"] unexpected_keys = loading_info["unexpected_keys"] + self.assertIn("`TestModelBeta`", cl2.out) self.assertIn(warning_msg_beta, cl2.out) self.assertIn("beta_param", missing_keys) self.assertIn("bias_param", unexpected_keys)