Skip to content

Commit

Permalink
Fix max_length criteria when using inputs_embeds (#28994)
Browse files Browse the repository at this point in the history
* fix max_length for inputs_embeds

* make style

* Update src/transformers/generation/utils.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Static Cache: load models with MQA or GQA (#28975)

* fix

* fix tests

* fix tests

* Update src/transformers/generation/utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* more fixes

* make style

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
3 people authored and Ita Zaporozhets committed May 14, 2024
1 parent fcd7cd2 commit 7db7bf6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
11 changes: 11 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ def _maybe_initialize_input_ids_for_generation(
if isinstance(value, torch.Tensor):
batch_size = value.shape[0]
break

if "inputs_embeds" in model_kwargs:
return torch.ones((batch_size, 0), dtype=torch.long, device=self.device)
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id

def _prepare_attention_mask_for_generation(
Expand Down Expand Up @@ -1421,6 +1424,14 @@ def generate(
)
generation_config.max_length = generation_config.max_new_tokens + input_ids_length

# otherwise the total length [inputs-embeds-len + new-tokens-len] will go beyond indicated `max_length``
elif (
model_input_name == "inputs_embeds"
and inputs_tensor.shape[:-1] != input_ids.shape
and not self.config.is_encoder_decoder
):
generation_config.max_length -= inputs_tensor.shape[1]

# if we don't pass `past_key_values` and a cache_implementation is specified
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING and not model_kwargs.get(
"past_key_values", False
Expand Down
16 changes: 15 additions & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1963,7 +1963,7 @@ def test_generate_from_inputs_embeds_decoder_only(self):
)
self.assertListEqual(
outputs_from_embeds[:, inputs_embeds.shape[1] :].tolist(),
outputs_from_embeds_wo_ids[:, 1:].tolist(),
outputs_from_embeds_wo_ids.tolist(),
)

def test_generate_continue_from_past_key_values(self):
Expand Down Expand Up @@ -2730,6 +2730,20 @@ def test_max_length_warning_if_different(self):
**model_kwargs,
)

def test_max_length_if_input_embeds(self):
# PT-only test: TF doesn't have StoppingCriteria
article = "Today a dragon flew over Paris."
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
inputs_embeds = model.get_input_embeddings()(input_ids)

max_length = 20
input_len = input_ids.shape[-1]
out_gen = model.generate(input_ids=input_ids, max_length=max_length)
out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, max_length=max_length)
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1])

def test_custom_stopping_criteria_overload_error(self):
# PT-only test: TF doesn't have StoppingCriteria
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
Expand Down

0 comments on commit 7db7bf6

Please sign in to comment.