Skip to content

Commit

Permalink
Fix GPT-2 warnings (#11213)
Browse files Browse the repository at this point in the history
* Fix GPT-2 warnings

* Update src/transformers/models/gpt2/modeling_gpt2.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
  • Loading branch information
LysandreJik and stas00 authored Apr 13, 2021
1 parent 0cd89d8 commit 823df93
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
10 changes: 10 additions & 0 deletions src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,16 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
missing_keys += missing_keys_pt

# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if pt_model._keys_to_ignore_on_load_missing is not None:
for pat in pt_model._keys_to_ignore_on_load_missing:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]

if pt_model._keys_to_ignore_on_load_unexpected is not None:
for pat in pt_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]

if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the TF 2.0 model were not used when "
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def custom_forward(*inputs):
GPT2_START_DOCSTRING,
)
class GPT2LMHeadModel(GPT2PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]

def __init__(self, config):
super().__init__(config)
Expand Down

0 comments on commit 823df93

Please sign in to comment.