-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: assisted decoding with sample #22862
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! I think the doc can be made a little bit better before we merge this.
@@ -359,3 +360,26 @@ To enable assisted generation, set the `assistant_model` argument with a model. | |||
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True) | |||
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a'] | |||
``` | |||
|
|||
When using assisted decoding with sampling methods, the `assisted_keep_proba` argument will balance speed with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not like this name, and I'm not completely understanding what this argument does from the doc so I can't suggest a new one😅 assisted_threshold
maybe?
assisted_keep_proba (`float`, *optional*): | ||
Used with assisted decoding. When `do_sample` is true, this controls the threshold at which the model will | ||
resample candidate tokens. When the model's predicted probability for a candidate token is below this | ||
threshold, the candidate token is invalidated and a sampling step. Decreasing this value will aproximate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and a sampling step is performed?
if do_sample: | ||
probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1) | ||
max_probs, max_logits = probs[:, :-1, :].topk(1, dim=-1) | ||
max_logits[max_probs < assisted_keep_proba] = -1 # invalidate candidate tokens with low proba |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So yeah it looks like assisted_threshold
could work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - thanks for adding this and the details script and results! ❤️ 🚀
@@ -179,6 +181,11 @@ class GenerationConfig(PushToHubMixin): | |||
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be | |||
forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token | |||
of index 123. | |||
assisted_keep_proba (`float`, *optional*): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default value should be mentioned here
When `do_sample` is true, this controls the threshold at which the model will resample candidate | ||
tokens. When the model's predicted probability for a candidate token is below this threshold, the | ||
candidate token is invalidated and a sampling step. Decreasing this value will aproximate the decoding | ||
process to greedy search, but it will be faster. | ||
logits_processor (`LogitsProcessorList`, *optional*): | ||
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] | ||
used to modify the prediction scores of the language modeling head applied at each generation step. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logits_warper
in docstring missing here
I'm closing this PR because I found a much much better way to handle the sample case 🧠 Stay tuned 🚀 |
What does this PR do?
This PR expands the previous assisted generation PR so as to work with sampling.
Two important notes to review the PR:
Bellow are some results, so you can understand the balancing act. Execution time obtained on a 3090.
Script
Sample results
1. Make sure you have a good stock.
2. Make sure you have a good broth.
3. Make sure you have a good ramen.
4. Make sure you have a good ramen.
5. Make sure you have a good ramen.
assisted_keep_proba=0.0
)1. Get a noodle.
2. Get a stock.
3. Get a packet of dried ingredients.
4. Cook the noodles.
5. Cook the stock.
6. Cook the packet of dried ingredients.
7. Enjoy!
And
assisted_keep_proba=0.2
)1. Get a noodle vendor.
The noodle vendor makes the noodles. Japanese restaurants often have the noodle vendor on-site.
2. Get a pot.
The pot is used to cook ramen.
3. Get a pot of boiling water.
assisted_keep_proba=0.4
)Step 1: Collect your ingredients.
For this recipe you need a big stock pot. That's good.
And some water.
Step 2: Peel the eggs.
Yes, that's it. Four eggs.
Step 3: Separate the yolks.
assisted_keep_proba=0.6
)Nothing much to take out of the packet. Just a big block of pork fat, some Chinese chilli paste and seasonings.
Preheat the oven to 210ºC (410ºF/Gas 6).
Place the pork fat, chilli paste and seasoning into a mixing bowl and
assisted_keep_proba=0.8
)You'll need: A large pot for boiling noodles
A small saucepan for cooking the noodles
BBQ chicken or roasted fish, or any grilled healthy protein
A box of ramen noodles, noodles that come in
shapes and sizes
Soups or broth,
assisted_keep_proba=1.0
)You take your pre-scalloped noodles, pour boiling water (or your preferred water-to-noodle ratio) over them, and leave them alone for four to five minutes. Once that's done, drain them, season with salt, and heat them up on the stove (microwave won
You take your pre-scalloped noodles, pour boiling water (or your preferred cooking liquid) over it, and after that you go get your ramen broth, add-ins, and other condiments. You make your seasoning sauce, and heat that up. Mix it all together, and put
As it can be seen above, there is a trade off between time and quality. This will certainly be application specific: factual applications will be able to take the most of assisted decoding. In my brief experiments,
assisted_keep_proba=0.3
seems like a sensible default.