Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Contribution] Disjunctive Positive Constraint Decoding (adding force_tokens to model.generate()) #14081

Closed
cwkeam opened this issue Oct 20, 2021 · 14 comments · Fixed by #15416

Comments

@cwkeam
Copy link
Contributor

cwkeam commented Oct 20, 2021

🚀 Feature request

"Disjunctive Positive Constraint Decoding" Algorithm proposed by a recent paper: Guided Generation of Cause and Effect

Currently the model.generation() method limits the user to just the highest probability outputs. Here the user is able to force diverse outputs by forcing the model to include diverse tokens across multiple generations.

This method is called "Disjunctive Positive Constraint Decoding", and it forces the model.generate() process to generate sequences with the highest probabilities under the constraint of needing to include a set of provided tokens.

This "disjunctive" method is powerful in that it can handle lemmatizing these forced tokens. For instance, when asking the model to autoregressively generate the completion tokens from "Babies cry because" and want to force the generation to include the word "lonely", it can induce the model to generate sequences like "Babies cry because they are lonely", as well as "Babies cry because of their loneliness".

I think that this could be implemented as:

model.generate(force_tokens=[["lonely", "loneliness", ...], ["happy", "happiness", ...]])

where the input to force_tokens is a 2D array, where each 1D array is a list of different forms of the desired token.

Otherwise it could be:

model.generate(force_tokens=["lonely", "happy"])

but in this case the transformers library would need to have a built-in lemmatization engine, which I think might be something better left for the practictioner to figure out (i.e. go with the first method instead).

Motivation

Diversity in outputs from sequence generation of LMs have always been an active problem. Though usually diversity inducing methods involved some dual implementation with a VAE or modifying the training scheme, this feature would allow the practioner to induce diversity in a very controllable way.

A large pretrained language model probably has the capacity to generate all kinds of different expressions given an input, but usually the generation gets limited to the highest probable outputs. Clearly one solution is to use sampling instead, but this goes back to the problem of controllability. This level of control is extremely useful in model implementations that aim to learn a syntactic transformations that need to preserve certain entities, or QA verbalizers where we have pre-existing knowledge of what the answer should be.

Instead of making it generate a lot of sequences and filtering out for desired ones, this would allow to force it to generate an output that we want, which a large LM probably can do well; even if it can't figure out a way, then we can just filter out the low probabability outputs based on a threshold.

Your contribution

I'm happy to submit a PR for the full implementation if there aren't any reasons to object this feature.

But I do believe I should get some feedback on this idea before proceeding with an implementation, since it's not exactly clear what's the best way to introduce this functionality to the library.

@cwkeam cwkeam changed the title Disjunctive Positive Constraint Decoding (adding force_tokens to model.generate()) [Feature Contribution] Disjunctive Positive Constraint Decoding (adding force_tokens to model.generate()) Oct 20, 2021
@LysandreJik
Copy link
Member

cc @patrickvonplaten @Narsil

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@Narsil
Copy link
Contributor

Narsil commented Dec 2, 2021

This went under the radar sorry about this.

I'll let patrick discuss actual inclusion within transformers, but FYI, we're enabling generic logits_processor which should enable you to arbitrarily reassign logits during generation. #12219

If you could create your implementation framed as a LogitsProcessor type of objects, that would make inclusion super simple (and also even if it does not get merged, usage should be quite smooth).

@patrickvonplaten
Copy link
Contributor

Very much agree with @Narsil - it would be great if you could add a simple LogitsProcessor class for you decoding method. This way we can surely merge it to master

@cwkeam
Copy link
Contributor Author

cwkeam commented Dec 9, 2021

Thanks for reviewing this thread @patrickvonplaten @Narsil

Though it'd be ideal if it can be solved with a simple custom LogitsProcessor, it seems like this problem requires at least a dedicated beam search function (LexicallyConstrainedBeamSearch / DisjunctivlyConstrainedBeamSearch).

Upon further research I realized that similar features already exist in Fairseq and Sockeye.

Fairseq implementation is introduced by this readme and the implementation is here (LexicallyConstrainedBeamSearch)
Sockeye's implementation is here.

These implementations are based on mainly the following papers:
Fast Lexically Constrained Decoding with Dynamic Beam Allocation for
Neural Machine Translation

Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting

I feel like the fact that implementations exist in other major text generation libraries hints at the importance of such a function and similar ideas about "whitelisting" certain tokens (as opposed to blacklisting with bad_token_ids) have been discussed before in huggingface forums.

I think this is even more worthy of a cause since what I had proposed with this "Disjunctive Positive Constraint Decoding" approach is one step above all the above implementations in that it can handle lemmatized constraints.

For example, users using Sockeye or Fairseq can only force the model generation to include the word "rain" and will end up prevent it against generating the word "raining". On the other hand, this disjunctive approach is able to instead say "generate one of {rain, raining}" for better intuitive use of the function.

As one can imagine, implementing even just the simple lexically constraint approach found in fairseq and Sockeye requires a dedicated beam search function and it's much more complex than boosting / reducing logits at each step with a custom LogitsProcessor.

I'm wondering if such an approach makes the scale of this implementation too large and complex for merging to master. I'm personally more than willing to write the full implementation, with the boosted confidence since half the work is done with other libraries having similar implementations already.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Jan 3, 2022

Hey @cwkeam,

I'm taking a look into this now (finally have some free time). Having looked at the papers, I think it actually makes sense to incorporate this generation method. @LysandreJik @sgugger @patil-suraj @yjernite @thomwolf what do you think?

@patrickvonplaten
Copy link
Contributor

@cwkeam,

One quick question. There seems to be a blog post explaining how to do constrained decoding by just using a LogitsProcessor class (code, blog post). Could you explain why this is not sufficient for this case and what improvements we would see by implementing a whole new generation method?

@cwkeam
Copy link
Contributor Author

cwkeam commented Jan 4, 2022

Hi @patrickvonplaten, thanks for getting back!

From what I've looked into thus far, the following are the reasons I can think of right now. It mainly comes down to the fact that the scope of the LogitsProcessor is just one step in the generation process and outputs scores across the vocabulary for just the next step, while this constraint decoding method requires ranking, grouping, and filtering of stateful beams (tracking constraints satisfied by each beam).

1. Flexible Placement of Constraints

At least with just looking at the blog, the implementations provided are rather very simple in nature, which don't show the capacity needed for this decoding method. The code in the blog post is along the lines of:

among the next probable tokens, choose only those that have even characters.

if current_token.startsWith("a"): next token has to start with "b" 
if current_token.startsWith("b"): next token has to start with "c" 
else: next token has to start with "a"

The immediate problem is that this can't control for where in the sequence the generation satisfies the constraint.

In a case where you want a certain word to be included in the output, given that the scope of a LogitsProcessor is just one step in the generation process, it doesn't seem possible to me how one could implement the forcing of that word somewhere appropriate in some future step.

2. Supporting Multiple Constraints

Not only should this method be able to force those tokens at the most appropriate spot in the sequence (which is unknown at an arbitrary step in the process), it should also handle the most appropriate constraint at a time (equally unknown at an arbitrary step).

If there are n constraints ({c1, c2, ..., Cn}) and T tokens {t1, t2, ..., tT} in a gold output, we're essentially finding most probable position i that satisfies constraint_token j and most appropriate tokens for the rest of the positions. I personally can't see a way how this can be solved by only changing the scores of the vocabulary for step t+1 at step t in the generation (the scope of LogitsProcessor).

Essentially we need a more global view of the generation at each step of the generation

2. Solution With Constraint States and Ranking & Filtering Beams Throughout The Generation

From Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting

"The implementation of positively constrained decoding comprises two key pieces: tracking which of the supplied constraints each hypothesis has already generated, and ensuring progress through the constraints by dynamically allocating the beam to hypotheses that have generated different numbers of them.
...
The use of tries for recording constraint state, and thereby offsetting certain corner cases.
Each time a constraint is completed, the number is decremented, and nodes of the trie can be trimmed when they lead only to paths ending in zero counts.
...
In summary, we represent all the constraints as a compact trie. Each hypothesis in the decoder beam has its version of the trie. The set of active states in each hypothesis’ trie tracks all suffixes of the target words that match against the constraint trie. When a constraint is generated, its counter is decremented and zero paths are pruned."

From the existing fairseq implementation (which doesn't have the additional feature of the "disjunctive" constraint that I'm proposing)

 # STEP 3: Compute the "bank" for each candidate. This is the
# number of constraints it's generated. We need this so that
# we can do round-robin allocation of the beam across these
# banks. If C is the number of constraints, we select the best
# item in bank C, then the best in bank C-1, etc, followed by
# the 2nd-best in bank C, the 2nd-best in bank C-1, etc, and so
# on, until the maximum beam size. We accomplish this by
# creating a sort key and striping across the banks.

# STEP 5: Remove duplicates. The topk calls (overall and
# per-row) plus the per-row generation of constraints will
# produce duplicates. Here we remove them.

# STEP 6: Assign IDs round-robin across banks, sort, and
# truncate. Now that the candidates are sorted by (bank,
# score) and uniqed, we dynamically allocate the {beam_size}
# beam by striping across the candidates. These stripes will
# be used as sort keys to do round-robin selection. This is
# accomplished in a single pass with offsets. Sorting by
# highest-banks (furthest-along hypotheses) first ensures
# progress through the constraints.
#
# e.g., BANKS: 3 3 3 2 2 2 2 1 1 1 0 0
# OLD STRIPES: 0 1 2 0 1 2 3 0 1 2 0 1
# NEW STRIPES: 0 1+4 2+8 0+1 1+5 2+9 3+11 0+2 1+6 2+10 0+3 1+7
#            = 0 5 10 1 6 11 13 2 7 12 3 8
#
# Sorting by this then gives the following banks:
#
#             3 2 1 0 3 2 1 0 3 2 1 2
#
# We'll take the top {beam_size} of these.

Let me know what you think!

I'm still very much open to a solution, if possible, using LogitsProcessor or a modification of it if needed, since clearly that'd be a much easier way of integrating this functionality into the package.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Jan 5, 2022

Hey @cwkeam,

Thanks a lot for you very in-detail answer! We've discussed this internally and think actually it makes sense to add a new generation method for this.

Ideally this generation method could cover the generation methods proposed in the following papers:

From an implementation point of view I think we should add a new "sub"-generation method here:

It would join the 5 existing "sub"-generation methods then which are:

  • greedy_search
  • sample
  • beam_sample
  • beam_search
  • group_search

We might also have to add a new BeamSearcher class here: https://github.com/huggingface/transformers/blob/master/src/transformers/generation_beam_search.py

Think we could call it either guided_beam_search or constrained_beam_search and a couple of appropriate input arguments to generate(...) to trigger the "sub"-generation function

@patrickvonplaten
Copy link
Contributor

Would you be interested in giving this PR a try? I'm more than happy to help you whenever you're stuck!
Otherwise I would probably be able to find some time to give it a try in 1-2 months myself.

@cwkeam
Copy link
Contributor Author

cwkeam commented Jan 5, 2022

Hi @patrickvonplaten

Thanks for the consideration!
I'd be more than happy to make the full implementation. I'll try to push a PR within the next week or so with an initial PR.

@patrickvonplaten
Copy link
Contributor

This sounds great! Please ping me whenever you need any help

@cwkeam
Copy link
Contributor Author

cwkeam commented Jan 30, 2022

@patrickvonplaten Hi, I'm almost complete with a full, robust implementation with full testing. I had retracted the previous PR because I felt that I submitted it a bit prematurely. Will you reopen this issue? I'm working on the documentation now.

@minmummax
Copy link

if i want to make my generation result in a range char , how should i do ?
for example [A, B, C ]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants