From 3056095a45d2fbbd8ab818a68dd1a00828c45158 Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Thu, 29 Aug 2024 15:21:08 +0330 Subject: [PATCH 01/23] Add validation for maximum sequence length in modeling_whisper.py Added a validation check to ensure that the sequence length of labels does not exceed the maximum allowed length of 448 tokens. If the sequence length exceeds this limit, a ValueError is raised with a descriptive error message. This change prevents the model from encountering errors or unexpected behavior due to excessively long sequences during training or fine-tuning, ensuring consistent input dimensions and improving overall robustness. --- src/transformers/models/whisper/modeling_whisper.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 81f60edbfa98d8..0d20c63e2c51d2 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1721,6 +1721,7 @@ def forward( ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + `sequence_length` should be smaller than or equal to 448 which is the Whisper's decoder's output limitation. Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. @@ -1751,6 +1752,10 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: + if labels.shape[1] > 448: + raise ValueError( + f"Sequence length {labels.shape[1]} exceeds the maximum allowed length of 448 tokens." + ) if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id From 4330cc3d7009fa7dbdb0a8580fdacef6caf1c9aa Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Wed, 4 Sep 2024 14:26:31 +0330 Subject: [PATCH 02/23] Change exception message in src/transformers/models/whisper/modeling_whisper.py The exception message is for whisper's label's sequence max length. Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> --- src/transformers/models/whisper/modeling_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 0d20c63e2c51d2..67f447a25a10fe 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1754,7 +1754,7 @@ def forward( if labels is not None: if labels.shape[1] > 448: raise ValueError( - f"Sequence length {labels.shape[1]} exceeds the maximum allowed length of 448 tokens." + f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {config.max_target_positions} tokens." ) if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( From f12a3c264fdb07a5f308861bf227e743f0a492b9 Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Wed, 4 Sep 2024 14:36:45 +0330 Subject: [PATCH 03/23] Change 448 to config.max_target_positions in src/transformers/models/whisper/modeling_whisper.py It's for whisper's config.max_target_positions. Co-authored-by: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> --- src/transformers/models/whisper/modeling_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 67f447a25a10fe..5eef06e97924ca 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1752,7 +1752,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: - if labels.shape[1] > 448: + if labels.shape[1] > config.max_target_positions: raise ValueError( f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {config.max_target_positions} tokens." ) From 264cd5b896971342d36794bb691a7378aa1bbe75 Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Wed, 4 Sep 2024 16:51:51 +0330 Subject: [PATCH 04/23] Change method's documentation in src/transformers/models/whisper/modeling_whisper.py --- src/transformers/models/whisper/modeling_whisper.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 5eef06e97924ca..4cb35b31ea942c 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1721,10 +1721,9 @@ def forward( ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - `sequence_length` should be smaller than or equal to 448 which is the Whisper's decoder's output limitation. Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is - only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + only computed for the tokens with labels in `[0, ..., config.vocab_size]`. `sequence_length` should be smaller than or equal to `config.max_target_positions`. Returns: From cd4db32887acac4f9e4562f273dab145464e4a7d Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Wed, 4 Sep 2024 17:42:20 +0330 Subject: [PATCH 05/23] Add test for maximum label's sequence length in test_modeling_whisper.py --- tests/models/whisper/test_modeling_whisper.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index f3d191b4d3c4c6..6daa71980c0804 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -242,6 +242,7 @@ def __init__( self.decoder_start_token_id = decoder_start_token_id self.num_conv_layers = num_conv_layers self.suppress_tokens = suppress_tokens + self.max_target_positions = 448 def prepare_config_and_inputs(self): input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size) @@ -278,6 +279,7 @@ def get_config(self): encoder_ffn_dim=self.hidden_size, decoder_start_token_id=self.decoder_start_token_id, suppress_tokens=self.suppress_tokens, + max_target_positions=self.max_target_positions, ) def prepare_config_and_inputs_for_common(self): @@ -1676,6 +1678,35 @@ def test_flash_attn_2_generate_reuse_cache(self): past_key_values=past_key_values, ) + def test_labels_sequence_max_length(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + model = model_class(config) + dummy_input_features = torch.ones(1, config.num_mel_bins, 3000, dtype=torch.float32) + + correct_labels_length = 448 + assert correct_labels_length <= config.max_target_positions + dummy_labels = torch.ones(1, config.max_target_positions, dtype=torch.int64) + + # The following model should run without any problem + model(input_features=dummy_input_features, labels=dummy_labels) + + error_labels_length = 449 + assert error_labels_length > config.max_target_positions + dummy_labels = torch.ones(1, config.max_target_positions, dtype=torch.int64) + + with self.assertRaises(ValueError): + model(input_features=dummy_input_features, labels=dummy_labels) + + new_max_length = 500 + assert new_max_length > config.max_target_positions + model.config.max_length = 500 + model.generation_config.max_length = 500 + + with self.assertRaises(ValueError): + model(input_features=dummy_input_features, labels=dummy_labels) + @require_torch @require_torchaudio From 53dea493b977bfbd784d605f619f7cf753aa7e7d Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Wed, 4 Sep 2024 17:53:56 +0330 Subject: [PATCH 06/23] Add self to modeling_whisper.py --- src/transformers/models/whisper/modeling_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 4cb35b31ea942c..d3e324942f465b 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1751,7 +1751,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: - if labels.shape[1] > config.max_target_positions: + if labels.shape[1] > self.config.max_target_positions: raise ValueError( f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {config.max_target_positions} tokens." ) From 933a0584cf95a3040a6541903e0880bf61122700 Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Wed, 4 Sep 2024 17:57:43 +0330 Subject: [PATCH 07/23] Update test_modeling_whisper.py with respect to automatic validations --- tests/models/whisper/test_modeling_whisper.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 6daa71980c0804..d6ef641f81b049 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -210,7 +210,6 @@ def __init__( attention_probs_dropout_prob=0.1, max_position_embeddings=20, max_source_positions=30, - max_target_positions=40, bos_token_id=98, eos_token_id=98, pad_token_id=0, @@ -235,14 +234,13 @@ def __init__( self.num_mel_bins = num_mel_bins self.max_position_embeddings = max_position_embeddings self.max_source_positions = max_source_positions - self.max_target_positions = max_target_positions + self.max_target_positions = 448 self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id self.decoder_start_token_id = decoder_start_token_id self.num_conv_layers = num_conv_layers self.suppress_tokens = suppress_tokens - self.max_target_positions = 448 def prepare_config_and_inputs(self): input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size) @@ -279,7 +277,6 @@ def get_config(self): encoder_ffn_dim=self.hidden_size, decoder_start_token_id=self.decoder_start_token_id, suppress_tokens=self.suppress_tokens, - max_target_positions=self.max_target_positions, ) def prepare_config_and_inputs_for_common(self): @@ -1680,7 +1677,7 @@ def test_flash_attn_2_generate_reuse_cache(self): def test_labels_sequence_max_length(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - + for model_class in self.all_generative_model_classes: model = model_class(config) dummy_input_features = torch.ones(1, config.num_mel_bins, 3000, dtype=torch.float32) From 2be9b27cd6df410cb4c11fd980a8bfd3affa3d87 Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Wed, 4 Sep 2024 18:00:55 +0330 Subject: [PATCH 08/23] Update modeling_whisper.py with respect to ci/circleci: check_code_quality --- src/transformers/models/whisper/modeling_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index d3e324942f465b..f1c3d57b32a036 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1753,7 +1753,7 @@ def forward( if labels is not None: if labels.shape[1] > self.config.max_target_positions: raise ValueError( - f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {config.max_target_positions} tokens." + f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {self.config.max_target_positions} tokens." ) if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( From 9c15e68cd17fcc5a13cc6997448d77af2617d008 Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Wed, 4 Sep 2024 18:02:08 +0330 Subject: [PATCH 09/23] Update test_modeling_whisper.py with respect to ci/circleci: check_code_quality --- tests/models/whisper/test_modeling_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index d6ef641f81b049..491a1b249b4813 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1675,7 +1675,7 @@ def test_flash_attn_2_generate_reuse_cache(self): past_key_values=past_key_values, ) - def test_labels_sequence_max_length(self): + def test_labels_sequence_max_length(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_generative_model_classes: From 81d704ec5efa0a562dd5de5719f340f278110d85 Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Wed, 4 Sep 2024 18:08:16 +0330 Subject: [PATCH 10/23] Update test_modeling_whisper.py with respect to ci/circleci: tests_generate --- tests/models/whisper/test_modeling_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 491a1b249b4813..02087c687c7230 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1680,7 +1680,7 @@ def test_labels_sequence_max_length(self): for model_class in self.all_generative_model_classes: model = model_class(config) - dummy_input_features = torch.ones(1, config.num_mel_bins, 3000, dtype=torch.float32) + dummy_input_features = torch.ones(1, config.num_mel_bins, config.seq_length, dtype=torch.float32) correct_labels_length = 448 assert correct_labels_length <= config.max_target_positions From f5ae7ec5dc18df64b08bb264e595f8e0c71e55ab Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Wed, 4 Sep 2024 18:20:37 +0330 Subject: [PATCH 11/23] Update test_modeling_whisper.py with respect to ci/circleci: tests_generate --- tests/models/whisper/test_modeling_whisper.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 02087c687c7230..df3c1592dc110d 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1606,7 +1606,7 @@ def test_custom_4d_attention_mask(self): def test_generate_output_type(self, return_dict_in_generate): expected_output_type = GenerateEncoderDecoderOutput if return_dict_in_generate else torch.Tensor for model_class in self.all_generative_model_classes: - config, inputs = self.model_tester.prepare_config_and_inputs() + config, inputs_dict = self.model_tester.prepare_config_and_inputs() model = model_class(config).to(torch_device).eval() # short-form generation without fallback @@ -1676,25 +1676,25 @@ def test_flash_attn_2_generate_reuse_cache(self): ) def test_labels_sequence_max_length(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_generative_model_classes: model = model_class(config) - dummy_input_features = torch.ones(1, config.num_mel_bins, config.seq_length, dtype=torch.float32) + input_features = input_dict["input_features"] correct_labels_length = 448 assert correct_labels_length <= config.max_target_positions - dummy_labels = torch.ones(1, config.max_target_positions, dtype=torch.int64) + labels = torch.ones(1, config.max_target_positions, dtype=torch.int64) # The following model should run without any problem - model(input_features=dummy_input_features, labels=dummy_labels) + model(input_features=input_features, labels=labels) error_labels_length = 449 assert error_labels_length > config.max_target_positions - dummy_labels = torch.ones(1, config.max_target_positions, dtype=torch.int64) + labels = torch.ones(1, config.max_target_positions, dtype=torch.int64) with self.assertRaises(ValueError): - model(input_features=dummy_input_features, labels=dummy_labels) + model(input_features=input_features, labels=labels) new_max_length = 500 assert new_max_length > config.max_target_positions @@ -1702,7 +1702,7 @@ def test_labels_sequence_max_length(self): model.generation_config.max_length = 500 with self.assertRaises(ValueError): - model(input_features=dummy_input_features, labels=dummy_labels) + model(input_features=input_features, labels=labels) @require_torch From b5866868af8d6516266b83fea16ffb15feb69079 Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Wed, 4 Sep 2024 18:25:04 +0330 Subject: [PATCH 12/23] Update test_modeling_whisper.py with respect to ci/circleci: check_code_quality --- tests/models/whisper/test_modeling_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index df3c1592dc110d..02898e856fa8b5 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1606,7 +1606,7 @@ def test_custom_4d_attention_mask(self): def test_generate_output_type(self, return_dict_in_generate): expected_output_type = GenerateEncoderDecoderOutput if return_dict_in_generate else torch.Tensor for model_class in self.all_generative_model_classes: - config, inputs_dict = self.model_tester.prepare_config_and_inputs() + config, inputs = self.model_tester.prepare_config_and_inputs() model = model_class(config).to(torch_device).eval() # short-form generation without fallback From 8aa48326bfb5fe006527377b9ab23a1ab3806299 Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Wed, 4 Sep 2024 18:41:24 +0330 Subject: [PATCH 13/23] Separate test_labels_sequence_max_length tests in test_modeling_whisper.py --- tests/models/whisper/test_modeling_whisper.py | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 02898e856fa8b5..dc9a18c218d01c 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1675,27 +1675,44 @@ def test_flash_attn_2_generate_reuse_cache(self): past_key_values=past_key_values, ) - def test_labels_sequence_max_length(self): + def test_labels_sequence_max_length_correct(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_generative_model_classes: model = model_class(config) input_features = input_dict["input_features"] + + labels_length = 448 + assert labels_length <= config.max_target_positions + labels = torch.ones(1, labels_length, dtype=torch.int64) - correct_labels_length = 448 - assert correct_labels_length <= config.max_target_positions - labels = torch.ones(1, config.max_target_positions, dtype=torch.int64) - - # The following model should run without any problem model(input_features=input_features, labels=labels) - error_labels_length = 449 - assert error_labels_length > config.max_target_positions - labels = torch.ones(1, config.max_target_positions, dtype=torch.int64) + def test_labels_sequence_max_length_error(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + model = model_class(config) + input_features = input_dict["input_features"] + + labels_length = 449 + assert labels_length > config.max_target_positions + labels = torch.ones(1, labels_length, dtype=torch.int64) with self.assertRaises(ValueError): model(input_features=input_features, labels=labels) + def test_labels_sequence_max_length_error_after_changing_config(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: + model = model_class(config) + input_features = input_dict["input_features"] + + labels_length = 449 + assert labels_length > config.max_target_positions + labels = torch.ones(1, labels_length, dtype=torch.int64) + new_max_length = 500 assert new_max_length > config.max_target_positions model.config.max_length = 500 From 8bf582f176ffc7a367bfb3dca02f5a68f478fec7 Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Wed, 4 Sep 2024 18:48:45 +0330 Subject: [PATCH 14/23] Update test_modeling_whisper.py with respect to ci/circleci: check_code_quality --- tests/models/whisper/test_modeling_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index dc9a18c218d01c..fc99a36d86b05d 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1681,7 +1681,7 @@ def test_labels_sequence_max_length_correct(self): for model_class in self.all_generative_model_classes: model = model_class(config) input_features = input_dict["input_features"] - + labels_length = 448 assert labels_length <= config.max_target_positions labels = torch.ones(1, labels_length, dtype=torch.int64) From 13baea3f44fcbc9e4da2d298a7798136b391c51f Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Wed, 4 Sep 2024 20:10:05 +0330 Subject: [PATCH 15/23] Remove assert from test_modeling_whisper.py --- tests/models/whisper/test_modeling_whisper.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index fc99a36d86b05d..751dea46e40ac3 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1682,8 +1682,7 @@ def test_labels_sequence_max_length_correct(self): model = model_class(config) input_features = input_dict["input_features"] - labels_length = 448 - assert labels_length <= config.max_target_positions + labels_length = config.max_target_positions labels = torch.ones(1, labels_length, dtype=torch.int64) model(input_features=input_features, labels=labels) @@ -1695,8 +1694,7 @@ def test_labels_sequence_max_length_error(self): model = model_class(config) input_features = input_dict["input_features"] - labels_length = 449 - assert labels_length > config.max_target_positions + labels_length = config.max_target_positions + 1 labels = torch.ones(1, labels_length, dtype=torch.int64) with self.assertRaises(ValueError): @@ -1709,14 +1707,12 @@ def test_labels_sequence_max_length_error_after_changing_config(self): model = model_class(config) input_features = input_dict["input_features"] - labels_length = 449 - assert labels_length > config.max_target_positions + labels_length = config.max_target_positions + 1 labels = torch.ones(1, labels_length, dtype=torch.int64) - new_max_length = 500 - assert new_max_length > config.max_target_positions - model.config.max_length = 500 - model.generation_config.max_length = 500 + new_max_length = config.max_target_positions + 100 + model.config.max_length = new_max_length + model.generation_config.max_length = new_max_length with self.assertRaises(ValueError): model(input_features=input_features, labels=labels) From be3d4d5c98f2b0f464b622e2518fdad1da2f9c67 Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Wed, 4 Sep 2024 20:18:08 +0330 Subject: [PATCH 16/23] Add max_target_positions to WhisperModelTester in test_modeling_whisper.py --- tests/models/whisper/test_modeling_whisper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 751dea46e40ac3..26446ae2f70eff 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -218,6 +218,8 @@ def __init__( num_conv_layers=1, suppress_tokens=None, ): + config = self.get_config() + self.parent = parent self.batch_size = batch_size self.seq_length = seq_length @@ -234,7 +236,7 @@ def __init__( self.num_mel_bins = num_mel_bins self.max_position_embeddings = max_position_embeddings self.max_source_positions = max_source_positions - self.max_target_positions = 448 + self.max_target_positions = config.max_target_positions self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id From 6cf4d8d3ecd4985eeb93f7d7fad7dbc5b5bf9fa7 Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Wed, 4 Sep 2024 20:21:47 +0330 Subject: [PATCH 17/23] Update test_modeling_whisper.py with respect to ci/circleci: check_code_quality --- tests/models/whisper/test_modeling_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 26446ae2f70eff..9ec06c94decfcc 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -219,7 +219,7 @@ def __init__( suppress_tokens=None, ): config = self.get_config() - + self.parent = parent self.batch_size = batch_size self.seq_length = seq_length From 930a933806f1d8d98c687c76c088fa35bba6dce1 Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Wed, 4 Sep 2024 20:26:27 +0330 Subject: [PATCH 18/23] Update test_modeling_whisper.py with respect to ci/circleci: tests_generate --- tests/models/whisper/test_modeling_whisper.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 9ec06c94decfcc..04708b7ea6a904 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -218,8 +218,6 @@ def __init__( num_conv_layers=1, suppress_tokens=None, ): - config = self.get_config() - self.parent = parent self.batch_size = batch_size self.seq_length = seq_length @@ -236,7 +234,6 @@ def __init__( self.num_mel_bins = num_mel_bins self.max_position_embeddings = max_position_embeddings self.max_source_positions = max_source_positions - self.max_target_positions = config.max_target_positions self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id @@ -244,6 +241,9 @@ def __init__( self.num_conv_layers = num_conv_layers self.suppress_tokens = suppress_tokens + config = self.get_config() + self.max_target_positions = config.max_target_positions + def prepare_config_and_inputs(self): input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size) @@ -271,7 +271,6 @@ def get_config(self): attention_dropout=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, max_source_positions=self.max_source_positions, - max_target_positions=self.max_target_positions, eos_token_id=self.eos_token_id, bos_token_id=self.bos_token_id, pad_token_id=self.pad_token_id, From fb1cd5e522f4cda982d17c1ade527dea28f46a0b Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Thu, 5 Sep 2024 00:02:59 +0330 Subject: [PATCH 19/23] Update test_modeling_whisper.py --- tests/models/whisper/test_modeling_whisper.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 04708b7ea6a904..41be1fa41d56c3 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -210,6 +210,7 @@ def __init__( attention_probs_dropout_prob=0.1, max_position_embeddings=20, max_source_positions=30, + max_target_positions=40, bos_token_id=98, eos_token_id=98, pad_token_id=0, @@ -234,6 +235,7 @@ def __init__( self.num_mel_bins = num_mel_bins self.max_position_embeddings = max_position_embeddings self.max_source_positions = max_source_positions + self.max_target_positions = max_target_positions self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id @@ -241,9 +243,6 @@ def __init__( self.num_conv_layers = num_conv_layers self.suppress_tokens = suppress_tokens - config = self.get_config() - self.max_target_positions = config.max_target_positions - def prepare_config_and_inputs(self): input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size) @@ -271,6 +270,7 @@ def get_config(self): attention_dropout=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, max_source_positions=self.max_source_positions, + max_target_positions=self.max_target_positions, eos_token_id=self.eos_token_id, bos_token_id=self.bos_token_id, pad_token_id=self.pad_token_id, @@ -1714,6 +1714,7 @@ def test_labels_sequence_max_length_error_after_changing_config(self): new_max_length = config.max_target_positions + 100 model.config.max_length = new_max_length model.generation_config.max_length = new_max_length + config.max_target_positions = new_max_length with self.assertRaises(ValueError): model(input_features=input_features, labels=labels) From 88cd92fef6529f3051a0d02afaa351309824c1b8 Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Thu, 5 Sep 2024 00:18:55 +0330 Subject: [PATCH 20/23] Change test_labels_sequence_max_length_error_after_changing_config in test_modeling_whisper.py --- tests/models/whisper/test_modeling_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 41be1fa41d56c3..8b7db04c598bbb 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1703,6 +1703,7 @@ def test_labels_sequence_max_length_error(self): def test_labels_sequence_max_length_error_after_changing_config(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.max_target_positions = 500 for model_class in self.all_generative_model_classes: model = model_class(config) @@ -1714,7 +1715,6 @@ def test_labels_sequence_max_length_error_after_changing_config(self): new_max_length = config.max_target_positions + 100 model.config.max_length = new_max_length model.generation_config.max_length = new_max_length - config.max_target_positions = new_max_length with self.assertRaises(ValueError): model(input_features=input_features, labels=labels) From 4b677d547f1d77a517125caa3e178882ff4fe306 Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Thu, 5 Sep 2024 00:24:19 +0330 Subject: [PATCH 21/23] Change self.config.max_target_positions to self.max_target_positions modeling_whisper.py --- src/transformers/models/whisper/modeling_whisper.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index f1c3d57b32a036..b82b978e5e6d95 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1671,6 +1671,7 @@ def __init__(self, config: WhisperConfig): super().__init__(config) self.model = WhisperModel(config) self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False) + self.max_target_positions = config.max_target_positions # Initialize weights and apply final processing self.post_init() @@ -1751,9 +1752,9 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: - if labels.shape[1] > self.config.max_target_positions: + if labels.shape[1] > self.max_target_positions: raise ValueError( - f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {self.config.max_target_positions} tokens." + f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {self.max_target_positions} tokens." ) if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( From bab1098453582eb6328f00c493b9027c9a4bde1d Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Thu, 5 Sep 2024 00:32:51 +0330 Subject: [PATCH 22/23] Add new tests in test_modeling_whisper.py --- tests/models/whisper/test_modeling_whisper.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 8b7db04c598bbb..d5d8577592428d 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1680,30 +1680,43 @@ def test_labels_sequence_max_length_correct(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_generative_model_classes: + input_features = input_dict["input_features"] + + labels_length = config.max_target_positions + labels = torch.ones(1, labels_length, dtype=torch.int64) + model = model_class(config) + model(input_features=input_features, labels=labels) + + def test_labels_sequence_max_length_correct_after_changing_config(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_generative_model_classes: input_features = input_dict["input_features"] + config.max_target_positions += 100 + labels_length = config.max_target_positions labels = torch.ones(1, labels_length, dtype=torch.int64) + model = model_class(config) model(input_features=input_features, labels=labels) def test_labels_sequence_max_length_error(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_generative_model_classes: - model = model_class(config) input_features = input_dict["input_features"] labels_length = config.max_target_positions + 1 labels = torch.ones(1, labels_length, dtype=torch.int64) + model = model_class(config) with self.assertRaises(ValueError): model(input_features=input_features, labels=labels) def test_labels_sequence_max_length_error_after_changing_config(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.max_target_positions = 500 for model_class in self.all_generative_model_classes: model = model_class(config) From de273c03a8e2d1eba5662334c3c358befc75326f Mon Sep 17 00:00:00 2001 From: Amir Mohammad Fakhimi Date: Thu, 5 Sep 2024 00:40:55 +0330 Subject: [PATCH 23/23] Update test_modeling_whisper.py --- tests/models/whisper/test_modeling_whisper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index d5d8577592428d..de5395a4822a3b 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1728,6 +1728,7 @@ def test_labels_sequence_max_length_error_after_changing_config(self): new_max_length = config.max_target_positions + 100 model.config.max_length = new_max_length model.generation_config.max_length = new_max_length + config.max_target_positions = new_max_length with self.assertRaises(ValueError): model(input_features=input_features, labels=labels)