-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: consistently handle special tokens as tensors #30624
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
cc @zucchini-nlp this one is the same as #29788, which you've already reviewed in its early state. Note that because we've merged other PRs first (e.g. removing the decoding functions from the public API), the diff is much smaller 💛 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall good. Don't think we should warn but error out, and maybe update the serialize / de-serialize?
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@ArthurZucker addressed your comments :D (let's see this fast CI going brr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good thanks for updating
torch.zeros(input_ids.shape[:2], dtype=torch.int64, layout=input_ids.layout, device=input_ids.device) | ||
+ model._get_decoder_start_token_id() | ||
+ generation_config.decoder_start_token_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have to do the + and not just use the decoder start toekn id?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the + here makes it a tensor of all decoder_start_token_id
(as opposed to concatenation).
…30624) * tmp commit * [test_all] mvp * missing not * [test_all] final test fixes * fix musicgen_melody and rag * [test_all] empty commit * PR comments * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
* tmp commit * [test_all] mvp * missing not * [test_all] final test fixes * fix musicgen_melody and rag * [test_all] empty commit * PR comments * Update src/transformers/generation/utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
….41.1 Fixes #31. The handling of special tokens in `transformers` was changed in huggingface/transformers#30624 and huggingface/transformers#30746. This updates the XTTS streaming code accordingly.
….41.1 Fixes #31. The handling of special tokens in `transformers` was changed in huggingface/transformers#30624 and huggingface/transformers#30746. This updates the XTTS streaming code accordingly.
….41.1 Fixes #31. The handling of special tokens in `transformers` was changed in huggingface/transformers#30624 and huggingface/transformers#30746. This updates the XTTS streaming code accordingly.
What does this PR do?
(reopened from #29788, requirements were merged and made this PR simpler)
To enable
torch.compile
withgenerate
, some special token-related operations have to be rewritten into torch operations. That requires special tokens to be tensors instead of integers or a list of integers. (See #29374 for a working prototype)This PR reworks special token usage in
generate
to consistently treat them as a tensor, as opposed to e.g. keeping track ofeos_token_id
in integer and in tensor form.👉 Review suggestion: start by reading
_prepare_special_tokens
and how it fits ingenerate
.Tests ran locally:
pytest --doctest-modules src/transformers/generation/logits_process.py -vv
)pytest --doctest-modules src/transformers/generation/utils.py -vv
)RUN_SLOW=1 py.test tests/generation/ -vv
)RUN_SLOW=1 py.test tests/test_cache_utils.py -vv
)RUN_SLOW=1 py.test tests/models/llama/test_modeling_llama.py -vv
)RUN_SLOW=1 py.test tests/models/whisper/test_modeling_whisper.py -vv
) -- same failures as inmain