Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zucchini-nlp committed Aug 30, 2024
1 parent df3c512 commit 2b47ea9
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def test_generate_with_static_cache(self):
def test_generate_from_inputs_embeds_with_static_cache(self):
pass

# overwrite because HybridCache has fixed length for key/values
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
):
Expand All @@ -151,7 +152,7 @@ def _check_attentions_for_generate(

for idx, iter_attentions in enumerate(attentions):
tgt_len = min_length + idx if not use_cache else 1
src_len = min_length + idx if not use_cache else max_length # HybridCache has fixed length for key/values
src_len = min_length + idx if not use_cache else max_length

expected_shape = (
batch_size * num_beam_groups,
Expand All @@ -164,6 +165,7 @@ def _check_attentions_for_generate(
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
)

# overwrite because HybridCache has fixed length for key/values
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
self.assertIsInstance(past_key_values, HybridCache)

Expand Down

0 comments on commit 2b47ea9

Please sign in to comment.