diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 5efa0a5d120d94..918ed847f83d9e 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -140,6 +140,7 @@ def test_generate_with_static_cache(self): def test_generate_from_inputs_embeds_with_static_cache(self): pass + # overwrite because HybridCache has fixed length for key/values def _check_attentions_for_generate( self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 ): @@ -151,7 +152,7 @@ def _check_attentions_for_generate( for idx, iter_attentions in enumerate(attentions): tgt_len = min_length + idx if not use_cache else 1 - src_len = min_length + idx if not use_cache else max_length # HybridCache has fixed length for key/values + src_len = min_length + idx if not use_cache else max_length expected_shape = ( batch_size * num_beam_groups, @@ -164,6 +165,7 @@ def _check_attentions_for_generate( [layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions) ) + # overwrite because HybridCache has fixed length for key/values def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1): self.assertIsInstance(past_key_values, HybridCache)