Skip to content

Commit

Permalink
Merge branch 'main' into spqr-quantizer
Browse files Browse the repository at this point in the history
  • Loading branch information
MekkCyber authored Feb 13, 2025
2 parents 8da4a66 + 636ee57 commit afff70e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
1 change: 0 additions & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,7 +1470,6 @@ def _prepare_generated_length(
elif (
model_input_name == "inputs_embeds"
and input_ids_length != inputs_tensor.shape[1]
and input_ids_length != 0
and not self.config.is_encoder_decoder
):
generation_config.max_length -= inputs_tensor.shape[1]
Expand Down
10 changes: 5 additions & 5 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1786,12 +1786,12 @@ def test_generate_from_inputs_embeds_with_static_cache(self):
model.config.use_cache = True
model.config.is_decoder = True
batch_size = input_ids.shape[0]
max_cache_len = 30
max_length = 30

# here we force to not stop at eos and go until max-length
model.generation_config.eos_token_id = model.config.get_text_config().eos_token_id = -1
generation_kwargs = {
"max_length": max_cache_len,
"max_length": max_length,
"cache_implementation": "static",
"return_dict_in_generate": True, # Required to return `past_key_values`
}
Expand All @@ -1810,11 +1810,11 @@ def test_generate_from_inputs_embeds_with_static_cache(self):
num_hidden_layers = text_config.num_hidden_layers

inputs_embeds = model.get_input_embeddings()(input_ids)
max_cache_len += inputs_embeds.shape[1] - 1 # the last generated token has no cache
outputs = model.generate(inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict)

# we should get `max_length` in shape, not `max_length - embeds_length`
cache_shape = (batch_size, num_key_value_heads, max_cache_len, head_dim)
# we should get `max_length - 1` in shape, not `max_length - embeds_length`.
# -1 because the last generated token isn't yet in the cache.
cache_shape = (batch_size, num_key_value_heads, max_length - 1, head_dim)
self.assertTrue(isinstance(outputs.past_key_values, StaticCache))
self.assertTrue(len(outputs.past_key_values.key_cache) == num_hidden_layers)
self.assertTrue(outputs.past_key_values.key_cache[0].shape == cache_shape)
Expand Down

0 comments on commit afff70e

Please sign in to comment.