diff --git a/examples/pytorch/summarization/run_summarization.py b/examples/pytorch/summarization/run_summarization.py index 05291a85fe7365..c310cbd4f43ea3 100755 --- a/examples/pytorch/summarization/run_summarization.py +++ b/examples/pytorch/summarization/run_summarization.py @@ -353,6 +353,8 @@ def main(): use_auth_token=True if model_args.use_auth_token else None, ) + model.resize_token_embeddings(len(tokenizer)) + if model.config.decoder_start_token_id is None: raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") diff --git a/examples/pytorch/translation/run_translation.py b/examples/pytorch/translation/run_translation.py index 125ab707103929..56503f98ef3766 100755 --- a/examples/pytorch/translation/run_translation.py +++ b/examples/pytorch/translation/run_translation.py @@ -337,6 +337,8 @@ def main(): use_auth_token=True if model_args.use_auth_token else None, ) + model.resize_token_embeddings(len(tokenizer)) + # Set decoder_start_token_id if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): if isinstance(tokenizer, MBartTokenizer):