Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: assisted decoding with sample #22862

Closed
wants to merge 2 commits into from

Conversation

gante
Copy link
Member

@gante gante commented Apr 19, 2023

What does this PR do?

This PR expands the previous assisted generation PR so as to work with sampling.

Two important notes to review the PR:

  1. I'd suggest starting the review by the docs, so you understand what's going on at a high level. Sampling adds an additional (controllable) heuristic, so the user can control between speed and pure sampling behavior.
  2. In terms of implementation, I've decided to overload the assisted generation function with a few extra lines to handle the sample case. This is to avoid adding a close copy of a 500-line function.

Bellow are some results, so you can understand the balancing act. Execution time obtained on a 3090.

Script
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
import torch
import time

model_id = "EleutherAI/pythia-6.9b-deduped"
assistant_id = "EleutherAI/pythia-160m-deduped"

tokenizer = AutoTokenizer.from_pretrained(model_id)

assistant_model = AutoModelForCausalLM.from_pretrained(assistant_id)
assistant_model = assistant_model.to("cuda")

model_kwargs = {
    "pretrained_model_name_or_path": model_id,
    "device_map": "auto",
    "max_memory": {0: "20GiB", "cpu": "50GiB"},
    "torch_dtype": torch.float16,
}
model = AutoModelForCausalLM.from_pretrained(**model_kwargs)

inputs = tokenizer("Here's how to cook a good ramen:", return_tensors="pt").to("cuda")

streamer = TextStreamer(tokenizer=tokenizer)

print("Greedy with assistance:")
start = time.time()
model.generate(**inputs, assistant_model=assistant_model, streamer=streamer, max_new_tokens=64)
print(f"Elapsed time: {time.time() - start:.2f} seconds")

for p in (0.0, 0.2, 0.4, 0.6, 0.8, 1.0):
    print(f"Sample with assistance (assisted_keep_proba = {p})")
    torch.manual_seed(0)
    start = time.time()
    model.generate(
        **inputs,
        do_sample=True,
        assistant_model=assistant_model,
        assisted_keep_proba=p,
        streamer=streamer,
        max_new_tokens=64
    )
    print(f"Elapsed time: {time.time() - start:.2f} seconds")

print("Original sample")
torch.manual_seed(0)
start = time.time()
model.generate(**inputs, do_sample=True, streamer=streamer, max_new_tokens=64)
print(f"Elapsed time: {time.time() - start:.2f} seconds")
Sample results
Decoding strategy Result Execution time
Greedy (w/assistance) Here's how to cook a good ramen:

1. Make sure you have a good stock.

2. Make sure you have a good broth.

3. Make sure you have a good ramen.

4. Make sure you have a good ramen.

5. Make sure you have a good ramen.
1.44 seconds
Sample (w/assistance
assisted_keep_proba=0.0)
Here's how to cook a good ramen:

1. Get a noodle.

2. Get a stock.

3. Get a packet of dried ingredients.

4. Cook the noodles.

5. Cook the stock.

6. Cook the packet of dried ingredients.

7. Enjoy!

And
1.44 seconds
Sample (w/assistance
assisted_keep_proba=0.2)
Here's how to cook a good ramen:

1. Get a noodle vendor.

The noodle vendor makes the noodles. Japanese restaurants often have the noodle vendor on-site.

2. Get a pot.

The pot is used to cook ramen.

3. Get a pot of boiling water.
1.59 seconds
Sample (w/assistance
assisted_keep_proba=0.4)
Here's how to cook a good ramen:

Step 1: Collect your ingredients.

For this recipe you need a big stock pot. That's good.

And some water.

Step 2: Peel the eggs.

Yes, that's it. Four eggs.

Step 3: Separate the yolks.
1.71 seconds
Sample (w/assistance
assisted_keep_proba=0.6)
Here's how to cook a good ramen:

Nothing much to take out of the packet. Just a big block of pork fat, some Chinese chilli paste and seasonings.

Preheat the oven to 210ºC (410ºF/Gas 6).

Place the pork fat, chilli paste and seasoning into a mixing bowl and
2.08 seconds
Sample (w/assistance
assisted_keep_proba=0.8)
Here's how to cook a good ramen:

You'll need: A large pot for boiling noodles
A small saucepan for cooking the noodles
BBQ chicken or roasted fish, or any grilled healthy protein
A box of ramen noodles, noodles that come in
shapes and sizes
Soups or broth,
2.32 seconds
Sample (w/assistance
assisted_keep_proba=1.0)
Here's how to cook a good ramen:

You take your pre-scalloped noodles, pour boiling water (or your preferred water-to-noodle ratio) over them, and leave them alone for four to five minutes. Once that's done, drain them, season with salt, and heat them up on the stove (microwave won
2.56 seconds
Original Sample) Here's how to cook a good ramen:

You take your pre-scalloped noodles, pour boiling water (or your preferred cooking liquid) over it, and after that you go get your ramen broth, add-ins, and other condiments. You make your seasoning sauce, and heat that up. Mix it all together, and put
2.05 seconds

As it can be seen above, there is a trade off between time and quality. This will certainly be application specific: factual applications will be able to take the most of assisted decoding. In my brief experiments, assisted_keep_proba=0.3 seems like a sensible default.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 19, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! I think the doc can be made a little bit better before we merge this.

@@ -359,3 +360,26 @@ To enable assisted generation, set the `assistant_model` argument with a model.
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```

When using assisted decoding with sampling methods, the `assisted_keep_proba` argument will balance speed with
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like this name, and I'm not completely understanding what this argument does from the doc so I can't suggest a new one😅 assisted_threshold maybe?

assisted_keep_proba (`float`, *optional*):
Used with assisted decoding. When `do_sample` is true, this controls the threshold at which the model will
resample candidate tokens. When the model's predicted probability for a candidate token is below this
threshold, the candidate token is invalidated and a sampling step. Decreasing this value will aproximate
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and a sampling step is performed?

if do_sample:
probs = new_logits[:, -candidate_length - 1 :, :].softmax(dim=-1)
max_probs, max_logits = probs[:, :-1, :].topk(1, dim=-1)
max_logits[max_probs < assisted_keep_proba] = -1 # invalidate candidate tokens with low proba
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So yeah it looks like assisted_threshold could work?

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - thanks for adding this and the details script and results! ❤️ 🚀

@@ -179,6 +181,11 @@ class GenerationConfig(PushToHubMixin):
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token
of index 123.
assisted_keep_proba (`float`, *optional*):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default value should be mentioned here

When `do_sample` is true, this controls the threshold at which the model will resample candidate
tokens. When the model's predicted probability for a candidate token is below this threshold, the
candidate token is invalidated and a sampling step. Decreasing this value will aproximate the decoding
process to greedy search, but it will be faster.
logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logits_warper in docstring missing here

@gante
Copy link
Member Author

gante commented Apr 22, 2023

I'm closing this PR because I found a much much better way to handle the sample case 🧠

Stay tuned 🚀

@gante gante closed this Apr 22, 2023
@gante gante deleted the assisted_sample branch May 18, 2023 15:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants