-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Documentation][Spec Decode] Add documentation about lossless guarantees in Speculative Decoding in vLLM #7962
Changes from 39 commits
5650b95
8f36146
9e75057
db2c679
8d7512c
1473f74
4013e1a
2dbdd78
b3575e9
94b0d43
fa8fedf
6ed96b4
b71c533
57babef
4b19bac
eb7a1c4
7e2c87e
6212d5f
5491438
68e080a
55e4332
532eb48
7cea056
185e056
e2be95f
2ed5473
085dea8
322463d
41be9c2
c4e477e
bea3399
e76b9fb
311b242
beb5b48
8328f6e
7a97508
c1e7773
37e2cc5
f6606d5
efa4714
fb87d34
a4ce5b8
b6e58c9
7004b00
fbec1be
25190a0
db12986
0f745e3
5419e49
c76a4e2
77d42c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -161,6 +161,46 @@ A variety of speculative models of this type are available on HF hub: | |
* `granite-7b-instruct-accelerator <https://huggingface.co/ibm-granite/granite-7b-instruct-accelerator>`_ | ||
* `granite-20b-code-instruct-accelerator <https://huggingface.co/ibm-granite/granite-20b-code-instruct-accelerator>`_ | ||
|
||
Lossless guarantees of Speculative Decoding | ||
------------------------------------------- | ||
In vLLM, speculative decoding aims to enhance inference efficiency while maintaining accuracy. This section addresses the lossless guarantees of | ||
speculative decoding, breaking down the guarantees into three key areas: | ||
|
||
1. **Theoretical Losslessness** | ||
- Speculative decoding sampling is theoretically lossless up to the precision limits of hardware numerics. Floating-point errors might | ||
cause slight variations in output distributions, as discussed | ||
in `Accelerating Large Language Model Decoding with Speculative Sampling <https://arxiv.org/pdf/2302.01318>`_ | ||
|
||
2. **Algorithmic Losslessness** | ||
- vLLM’s implementation of speculative decoding is algorithmically validated to be lossless when the | ||
temperature parameter (`temp`) is set to 0. Key tests include: | ||
|
||
- **Rejection Sampler Convergence**: Ensures that samples from vLLM’s rejection sampler align with the target | ||
distribution. `View Test Code <https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252>`_ | ||
|
||
- **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling | ||
without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler, | ||
provides a lossless guarantee. Almost all of the tests in `this directory <https://github.com/vllm-project/vllm/tree/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e>`_ | ||
verify this property using `this assertion implementation <https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291>`_ | ||
|
||
3. **vLLM Logprob Stability** | ||
- vLLM currently does not guarantee stable log probabilities (logprobs) across different batch sizes, which might | ||
cause small variations in output probabilities. | ||
This issue may stem from non-deterministic behaviors in batched operations or numerical instability in Torch operations. | ||
as explained in the `Numerical Accuracy section <https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations>`_ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sroy745 this isn't spec decoding specific, it applies generally when concurrent requests are batched differently. I guess would be good to have a dedicated section explaining that too... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a section for this in serving/faq.rst (I could not find any other generic place to add it. As you mentioned it is not specific to spec decode so thought of adding it to serving faqs). I added a link to it in this subsection. I am not sure if this is what you meant. PTAL and let me know. |
||
|
||
**Conclusion** | ||
|
||
While vLLM strives to ensure losslessness in speculative decoding, variations in generated outputs with and without speculative decoding | ||
can occur due to following factors: | ||
|
||
- **Floating-Point Precision**: Differences in hardware numerical precision may lead to slight discrepancies in the output distribution. | ||
|
||
- **Batch Size and Numerical Stability**: Changes in batch size may cause variations in logprobs and output probabilities, potentially | ||
due to non-deterministic behavior in batched operations or numerical instability. | ||
|
||
For stable generation across different runs, using request-seeds is recommended, although it may affect latency. For more information, | ||
refer to `Bugfix #6034 <https://github.com/vllm-project/vllm/issues/6034>`_. | ||
|
||
Resources for vLLM contributors | ||
------------------------------- | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the rejection sampler convergence tests also handle the case where temperature is nonzero, and/or other sampling parameters are applied.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Removed mention of temperature = 0 in the comment.