Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: assisted decoding with sample #22862

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions docs/source/en/generation_strategies.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -333,15 +333,16 @@ This guide illustrates the main parameters that enable various decoding strategi
[`generate`] method, which gives you even further control over the [`generate`] method's behavior.
For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.mdx).

### Assisted Generation
### Assisted Decoding

Assisted generation is a modification of the decoding strategies above that uses an assistant model with the same
tokenizer (ideally a much smaller model) to speed up the decoding process. Currently only assisted greedy search is
supported, and doesn't support batched inputs.
Assisted decoding is a modification of the decoding strategies above that uses an assistant model with the same
tokenizer (ideally a much smaller model) to greedily generate a few candidate tokens. The main model then validates
the candidate tokens in a single forward pass, which speeds up the decoding process. Currently, only greedy search
and sampling are supported with assisted decoding, and doesn't support batched inputs.

<!-- TODO: add link to the blog post about assisted generation when it exists -->
<!-- TODO: add link to the blog post about assisted decoding when it exists -->

To enable assisted generation, set the `assistant_model` argument with a model.
To enable assisted decoding, set the `assistant_model` argument with a model.

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
Expand All @@ -359,3 +360,26 @@ To enable assisted generation, set the `assistant_model` argument with a model.
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```

When using assisted decoding with sampling methods, the `assisted_keep_proba` argument will balance speed with
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like this name, and I'm not completely understanding what this argument does from the doc so I can't suggest a new one😅 assisted_threshold maybe?

pure sampling behavior. The greedily decoded candidate tokens whose main model's predicted probability are below
this threshold are discarded, which force an anticipated sampling step. If `assisted_keep_proba` is set to 1.0,
assisted decoding degenerates into a slow multinomial sampling. Conversely, if `assisted_keep_proba` is set to 0.0,
the outcome will approximate an accelerated greedy decoding, with a few sampling steps in between.

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer

>>> prompt = "Alice and Bob"
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped"

>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, do_sample=True, assisted_keep_proba=0.1)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are in a room watching television. They are both on the same channel. Alice']
```
12 changes: 11 additions & 1 deletion src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ class GenerationConfig(PushToHubMixin):
`num_beams>1` and `num_beam_groups>1`
- *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
`constraints!=None` or `force_words_ids!=None`
- *assisted decoding* by calling [`~generation.GenerationMixin.assisted_decoding`], if
`assistant_model` is passed to `.generate()`

You do not need to call any of the above methods directly. Pass custom parameter values to 'generate'. To learn
You do not need to call any of the above methods directly. Pass custom parameter values to '.generate()'. To learn
more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).

Arg:
Expand Down Expand Up @@ -179,6 +181,11 @@ class GenerationConfig(PushToHubMixin):
A list of pairs of integers which indicates a mapping from generation indices to token indices that will be
forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token
of index 123.
assisted_keep_proba (`float`, *optional*):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default value should be mentioned here

Used with assisted decoding. When `do_sample` is true, this controls the threshold at which the model will
resample candidate tokens. When the model's predicted probability for a candidate token is below this
threshold, the candidate token is invalidated and a sampling step. Decreasing this value will aproximate
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and a sampling step is performed?

the decoding process to greedy search, but it will be faster.

> Parameters that define the output variables of `generate`

Expand Down Expand Up @@ -258,6 +265,7 @@ def __init__(self, **kwargs):
self.suppress_tokens = kwargs.pop("suppress_tokens", None)
self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None)
self.assisted_keep_proba = kwargs.pop("assisted_keep_proba", 0.3)

# Parameters that define the output variables of `generate`
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
Expand Down Expand Up @@ -319,6 +327,8 @@ def validate(self):
"""
if self.early_stopping not in {True, False, "never"}:
raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
if self.assisted_keep_proba < 0.0 or self.assisted_keep_proba > 1.0:
raise ValueError(f"`assisted_keep_proba` must be between 0.0 and 1.0, but is {self.assisted_keep_proba}.")

def save_pretrained(
self,
Expand Down
Loading