From ca541bd4f4d932f486a4116deba833b4ffaebd15 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Mon, 21 Oct 2024 10:00:14 +0200 Subject: [PATCH] Generation tests: don't rely on main input name (#34228) * don't rely on main input name * update --- tests/generation/test_utils.py | 69 ++++++++++++------- .../models/reformer/test_modeling_reformer.py | 6 +- .../test_modeling_speech_to_text.py | 8 --- 3 files changed, 47 insertions(+), 36 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6766fa22b9b8a0..996d95eb80ff9b 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -410,7 +410,6 @@ def _contrastive_generate( def test_greedy_generate(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() output_generate = self._greedy_generate(model=model, inputs_dict=inputs_dict) @@ -418,7 +417,7 @@ def test_greedy_generate(self): if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) @pytest.mark.generate def test_greedy_generate_dict_outputs(self): @@ -444,7 +443,9 @@ def test_greedy_generate_dict_outputs(self): # Retrocompatibility check self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput) else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue( + output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + ) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) # Retrocompatibility check self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput) @@ -478,7 +479,9 @@ def test_greedy_generate_dict_outputs_use_cache(self): if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue( + output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + ) self._check_outputs(output_generate, main_input, model.config, use_cache=True) @@ -486,7 +489,6 @@ def test_greedy_generate_dict_outputs_use_cache(self): def test_sample_generate(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() output_generate = self._sample_generate(model=model, inputs_dict=inputs_dict, num_return_sequences=1) @@ -494,7 +496,7 @@ def test_sample_generate(self): if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) @pytest.mark.generate def test_sample_generate_dict_output(self): @@ -521,7 +523,9 @@ def test_sample_generate_dict_output(self): # Retrocompatibility check self.assertIsInstance(output_generate, SampleEncoderDecoderOutput) else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue( + output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + ) self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) # Retrocompatibility check self.assertIsInstance(output_generate, SampleDecoderOnlyOutput) @@ -532,7 +536,6 @@ def test_sample_generate_dict_output(self): def test_beam_search_generate(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() @@ -542,7 +545,7 @@ def test_beam_search_generate(self): if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) @pytest.mark.generate def test_beam_search_generate_dict_output(self): @@ -569,7 +572,9 @@ def test_beam_search_generate_dict_output(self): # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue( + output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + ) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) @@ -609,7 +614,9 @@ def test_beam_search_generate_dict_outputs_use_cache(self): if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue( + output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + ) self._check_outputs( output_generate, @@ -647,7 +654,6 @@ def test_model_parallel_beam_search(self): def test_beam_sample_generate(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() beam_kwargs = self._get_beam_kwargs() @@ -660,7 +666,7 @@ def test_beam_sample_generate(self): if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) # for VLMs inputs embeds won't match input ids unless images are encoded and merged with ids properly # no quick fix available, since obtaining image embeddings step is very model-specific @@ -712,7 +718,9 @@ def test_beam_sample_generate_dict_output(self): # Retrocompatibility check self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput) else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue( + output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + ) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) @@ -746,7 +754,6 @@ def test_generate_without_input_ids(self): def test_group_beam_search_generate(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() # check `generate()` and `group_beam_search()` are equal @@ -759,7 +766,7 @@ def test_group_beam_search_generate(self): if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) # check `group_beam_search` for higher than 1 `num_return_sequences` num_return_sequences = 2 @@ -772,7 +779,7 @@ def test_group_beam_search_generate(self): if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) @pytest.mark.generate def test_group_beam_search_generate_dict_output(self): @@ -799,7 +806,9 @@ def test_group_beam_search_generate_dict_output(self): # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue( + output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + ) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) @@ -814,7 +823,6 @@ def test_group_beam_search_generate_dict_output(self): def test_constrained_beam_search_generate(self): for model_class in self.all_generative_model_classes: config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[model_class.main_input_name] model = model_class(config).to(torch_device).eval() @@ -838,7 +846,7 @@ def test_constrained_beam_search_generate(self): if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) for generation_output in output_generate: self._check_sequence_inside_sequence(force_tokens, generation_output) @@ -862,7 +870,7 @@ def test_constrained_beam_search_generate(self): if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) for generation_output in output_generate: self._check_sequence_inside_sequence(force_tokens, generation_output) @@ -903,7 +911,9 @@ def test_constrained_beam_search_generate_dict_output(self): # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue( + output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + ) self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput) # Retrocompatibility check self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) @@ -923,7 +933,6 @@ def test_contrastive_generate(self): self.skipTest(reason="Won't fix: old model with different cache format") config, inputs_dict = self.prepare_config_and_inputs_for_generate() - main_input = inputs_dict[model_class.main_input_name] # NOTE: contrastive search only works with cache on at the moment. if not hasattr(config, "use_cache"): @@ -940,7 +949,7 @@ def test_contrastive_generate(self): if model.config.is_encoder_decoder: self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]) @pytest.mark.generate def test_contrastive_generate_dict_outputs_use_cache(self): @@ -975,7 +984,9 @@ def test_contrastive_generate_dict_outputs_use_cache(self): if model.config.is_encoder_decoder: self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1) else: - self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1]) + self.assertTrue( + output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1] + ) self._check_outputs(output_generate, main_input, model.config, use_cache=True) @@ -2035,8 +2046,14 @@ def test_inherits_generation_mixin(self): self.assertTrue("GenerationMixin" in str(model_class.__bases__)) def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1): + # we can be sure what is batch size from main input but seq length depends on model type and whether input is text/audio/image + # so we infer actual text seq length from model_tester, same was as it is done in `test_modeling_common.py` tests` batch_size = main_input.shape[0] - seq_length = main_input.shape[-1] + + seq_length = getattr(self.model_tester, "seq_length", None) + seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length) + seq_length = getattr(self.model_tester, "text_seq_length", seq_length) + config = config.text_config if hasattr(config, "text_config") else config num_sequences_in_output = batch_size * num_return_sequences diff --git a/tests/models/reformer/test_modeling_reformer.py b/tests/models/reformer/test_modeling_reformer.py index 774831791fe5aa..25b28477a145ec 100644 --- a/tests/models/reformer/test_modeling_reformer.py +++ b/tests/models/reformer/test_modeling_reformer.py @@ -53,6 +53,7 @@ def __init__( parent, batch_size=13, seq_length=32, + text_seq_length=None, is_training=True, is_decoder=True, use_input_mask=True, @@ -128,6 +129,7 @@ def __init__( self.attn_layers = attn_layers self.pad_token_id = pad_token_id self.hash_seed = hash_seed + self.text_seq_length = text_seq_length or seq_length attn_chunk_length = local_attn_chunk_length if local_attn_chunk_length is not None else lsh_attn_chunk_length num_chunks_after = local_num_chunks_after if local_num_chunks_after is not None else lsh_num_chunks_after @@ -608,7 +610,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod test_sequence_classification_problem_types = True def setUp(self): - self.model_tester = ReformerModelTester(self) + self.model_tester = ReformerModelTester(self, text_seq_length=16) self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37) @slow @@ -689,7 +691,7 @@ def prepare_config_and_inputs_for_generate(self, *args, **kwargs): # decreasing the seq_length in tester causes errors for "training_tests", those need exactly max seq length # NOTE: seq_length has to be multiple of 4, otherwise it fails for other tests original_sequence_length = self.model_tester.seq_length - self.model_tester.seq_length = 16 + self.model_tester.seq_length = self.model_tester.text_seq_length test_inputs = super().prepare_config_and_inputs_for_generate(*args, **kwargs) self.model_tester.seq_length = original_sequence_length return test_inputs diff --git a/tests/models/speech_to_text/test_modeling_speech_to_text.py b/tests/models/speech_to_text/test_modeling_speech_to_text.py index 50446d4628af8c..253cda7e49cb14 100644 --- a/tests/models/speech_to_text/test_modeling_speech_to_text.py +++ b/tests/models/speech_to_text/test_modeling_speech_to_text.py @@ -618,14 +618,6 @@ def test_resize_embeddings_untied(self): def test_generate_without_input_ids(self): pass - def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1): - # In this model, the index of `batch_size` and `sequence_length`` in `main_input` is different: they are the - # first two dimensions of the tensor. - main_input = main_input[:, :, 0] - super()._check_outputs( - output, main_input, config, use_cache=use_cache, num_return_sequences=num_return_sequences - ) - def _create_and_check_torchscript(self, config, inputs_dict): if not self.test_torchscript: self.skipTest(reason="test_torchscript is set to False")