Skip to content

Commit

Permalink
[Speculative decoding] Initial spec decode docs (vllm-project#5400)
Browse files Browse the repository at this point in the history
  • Loading branch information
cadedaniel authored and jimpang committed Jun 27, 2024
1 parent ca61729 commit 7e35f61
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Documentation
models/engine_args
models/lora
models/vlm
models/spec_decode
models/performance

.. toctree::
Expand Down
75 changes: 75 additions & 0 deletions docs/source/models/spec_decode.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
.. _spec_decode:

Speculative decoding in vLLM
============================

.. warning::
Please note that speculative decoding in vLLM is not yet optimized and does
not usually yield inter-token latency reductions for all prompt datasets or sampling parameters. The work
to optimize it is ongoing and can be followed in `this issue. <https://github.com/vllm-project/vllm/issues/4630>`_

This document shows how to use `Speculative Decoding <https://x.com/karpathy/status/1697318534555336961>`_ with vLLM.
Speculative decoding is a technique which improves inter-token latency in memory-bound LLM inference.

Speculating with a draft model
------------------------------

The following code configures vLLM to use speculative decoding with a draft model, speculating 5 tokens at a time.

.. code-block:: python
from vllm import LLM, SamplingParams
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="facebook/opt-6.7b",
tensor_parallel_size=1,
speculative_model="facebook/opt-125m",
num_speculative_tokens=5,
use_v2_block_manager=True,
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Speculating by matching n-grams in the prompt
---------------------------------------------

The following code configures vLLM to use speculative decoding where proposals are generated by
matching n-grams in the prompt. For more information read `this thread. <https://x.com/joao_gante/status/1747322413006643259>`_

.. code-block:: python
from vllm import LLM, SamplingParams
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="facebook/opt-6.7b",
tensor_parallel_size=1,
speculative_model="[ngram]",
num_speculative_tokens=5,
ngram_prompt_lookup_max=4,
use_v2_block_manager=True,
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Resources for vLLM contributors
-------------------------------
* `A Hacker's Guide to Speculative Decoding in vLLM <https://www.youtube.com/watch?v=9wNAgpX6z_4>`_
* `What is Lookahead Scheduling in vLLM? <https://docs.google.com/document/d/1Z9TvqzzBPnh5WHcRwjvK2UEeFeq5zMZb5mFE8jR0HCs/edit#heading=h.1fjfb0donq5a>`_
* `Information on batch expansion. <https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit#heading=h.kk7dq05lc6q8>`_
* `Dynamic speculative decoding <https://github.com/vllm-project/vllm/issues/4565>`_

0 comments on commit 7e35f61

Please sign in to comment.