From fb32b553fd27fa0ce37078e03937c3d39bbab5b9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 24 Oct 2024 18:42:03 +0200 Subject: [PATCH] Correct the new defaults (#34377) * Correct the new defaults * CIs * add check * Update utils.py * Update utils.py * Add the max_length in generate test checking shape without passing length * style * CIs * fix fx CI issue --- src/transformers/generation/utils.py | 5 ++++- .../encoder_decoder/test_modeling_encoder_decoder.py | 4 +++- .../test_modeling_speech_encoder_decoder.py | 4 +++- .../test_modeling_vision_encoder_decoder.py | 7 ++++++- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3938457155d83f..efe953db051cb3 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1440,8 +1440,11 @@ def _prepare_generated_length( and not self.config.is_encoder_decoder ): generation_config.max_length -= inputs_tensor.shape[1] - else: # by default let's always generate 10 new tokens + elif has_default_max_length: # by default let's always generate 20 new tokens generation_config.max_length = generation_config.max_length + input_ids_length + max_position_embeddings = getattr(self.config, "max_position_embeddings", None) + if max_position_embeddings is not None: + generation_config.max_length = min(generation_config.max_length, max_position_embeddings) # same for min length if generation_config.min_new_tokens is not None: diff --git a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py index 0ee4b75ed803e3..64ebedcb45984b 100644 --- a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py @@ -488,7 +488,9 @@ def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config # Bert does not have a bos token id, so use pad_token_id instead generated_output = enc_dec_model.generate( - input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id + input_ids, + decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id, + max_length=decoder_config.max_length, ) self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,)) diff --git a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py index 6e0b7fa9782fbc..7dcb7c406ae287 100644 --- a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py +++ b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py @@ -362,7 +362,9 @@ def check_encoder_decoder_model_generate( # Bert does not have a bos token id, so use pad_token_id instead generated_output = enc_dec_model.generate( - inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id + inputs, + decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id, + max_length=decoder_config.max_length, ) self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,)) diff --git a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py index 7def8a9ac96507..77e2a19fea4861 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py @@ -306,7 +306,9 @@ def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_val # Bert does not have a bos token id, so use pad_token_id instead generated_output = enc_dec_model.generate( - inputs, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id + inputs, + decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id, + max_length=decoder_config.max_length, ) self.assertEqual(generated_output.shape, (inputs.shape[0],) + (decoder_config.max_length,)) @@ -873,6 +875,7 @@ def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_val generated_output = enc_dec_model.generate( pixel_values=pixel_values, decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id, + max_length=decoder_config.max_length, **kwargs, ) self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,)) @@ -990,6 +993,7 @@ def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_val generated_output = enc_dec_model.generate( pixel_values=pixel_values, decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id, + max_length=decoder_config.max_length, **kwargs, ) self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,)) @@ -1107,6 +1111,7 @@ def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_val generated_output = enc_dec_model.generate( pixel_values=pixel_values, decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id, + max_length=decoder_config.max_length, **kwargs, ) self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,))