diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 3a36b41f9eb925..fad7e865b9d6b9 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -380,6 +380,16 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False) missing_keys += missing_keys_pt + # Some models may have keys that are not in the state by design, removing them before needlessly warning + # the user. + if pt_model._keys_to_ignore_on_load_missing is not None: + for pat in pt_model._keys_to_ignore_on_load_missing: + missing_keys = [k for k in missing_keys if re.search(pat, k) is None] + + if pt_model._keys_to_ignore_on_load_unexpected is not None: + for pat in pt_model._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + if len(unexpected_keys) > 0: logger.warning( f"Some weights of the TF 2.0 model were not used when " diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 881b17b2d8760b..a7c4ba1277e6a2 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -802,7 +802,7 @@ def custom_forward(*inputs): GPT2_START_DOCSTRING, ) class GPT2LMHeadModel(GPT2PreTrainedModel): - _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] + _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] def __init__(self, config): super().__init__(config)