Skip to content

Commit

Permalink
Reduce the error log when using core models that need their weights r…
Browse files Browse the repository at this point in the history
…enamed, and provide a step forward (#32656)

* Fin

* Modify msg

* Finish up nits
  • Loading branch information
muellerzr authored and ArthurZucker committed Aug 16, 2024
1 parent f77b986 commit 234e3c1
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 10 deletions.
42 changes: 36 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@

XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
PARAM_RENAME_WARNING = "A parameter name that contains `{}` will be renamed internally to `{}`. Please use a different name to suppress this warning."


if is_accelerate_available():
Expand Down Expand Up @@ -692,17 +691,30 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
new_keys = []
renamed_keys = {}
renamed_gamma = {}
renamed_beta = {}
warning_msg = f"A pretrained model of type `{model_to_load.__class__.__name__}` "
for key in state_dict.keys():
new_key = None
if "gamma" in key:
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight"))
# We add only the first key as an example
new_key = key.replace("gamma", "weight")
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
if "beta" in key:
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias"))
# We add only the first key as an example
new_key = key.replace("beta", "bias")
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
if new_key:
old_keys.append(key)
new_keys.append(new_key)
renamed_keys = {**renamed_gamma, **renamed_beta}
if renamed_keys:
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for old_key, new_key in renamed_keys.items():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

Expand Down Expand Up @@ -818,6 +830,7 @@ def _load_state_dict_into_meta_model(
is_safetensors=False,
keep_in_fp32_modules=None,
unexpected_keys=None, # passing `unexpected` for cleanup from quantization items
pretrained_model_name_or_path=None, # for flagging the user when the model contains renamed keys
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
Expand All @@ -840,18 +853,30 @@ def _load_state_dict_into_meta_model(

old_keys = []
new_keys = []
renamed_gamma = {}
renamed_beta = {}
is_quantized = hf_quantizer is not None
warning_msg = f"This model {type(model)}"
for key in state_dict.keys():
new_key = None
if "gamma" in key:
logger.warning(PARAM_RENAME_WARNING.format("gamma", "weight"))
# We add only the first key as an example
new_key = key.replace("gamma", "weight")
renamed_gamma[key] = new_key if not renamed_gamma else renamed_gamma
if "beta" in key:
logger.warning(PARAM_RENAME_WARNING.format("beta", "bias"))
# We add only the first key as an example
new_key = key.replace("beta", "bias")
renamed_beta[key] = new_key if not renamed_beta else renamed_beta
if new_key:
old_keys.append(key)
new_keys.append(new_key)
renamed_keys = {**renamed_gamma, **renamed_beta}
if renamed_keys:
warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
for old_key, new_key in renamed_keys.items():
warning_msg += f"* `{old_key}` -> `{new_key}`\n"
warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
logger.info_once(warning_msg)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

Expand Down Expand Up @@ -4534,7 +4559,12 @@ def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=Fal

@staticmethod
def _load_pretrained_model_low_mem(
model, loaded_state_dict_keys, resolved_archive_file, start_prefix="", hf_quantizer=None
model,
loaded_state_dict_keys,
resolved_archive_file,
start_prefix="",
hf_quantizer=None,
pretrained_model_name_or_path=None,
):
"""
This is an experimental function that loads the model using ~1.x model size CPU memory
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,21 @@ def warning_once(self, *args, **kwargs):
logging.Logger.warning_once = warning_once


@functools.lru_cache(None)
def info_once(self, *args, **kwargs):
"""
This method is identical to `logger.info()`, but will emit the info with the same message only once
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
another type of cache that includes the caller frame information in the hashing function.
"""
self.info(*args, **kwargs)


logging.Logger.info_once = info_once


class EmptyTqdm:
"""Dummy tqdm which doesn't do anything."""

Expand Down
10 changes: 6 additions & 4 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1640,17 +1640,18 @@ def forward(self):

logger = logging.get_logger("transformers.modeling_utils")
config = PretrainedConfig()
warning_msg_gamma = "A parameter name that contains `gamma` will be renamed internally"
warning_msg_gamma = "`gamma_param` -> `weight_param`"
model = TestModelGamma(config)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.WARNING):
with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as cl1:
_, loading_info = TestModelGamma.from_pretrained(tmp_dir, config=config, output_loading_info=True)

missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"]
self.assertIn("`TestModelGamma`", cl1.out)
self.assertIn(warning_msg_gamma, cl1.out)
self.assertIn("gamma_param", missing_keys)
self.assertIn("weight_param", unexpected_keys)
Expand All @@ -1664,17 +1665,18 @@ def __init__(self, config):
def forward(self):
return self.beta_param.sum()

warning_msg_beta = "A parameter name that contains `beta` will be renamed internally"
warning_msg_beta = "`beta_param` -> `bias_param`"
model = TestModelBeta(config)

with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
with LoggingLevel(logging.WARNING):
with LoggingLevel(logging.INFO):
with CaptureLogger(logger) as cl2:
_, loading_info = TestModelBeta.from_pretrained(tmp_dir, config=config, output_loading_info=True)

missing_keys = loading_info["missing_keys"]
unexpected_keys = loading_info["unexpected_keys"]
self.assertIn("`TestModelBeta`", cl2.out)
self.assertIn(warning_msg_beta, cl2.out)
self.assertIn("beta_param", missing_keys)
self.assertIn("bias_param", unexpected_keys)
Expand Down

0 comments on commit 234e3c1

Please sign in to comment.