Skip to content

Commit

Permalink
Generate: Add assisted generation (huggingface#22211)
Browse files Browse the repository at this point in the history
* working mvp

* remove breakpoint

* fix commit

* standardize outputs

* tmp commit

* tests almost ready

* tmp commit

* skip a few models

* Add streaming; Docs and examples

* document limitations

* PR commits

* Amy PR comments
  • Loading branch information
gante authored and novice03 committed Jun 23, 2023
1 parent e1b81fb commit a86f698
Show file tree
Hide file tree
Showing 6 changed files with 623 additions and 26 deletions.
27 changes: 27 additions & 0 deletions docs/source/en/generation_strategies.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,30 @@ The groups are selected to ensure they are distinct enough compared to the other
This guide illustrates the main parameters that enable various decoding strategies. More advanced parameters exist for the
[`generate`] method, which gives you even further control over the [`generate`] method's behavior.
For the complete list of the available parameters, refer to the [API documentation](./main_classes/text_generation.mdx).

### Assisted Generation

Assisted generation is a modification of the decoding strategies above that uses an assistant model with the same
tokenizer (ideally a much smaller model) to speed up the decoding process. Currently only assisted greedy search is
supported, and doesn't support batched inputs.

<!-- TODO: add link to the blog post about assisted generation when it exists -->

To enable assisted generation, set the `assistant_model` argument with a model.

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer

>>> prompt = "Alice and Bob"
>>> checkpoint = "EleutherAI/pythia-1.4b-deduped"
>>> assistant_checkpoint = "EleutherAI/pythia-160m-deduped"
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```
Loading

0 comments on commit a86f698

Please sign in to comment.