From d47966536cd5ac1ed7e140edac65f00f471f656f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 12 Mar 2024 18:58:12 +0000 Subject: [PATCH] Examples: check `max_position_embeddings` in the translation example (#29600) check max_position_embeddings --- examples/pytorch/translation/run_translation.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/examples/pytorch/translation/run_translation.py b/examples/pytorch/translation/run_translation.py index 37bb37b3d86180..04a05fc477f46c 100755 --- a/examples/pytorch/translation/run_translation.py +++ b/examples/pytorch/translation/run_translation.py @@ -469,6 +469,19 @@ def main(): source_lang = data_args.source_lang.split("_")[0] target_lang = data_args.target_lang.split("_")[0] + # Check the whether the source target length fits in the model, if it has absolute positional embeddings + if ( + hasattr(model.config, "max_position_embeddings") + and not hasattr(model.config, "relative_attention_max_distance") + and model.config.max_position_embeddings < data_args.max_source_length + ): + raise ValueError( + f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has" + f" {model.config.max_position_embeddings} position encodings. Consider either reducing" + f" `--max_source_length` to {model.config.max_position_embeddings} or using a model with larger position " + "embeddings" + ) + # Temporarily set max_target_length for training. max_target_length = data_args.max_target_length padding = "max_length" if data_args.pad_to_max_length else False