-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow setting different decoder_start_token_ids for each item in a batch in the generate function. #28763
Comments
@dpernes Hi, if you want to specify in different decoder_start_token_ids for each element, you can do it by passing a tensor of shape
|
Great, thank you @zucchini-nlp! This behavior is not documented, though:
You may want to change it to something like:
But why isn't this the same as passing |
Thanks, I added a PR extending the docs. Regarding your question, there is a subtle difference between them. The The general format is |
Hi,
Each batch may require a different decoder_start_token_id during training. This is because each batch has a specific input language and output language. Sometimes, the output language is and some other times it is . |
Hey @tehranixyz , you do not need to specify |
Gotcha! |
Feature request
@gante
The
generate
function has adecoder_start_token_id
argument that allows the specification of the decoder start token when generating from an encoder-decoder model (e.g. mT5). Currently,decoder_start_token_id
must be an integer, which means that the same start token is used for all elements in the batch. I request that you allow the specification of different start tokens for each element of the batch. For this purpose,decoder_start_token_id
must be a tensor with shape(batch_size,)
.Motivation
Some multilingual encoder-decoder models use the
decoder_start_token_id
to indicate the target language. Thus, this change would allow generation into multiple target languages in parallel, as illustrated in the code below.Your contribution
The text was updated successfully, but these errors were encountered: