-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix beam_scores
shape when token scores shape changes after logits_processor
#25980
Conversation
cc @gante |
Hi @BakerBunker 👋 You wrote "When token scores shape changes after Would you be able to share an example? |
Sure, I trained a model with different sizes of input and output embeddings, because the output vocab of the model is much smaller compared to the input vocab. And because the input and output embeddings make up a large percentage of the parameters in the model, this saves a lot of GPU memory during training. However, during the generation process, I need to align the input and output class TokenAlignProcessor(LogitsProcessor):
def __call__(self, input_ids, scores):
new_score = torch.empty(scores.shape[0], len(tokenizer), device=DEVICE).fill_(
-torch.inf
)
new_score[:, -OVERLAP_TOKEN_NUMS :] = scores
return new_score |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BakerBunker I see, thank you for the explanation! 🤗
Normally we don't accept changes for uses with custom code but, since this one is indifferent (under normal operation, next_token_scores_processed.shape
== next_token_scores.shape
), I'm pro accepting it to also enable your custom case :)
To make CI go green, you need to run make fixup
in your local transformers
root folder and commit the changes
Thank you for the contribution, @BakerBunker 💛 |
What does this PR do?
When token scores shape changes after
logits_processor
,next_token_scores_processed
has different shape withbeam_scores[:, None].expand_as(next_token_scores)
, this PR fixes this issue.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.