Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding prefix constrained beam search #7784

Closed
wants to merge 9 commits into from
Closed

Adding prefix constrained beam search #7784

wants to merge 9 commits into from

Conversation

nicola-decao
Copy link
Contributor

What does this PR do?

This pull request adds a new decoding strategy that constrains the next token to generate based on a callable function. It mirrors facebookresearch/fairseq#2646 for fairseq.

Before submitting

  • This PR fixes a typo or improves the docs (you can dimiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to the it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors which may be interested in your PR.

@patrickvonplaten , @TevenLeScao

@nicola-decao
Copy link
Contributor Author

The failed tests are not from this pull request but from RAG. Can someone review this, please?
@patrickvonplaten , @TevenLeScao

@patrickvonplaten patrickvonplaten self-requested a review October 17, 2020 19:02
@patrickvonplaten
Copy link
Contributor

Hey @nicola-decao - thanks a lot for your PR here! The failing tests are actually due to this pull request. See a part of the error message here:

            )
E           TypeError: _generate_beam_search() missing 1 required positional argument: 'prefix_allowed_tokens_fn'

src/transformers/modeling_rag.py:1400: TypeError
_______________________ RagDPRT5Test.test_model_generate _______________________
[gw6] linux -- Python 3.7.9 /usr/local/bin/python

self = <tests.test_modeling_rag.RagDPRT5Test testMethod=test_model_generate>

    def test_model_generate(self):
        inputs_dict = self.config_and_inputs
>       self.check_model_generate(**inputs_dict)

Could you fix the errors by slightly adapting the generate method in RAG to make it pass with your example?

In general I'm fine with this PR :-)

@nicola-decao
Copy link
Contributor Author

@patrickvonplaten now It should be ready for the merge :)

@patrickvonplaten
Copy link
Contributor

Awesome! There is probably going to be a merge conflict with the big generate() refactor PR that will be merged today: #6949 .

We changed the design for these kinds of "logits processing" methods so that we'll probalby have to change the PR here a bit (and should also add a test). But let's put the PR on hold for a day and then I can help you merge it!

@patrickvonplaten
Copy link
Contributor

@nicola-decao, also what would be a typical use case of this function? E.g. could you give a quick example of a function that one would use as this function?

Also @sshleifer could you check if this is useful for Fairseq's Blenderbot? And @yjernite this might be interesting to you as well :-)

@yjernite
Copy link
Member

yjernite commented Nov 4, 2020

@patrickvonplaten we do have a few use cases in mind internally already :)

  • prefix-triggered multi task a la T5: in many cases having the prefix in the output sequence makes more sense than in the input
  • seq2seq model evaluation: some metrics (e.g. ROUGE-20 in the ELI5 paper) measure the model's ability to "continue" a generation, which can correlate better with human judgments of quality that full generation ROUGE
  • seq2seq diagnostics: being able to measure the the effect of the input vs local context

@nicola-decao did you have something along those lines in mind?

@nicola-decao
Copy link
Contributor Author

nicola-decao commented Nov 5, 2020

@nicola-decao, also what would be a typical use case of this function? E.g. could you give a quick example of a function that one would use as this function?

Also @sshleifer could you check if this is useful for Fairseq's Blenderbot? And @yjernite this might be interesting to you as well :-)

@patrickvonplaten an example would be using a seq2seq model to predict a wikipedia title as in Autoregressive Entity Retrieval (https://arxiv.org/abs/2010.00904) (I am the fist author and I want to release my models 😊 - that is why I did this PR). In this case the possible outputs are all the 6M wikipedia titles so one can create a prefix tree and constrain the generation only on these 6M strings. Here an example of what I have:

# custom class that creates a prefix tree where only "London" and "Rome" are possible outputs
trie = PrefixTree(["London", "Rome"])

# from a batch of tokens (torch.Tensor) returns a list of lists of allowed tokens (batches/ beams)
# `trie.get` returns the possible next tokens (leaves if any) given a prefix
def prefix_allowed_tokens_fn(batch_tokens):
    return [
        [
            trie.get(tokens.tolist())
            for tokens in beam_tokens
        ]
        for beam_tokens in batch_tokens
    ]

# encoding inputs
input_args = {
    k: v.to(model.device) for k, v in tokenizer.batch_encode_plus(
        ["[START_ENT] London [END_ENT] is the capital of the UK."],
        return_tensors="pt"
    ).items()
}

# generating and decoding
tokenizer.batch_decode(
    model.generate(
        **input_args,
        min_length=0,
        num_beams=2,
        num_return_sequences=2,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    ),
    skip_special_tokens=True
)
>> ['London', 'Rome']

@nicola-decao
Copy link
Contributor Author

Awesome! There is probably going to be a merge conflict with the big generate() refactor PR that will be merged today: #6949 .

We changed the design for these kinds of "logits processing" methods so that we'll probalby have to change the PR here a bit (and should also add a test). But let's put the PR on hold for a day and then I can help you merge it!

@patrickvonplaten I guess one way to do it is to implement a logits_processor. But what I want to touch the log-probabilities directly instead of the logits (I observed that this works better in practice). Maybe we can also add this logic too to the general method?

@patrickvonplaten
Copy link
Contributor

@nicola-decao - I think you should be able to directly touch the log-probs with a LogitsProcessor since there are applied after the log_softmax here:

next_token_scores = logits_processor(input_ids, next_token_scores)
and from looking at this version of the PR it seems to work with a LogitsProcessor

Or do you need to apply it to log_probs + beam_score ? so after this line:

next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
? his would be more difficult then and we would have to see how to deal with it ... maybe introduce logits_warper as well for beam_search`... not sure yet!

It would be great if you could add a LogitProcessor - I kinda did the whole refactor to keep the functions clean 😅 .

I'm sorry that the big generate refactor means that we have to change this PR now. Do you want to give it a shot with the new generate() design? Otherwise I'm happy to help :-)

@nicola-decao
Copy link
Contributor Author

@patrickvonplaten I can do it :) I'll make another PR today.

@nicola-decao
Copy link
Contributor Author

@patrickvonplaten here the new PR: #8529

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants