-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix max_length criteria when using inputs_embeds #28994
Fix max_length criteria when using inputs_embeds #28994
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically fulfils the main request of the GH issue, but I'd like for us to go one step further!
In the test you wrote, we check self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1] - 1)
. Ideally, the final -1
shouldn't be there: we initialize input_ids
with decoder_start_id
, causing the additional length, and we probably shouldn't. As such, we can add an additional condition in _prepare_decoder_input_ids_for_generation
: in this specific case, input_ids
should be empty.
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
oh, i see, added a new fix and checked that creating an empty tensor does not break anything |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect! Thank you for iterating 🤗
Regarding failing CI: it seems unrelated to this PR and main
does not have this failure. Therefore, it will likely be solved by rebasing with main
and then force-pushing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this! Very nice and clean PR :)
Just some outstanding questions so I can understand what's happening here before approving
tests/generation/test_utils.py
Outdated
@@ -2730,6 +2730,20 @@ def test_max_length_warning_if_different(self): | |||
**model_kwargs, | |||
) | |||
|
|||
def test_max_length_if_input_embeds(self): | |||
# PT-only test: TF doesn't have StoppingCriteria | |||
article = "Hey, are you conscious?" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use a different phrase here? Talking about consciousness with these LLMs isn't ideal
input_len = input_ids.shape[-1] | ||
out_gen = model.generate(input_ids=input_ids, max_length=max_length) | ||
out_gen_embeds = model.generate(inputs_embeds=inputs_embeds, max_length=max_length) | ||
self.assertEqual(out_gen.shape[-1], input_len + out_gen_embeds.shape[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my own understanding - why is the returned generation when passing in input_ids, a concatenation of the input and newly generated tokens, but for embeds we only return the new embeddings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The addition of input_length here is needed because the output of generation with inputs_embeds
return only newly generated text, while the input_ids
return the whole text, including prompt. So, we are just making sure the lengths of both are equal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, but why is the behaviour different for embeddings and input_ids?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand the question correctly, the lengths here differ because we return the whole text (prompt+new) when user passes ids
. But we cannot recover prompt text from input_embeds
, so we just return the newly generated part
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As @zucchini-nlp wrote.
There is no mismatch if the user passes input_ids
and inputs_embeds
, as generate
continues populating input_ids
. But passing both kinda defeats the point of feeding inputs_embeds
, which is used mostly for experimental purposes, and thus the shape difference when only inputs_embeds
is set. Although we can technically recover input_ids
from inputs_embeds
(reverse lookup search) in most cases to make the shapes consistent, it's probably not a good use of our engineering time :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zucchini-nlp @gante Thanks for the explanation!
@@ -441,6 +441,9 @@ def _maybe_initialize_input_ids_for_generation( | |||
if isinstance(value, torch.Tensor): | |||
batch_size = value.shape[0] | |||
break | |||
|
|||
if "inputs_embeds" in model_kwargs: | |||
return torch.ones((batch_size, 0), dtype=torch.long, device=self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my own understanding - am I correct in understanding when using input_embeds
we don't use any initialization then, this is just an empty placeholder?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, when we initialized with size 1 filled with BOS tokens, that ruined max_length by one token. We want want the final generation be a continuation of input_embeds
and not start with BOS
src/transformers/generation/utils.py
Outdated
@@ -1421,6 +1424,11 @@ def generate( | |||
) | |||
generation_config.max_length = generation_config.max_new_tokens + input_ids_length | |||
|
|||
# adjust max_length when using `input_embeds` in decoder-only models |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than saying what this is doing (we can tell from the code) it would be useful for the comment to explain why we need to do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@amyeroberts done for all comments
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great - thanks for iterating!
@amyeroberts unrelated CI failures, I believe this can be merged 🤗 |
@zucchini-nlp Can you try rebasing? Fixes should have been merged into main with resolve the currently failing tests |
@amyeroberts thanks, now it's all green and can be merged |
* fix max_length for inputs_embeds * make style * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Static Cache: load models with MQA or GQA (huggingface#28975) * fix * fix tests * fix tests * Update src/transformers/generation/utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * more fixes * make style --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* fix max_length for inputs_embeds * make style * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Static Cache: load models with MQA or GQA (#28975) * fix * fix tests * fix tests * Update src/transformers/generation/utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * more fixes * make style --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
What does this PR do?
Fixes #28953 . StoppingCriteria with max_length behaves differently when provided
input_ids
orinputs_embeds
, this happens only on decoder-only models. The PR fixes it so that the criteria accounts for the length ofinput_embeds
when generatingBefore submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@gante