From 0570c4b3f85bc38ecfa3a4f850cc32b14fe2f188 Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Fri, 6 Sep 2024 15:39:49 +0330 Subject: [PATCH] Add validation for maximum sequence length in modeling_whisper.py (#33196) * Add validation for maximum sequence length in modeling_whisper.py Added a validation check to ensure that the sequence length of labels does not exceed the maximum allowed length of 448 tokens. If the sequence length exceeds this limit, a ValueError is raised with a descriptive error message. This change prevents the model from encountering errors or unexpected behavior due to excessively long sequences during training or fine-tuning, ensuring consistent input dimensions and improving overall robustness. * Change exception message in src/transformers/models/whisper/modeling_whisper.py The exception message is for whisper's label's sequence max length. Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> * Change 448 to config.max_target_positions in src/transformers/models/whisper/modeling_whisper.py It's for whisper's config.max_target_positions. Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> * Change method's documentation in src/transformers/models/whisper/modeling_whisper.py * Add test for maximum label's sequence length in test_modeling_whisper.py * Add self to modeling_whisper.py * Update test_modeling_whisper.py with respect to automatic validations * Update modeling_whisper.py with respect to ci/circleci: check_code_quality * Update test_modeling_whisper.py with respect to ci/circleci: check_code_quality * Update test_modeling_whisper.py with respect to ci/circleci: tests_generate * Update test_modeling_whisper.py with respect to ci/circleci: tests_generate * Update test_modeling_whisper.py with respect to ci/circleci: check_code_quality * Separate test_labels_sequence_max_length tests in test_modeling_whisper.py * Update test_modeling_whisper.py with respect to ci/circleci: check_code_quality * Remove assert from test_modeling_whisper.py * Add max_target_positions to WhisperModelTester in test_modeling_whisper.py * Update test_modeling_whisper.py with respect to ci/circleci: check_code_quality * Update test_modeling_whisper.py with respect to ci/circleci: tests_generate * Update test_modeling_whisper.py * Change test_labels_sequence_max_length_error_after_changing_config in test_modeling_whisper.py * Change self.config.max_target_positions to self.max_target_positions modeling_whisper.py * Add new tests in test_modeling_whisper.py * Update test_modeling_whisper.py --------- Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> --- .../models/whisper/modeling_whisper.py | 7 ++- tests/models/whisper/test_modeling_whisper.py | 57 +++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 81f60edbfa98d8..b82b978e5e6d95 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1671,6 +1671,7 @@ def __init__(self, config: WhisperConfig): super().__init__(config) self.model = WhisperModel(config) self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False) + self.max_target_positions = config.max_target_positions # Initialize weights and apply final processing self.post_init() @@ -1723,7 +1724,7 @@ def forward( labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is - only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + only computed for the tokens with labels in `[0, ..., config.vocab_size]`. `sequence_length` should be smaller than or equal to `config.max_target_positions`. Returns: @@ -1751,6 +1752,10 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if labels.shape[1] > self.max_target_positions: + raise ValueError( + f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {self.max_target_positions} tokens." + ) if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 66a930499f73d1..e503937458ce90 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1676,6 +1676,63 @@ def test_flash_attn_2_generate_reuse_cache(self): past_key_values=past_key_values, ) + def test_labels_sequence_max_length_correct(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + input_features = input_dict["input_features"] + + labels_length = config.max_target_positions + labels = torch.ones(1, labels_length, dtype=torch.int64) + + model = model_class(config) + model(input_features=input_features, labels=labels) + + def test_labels_sequence_max_length_correct_after_changing_config(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + input_features = input_dict["input_features"] + + config.max_target_positions += 100 + + labels_length = config.max_target_positions + labels = torch.ones(1, labels_length, dtype=torch.int64) + + model = model_class(config) + model(input_features=input_features, labels=labels) + + def test_labels_sequence_max_length_error(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + input_features = input_dict["input_features"] + + labels_length = config.max_target_positions + 1 + labels = torch.ones(1, labels_length, dtype=torch.int64) + + model = model_class(config) + with self.assertRaises(ValueError): + model(input_features=input_features, labels=labels) + + def test_labels_sequence_max_length_error_after_changing_config(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + model = model_class(config) + input_features = input_dict["input_features"] + + labels_length = config.max_target_positions + 1 + labels = torch.ones(1, labels_length, dtype=torch.int64) + + new_max_length = config.max_target_positions + 100 + model.config.max_length = new_max_length + model.generation_config.max_length = new_max_length + config.max_target_positions = new_max_length + + with self.assertRaises(ValueError): + model(input_features=input_features, labels=labels) + @require_torch @require_torchaudio