-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding prefix constrained beam search #7784
Adding prefix constrained beam search #7784
Conversation
The failed tests are not from this pull request but from RAG. Can someone review this, please? |
Hey @nicola-decao - thanks a lot for your PR here! The failing tests are actually due to this pull request. See a part of the error message here:
Could you fix the errors by slightly adapting the In general I'm fine with this PR :-) |
…a-decao/transformers into add_prefix_constrained_beam_search
@patrickvonplaten now It should be ready for the merge :) |
Awesome! There is probably going to be a merge conflict with the big We changed the design for these kinds of "logits processing" methods so that we'll probalby have to change the PR here a bit (and should also add a test). But let's put the PR on hold for a day and then I can help you merge it! |
@nicola-decao, also what would be a typical use case of this function? E.g. could you give a quick example of a function that one would use as this function? Also @sshleifer could you check if this is useful for Fairseq's Blenderbot? And @yjernite this might be interesting to you as well :-) |
@patrickvonplaten we do have a few use cases in mind internally already :)
@nicola-decao did you have something along those lines in mind? |
@patrickvonplaten an example would be using a seq2seq model to predict a wikipedia title as in Autoregressive Entity Retrieval (https://arxiv.org/abs/2010.00904) (I am the fist author and I want to release my models 😊 - that is why I did this PR). In this case the possible outputs are all the 6M wikipedia titles so one can create a prefix tree and constrain the generation only on these 6M strings. Here an example of what I have: # custom class that creates a prefix tree where only "London" and "Rome" are possible outputs
trie = PrefixTree(["London", "Rome"])
# from a batch of tokens (torch.Tensor) returns a list of lists of allowed tokens (batches/ beams)
# `trie.get` returns the possible next tokens (leaves if any) given a prefix
def prefix_allowed_tokens_fn(batch_tokens):
return [
[
trie.get(tokens.tolist())
for tokens in beam_tokens
]
for beam_tokens in batch_tokens
]
# encoding inputs
input_args = {
k: v.to(model.device) for k, v in tokenizer.batch_encode_plus(
["[START_ENT] London [END_ENT] is the capital of the UK."],
return_tensors="pt"
).items()
}
# generating and decoding
tokenizer.batch_decode(
model.generate(
**input_args,
min_length=0,
num_beams=2,
num_return_sequences=2,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
),
skip_special_tokens=True
)
>> ['London', 'Rome'] |
@patrickvonplaten I guess one way to do it is to implement a |
@nicola-decao - I think you should be able to directly touch the log-probs with a
LogitsProcessor
Or do you need to apply it to
logits_warper as well for beam_search`... not sure yet!
It would be great if you could add a I'm sorry that the big generate refactor means that we have to change this PR now. Do you want to give it a shot with the new |
@patrickvonplaten I can do it :) I'll make another PR today. |
@patrickvonplaten here the new PR: #8529 |
What does this PR do?
This pull request adds a new decoding strategy that constrains the next token to generate based on a callable function. It mirrors facebookresearch/fairseq#2646 for fairseq.
Before submitting
Pull Request section?
to the it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.
@patrickvonplaten , @TevenLeScao