-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move eos_token_id
to stopping criteria
#29459
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good 💪 A few comments to further refine the idea
In the body of generate
, we set pad_token_id
to eos_token_id
when the latter exists and the former is None
.
As such, we can further modify the decoding functions as follows:
- the block
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 -
unfinished_sequences)
can become
# finished sentences should have their next token be a padding token
if pad_token_id is None:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
because the case that raises the exception can never be triggered from generate
(although we should add an integration test that ensures this remains true, if it doesn't already exist!).
- see my in-code comment :)
Ah, the pipeline tests are ignored by default (just like slow tests). You need to add |
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Done for all comments. I checked that users have to possibility of messing up when calling the these methods directly and ran again all tests (+slow, +pipeline) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for iterating 🙌
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, would be nice to make sure our assumption is always correct, nice cleanup otherwise!
eos_token_id (`Union[int, List[int]]`): | ||
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be nice to also support list of lists ? To have stopping tokens (if the eos is ["<","eos",">"]
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, is it really true for existing models to have eos tokens as a sequence? I guess you are referring to custom eos tokens, when users want to stop at "" but they haven't trained the model with "" as special token. If that's the case, there is StopStringsCriteria
PR coming or users are free to write their custom criteria
@gante wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the case where we want to stop on a string will be covered by #28932 (and is way more complex)
This PR is exclusively to port the existing EOS logic into its own stopping criteria :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right I forgot it was being covered
src/transformers/generation/utils.py
Outdated
"`eos_token_id` is deprecated in this function and will be removed in v4.41, use" | ||
" `stopping_criteria=StoppingCriteriaList([EOSTokenCriteria(eos_token_id=eos_token_id)])` instead.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should rather be set in the generation_config I guess? Why can't we leave this as a kwarg and set the generation_config.eos_token_id ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the general idea, we want to stop accepting eos
as input argument. The warning is for backward compatibility, since we used to give priority to user defined eos_token_id
from args before checking the generation_config.eos_token_id
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
^ as @zucchini-nlp wrote. This is merely for backward compatibility, if users want to keep calling the decoding methods directly (at their own risk, since we're deprecating their public API)
The decoding methods don't set up new logits processors nor stopping criteria. As such, The correct replacement for the EOS feature is to pass the new EOSTokenCriteria
:)
The long-term goal is to disable calling decoding methods directly at all, so we can optimize the codebase.
src/transformers/generation/utils.py
Outdated
.prod(dim=0) | ||
.bool() | ||
) | ||
last_assistant_token_is_eos = stopping_criteria[-1](candidate_input_ids, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this means we the last stopping_criteria is always EOSTokenCriteria
. It's not necessarily obvious and if people use custom passed criteria I am not sure this will always be the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, will fix it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker good catch, I missed this one too :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gante need you to look. I removed some code, given that assisted decoding works only with one batch, so we can make some assumptions about when to stop generating.
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One change in the speculative decoding diff and we're ready to go :)
.prod(dim=0) | ||
.bool() | ||
) | ||
is_done_candidate = stopping_criteria(candidate_input_ids, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧠 (this is actually much more versatile than the previous version!)
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
@zucchini-nlp needs a rebase with (@ArthurZucker I'm assuming this is ready to be merged, since all comments were addressed. I'm deciding to merge since this unblocks me on the torch.compile front, let us know if you like some post-merge changes :) ) |
@gante Woah, hold on before merging! You still need a core maintainers approval: even if the comments have been addressed it's important to make sure the amendments are approved |
@zucchini-nlp this PR needs to be rebased with |
@ArthurZucker ping for approval, if all tasks are complete :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Make sure to run the slow tests for important models before merging! + testing generation with Llama and eos_token_id list!
src/transformers/generation/utils.py
Outdated
@@ -2364,10 +2370,26 @@ def _greedy_search( | |||
) | |||
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) | |||
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id | |||
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id | |||
if eos_token_id is not None: | |||
warnings.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment here
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Rebased No idea why the examples test is failing, does not have anything to do with this PR. |
Just a heads up this commit breaks transformers generation on Apple Silicon as isin is not implemented for the MPS backend |
* add eos stopping criteria * minor fix * Update tests/generation/test_stopping_criteria.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * check eos is not None and fix tests * make style and fixup * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/generation/test_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/generation/test_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/__init__.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * camel case everywhere * call stopping criteria list for candidate ids * make style and fixup * Empty commit * Empty commit to pass flaky test * set max length in PromptLookupCandidateGenerator * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * lets fix this typo in docs * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * update PR * empty commit --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
What does this PR do?
This PR is a small step for
torch.compile
andgenerate
compatibility. It movesEOS
token to stopping criteria, so now we can loopwhile stopping_criteria
and get rid of extra checks onEOS
at the end of each generate method.All the generate tests, including slow are passing.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@gante