-
-
Notifications
You must be signed in to change notification settings - Fork 4.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Speculative decoding 3/9] Worker which speculates, scores, and applies rejection sampling #3103
[Speculative decoding 3/9] Worker which speculates, scores, and applies rejection sampling #3103
Conversation
cool! |
Ready for review. cc @LiuXiaoxuanPKU @ymwangg @robertgshaw2-neuralmagic @Yard1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great work! Just some minor questions & comments.
i for i, (_, proposal_len) in enumerate( | ||
zip(seq_group_metadata_list, proposal_lens_list)) | ||
if proposal_len == 0 | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: we can merge the two for loops above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems there are two concerns here:
- performance of two loops
- readability of two loops
I'll defer performance optimization until later; I'll put these into a helper function to make it easier to read.
@@ -0,0 +1,347 @@ | |||
from typing import List, Tuple, Optional, Dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, there is some complexity by separating the ways we can spec and non_spec sequences. In the future, we will remove the complexity by introducing variable proposed length and flashInfer kernel. Maybe we can add some comments about this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, will add some comment!
accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] | ||
k: int, | ||
) -> List[SamplerOutput]: | ||
"""Given the accepted token ids, create a list of SamplerOutput. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding comments for almost all the functions! Really appreciate it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😄
seq_data = next(iter(seq_group_metadata.seq_data.values())) | ||
seq_len = seq_data.get_len() | ||
|
||
if seq_len + max_proposal_len < self._max_model_len: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a comment here saying
(1) we want to address the different model length between the draft and target model.
(2) the proposal_lens can only be max_proposal_len or 0 for now. It can not be length between 0 and max_proposal_len.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good points, will add!
vllm/spec_decode/batch_expansion.py
Outdated
|
||
if non_spec_indices: | ||
all_tokens[non_spec_indices, 0] = non_spec_target_token_ids | ||
all_probs[non_spec_indices, 1:, :] = non_spec_target_probs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit confused by the 1 here, why starting from 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bug! Should be :1
. Saved me some headache during correctness testing, thank you 😄
vllm/spec_decode/batch_expansion.py
Outdated
all_tokens = torch.ones( | ||
original_bs, k + 1, device=self._device, dtype=torch.long) * -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of torch.ones * -1, do torch.full with -1 as the fill value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added, thanks
…es rejection sampling (vllm-project#3103)
…es rejection sampling (vllm-project#3103)
Hi, I was wondering why the Batch Expansion is designed to support only top-1 cases and not tree-based batches, any specific reasons for that? |
This PR implements a vLLM Worker which invokes a draft worker to obtain proposals, invokes a target worker to obtain probabilities of each proposal, and then applies rejection sampling to accept/reject each speculated token. It is a part of the speculative decoding contribution by Anyscale to vLLM, see #2188 for more info.
High-level design
The high-level design is as follows:
Currently, only the "draft model" approach to speculative decoding is implemented with top-1 proposals from the draft model, and lossless rejection sampling. In the future, other proposal approaches may be added, such as Medusa/Eagle (requiring top-k proposals/tree attention scoring), Lookahead, RAG, etc. The key contribution of this PR is a light framework for proposing, scoring, and verifying speculative tokens using non-contiguous KV memory.
Notes for reviewers
What is "batch expansion"?
This PR does not use MQA for scoring proposal tokens. Instead, it uses the single-query PagedAttention kernel (aka, normal vLLM decode attention) to perform scoring of the proposal tokens. This was done because at the time of implementation, we did not yet have performant MQA kernels for non-contiguous KV memory. We now have an abundance of these (notably, FlashAttention and FlashInfer, along with Triton implementations, e.g. in #2607). Batch expansion should be replaced by these to obtain some efficiency gain in verification time.
More details on batch expansion and the optimization opportunity can be found here.