-
-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] [Spec decode]: Enable MLPSpeculator/Medusa and prompt_logprobs
with ChunkedPrefill
#10132
[Feature] [Spec decode]: Enable MLPSpeculator/Medusa and prompt_logprobs
with ChunkedPrefill
#10132
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
This pull request has merge conflicts that must be resolved before it can be |
45b4e73
to
b423f32
Compare
This pull request has merge conflicts that must be resolved before it can be |
b423f32
to
88bba37
Compare
cc: @tdoublep, who has worked on the MLSpeculator. If you have time, appreciate your review of this PR. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pr!
I am wondering if we can split this pr into 2 - 1) one for enabling MLPSpeculator/Medusa and 2) for enabling the prompt logprobs.
The logic for enabling prompt log probabilities appears to be non-trivial. I'm wondering if this feature (chunked_prefill + sd + prompt_logprobs) is actively being requested. If not, can we consider postponing it for now, given the significant complexity involved in implementing it?
cc: @LiuXiaoxuanPKU
@sroy745 I should be able to take a look later this week. |
Thanks for reviewing this!
No problem on my side, let's wait for a second opinion before tearing the PR.
AFAIK at least folks at IBM have shown immediate interest on this, let's wait for more input on this matter too. |
Thanks. I was not aware of this feature request. SG to include it given the feature request. I will continue with my review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for PR ! Left some comments.
Since this pr makes changes to the batch_expansion and mqa_scorer I am wondering if we can run the sd benchmark with and without this pr and ensure that there is no impact on the vanilla sd performance?
vllm/spec_decode/batch_expansion.py
Outdated
# Add all terminal chunks sizes as well as decodes with no | ||
# speculation to get out tokens and skip over prompt ones. | ||
seq_meta = contracted_seq_group_metadata_list | ||
nospec_sizes = torch.tensor([ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are we handling non terminal chunks here? Don't we need to ignore non terminal chunks amongst the prefill sequences? If so how are we ensuring that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't ignore non-terminal chunks here, we actually have to pick their corresponding "output" token (which is always -1) so that it can be post-processed. These -1s are simply discarded later, but we're just complying with current state of post-processing.
Basically it's only needed to have matching number of num_input_request and num_outputs (tokens/probs).
I rephrased the comment to make it hopefully a bit clearer.
|
||
# Split loop into prefill|decode for readability. | ||
start_loc, i = 0, 0 | ||
while i < len(target_seq_group_metadata_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am wondering if we can split this into 2 different separate method - one for handling the prefills and the other for the decodes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
easily, just wanted to highlight the fact that we're still only looping once
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on second though, I think the only option to keep things clean and avoid repetitions is to split the prefills|decodes list first and then process each in their own function; but then again, there's not that many lines here anyway..
# scheduling on baseline too, we get slightly different logprobs, ending | ||
# up sampling different tokens at the tail (ie top tokens don't change). | ||
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected? | ||
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the scheduling changes can change the batching which in turn can lead to a different output? FAQ# 3 here refers to a similar issue https://docs.vllm.ai/en/latest/usage/faq.html.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that was my guess too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: if we are using greedy decoding, should only one token has probability 1, all other tokens have probability 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can sample greedily but output is still a regular distr; regardless in this test we also compare the ranking of the tokens that were not sampled (not top1) and their prob
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I've seen issues like that quite frequently with fp16, but normally goes away in fp32 (which I think this test in running in). There is still no guarantee you will get exactly the same logprobs though.
Sure good idea, let me address the review changes then I can post some numbers on that |
seq_group_meta.token_chunk_size) | ||
prompt_token_ids = prompt_token_ids[start:end] | ||
prompt_logprobs = [ | ||
create_logprobs_output( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: if the user does not need logprobs, why do we still create a fake logprobs here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tbh I am not sure either, I suppose it's just to comply with post processing code. I would love to remove it, but we'd need a separate PR 'cause it's already there. Maybe @tjohnson31415 knows about this.
I've added tests for Medusa (so that the PR content actually matches its title) but disabled CP compat with EAGLE, as we still have some issues to address there. |
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: wallashss <wallashss@ibm.com> Co-authored-by: wallashss <wallashss@ibm.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
d3131b6
to
bbcc807
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding support for this! I have a few questions.
May be missing something, but it seems like most of the changes here related to enabling prompt logprobs with chunked prefill + spec decode, and the changes related to MLPSpec/Medusa (e.g., the hidden states stuff) is a relatively small piece?
# scheduling on baseline too, we get slightly different logprobs, ending | ||
# up sampling different tokens at the tail (ie top tokens don't change). | ||
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected? | ||
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I've seen issues like that quite frequently with fp16, but normally goes away in fp32 (which I think this test in running in). There is still no guarantee you will get exactly the same logprobs though.
@@ -418,15 +441,19 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs, | |||
# Use smaller output len for fast test. | |||
32, | |||
]) | |||
# test with chunk size >= `speculative_disable_by_batch_size` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would this be a case that we need to test? Isn't prefill_chunk_size
measured in tokens, and the speculative_disable_by_batch_size
measured in number of sequences?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! I will change the comment, I was referring to the arg to maybe_enable_chunked_prefill
, which sets both number of tokens as well as number of sequences. In practice you end up "converging" to batch_size=max_num_seqs=prefill_chunk_size
as you get all decodes with size 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you update the comment here since it's confused.
Also why is max_num_seqs = prefill_chunk_size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No reason, this is just an arbitrary test value that I kept from existing tests
# Scoring model may also return logprobs for prompt tokens | ||
# for each request, when chunked prefill is enabled. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also generate logprobs using spec decode if chunked prefill is not enabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes previously available features shouldn't be affected by this change. We were able to get away with fewer loc because we either had all prompts (so no spec, here https://github.com/vllm-project/vllm/blob/main/vllm/spec_decode/spec_decode_worker.py#L654) or all decodes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In particular, addition here was necessary because the Speculator now runs prefills too, so it needed a way to report the prompt_logprobs back to the worker.
start = 1 if seq_data._num_computed_tokens == 0 \ | ||
else seq_data._num_computed_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we skip first location when seq_data._num_computed_tokens==0
but not otherwise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have no prob for the first token of the prompt, p(x_i|x_i-1... )=>there is no prior p(x0) .
Same logic as https://github.com/vllm-project/vllm/blob/main/vllm/spec_decode/spec_decode_worker.py#L591.
Thanks for the review!
I am afraid so, I was hoping support for plogs could be added with less effort, but it ended up having to be quite invasive. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Left one comment about a TP > 1 test. PTAL
I am wondering if we could do one round of testing around the following
-
Compare the TPOT with and without this pr for vanilla sd runs (with and without prompt logprobs) and make sure there is no regression.
-
Do a sanity run for an MLP Speculator with target model tp>=1
-
Do a sanity run for sanity run for chunked_prefill + sd for regular draft model + target model tp >=1 (with and without logprobs)
Signed-off-by: NickLucche <nlucches@redhat.com>
I've added the benchmark results @sroy745, let me know what you think |
Thanks for the sharing the results. LGTM. There are some spec_decoding test failures. PTAL |
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Thanks for the heroic efforts on this @NickLucche, and the detailed reviews @sroy745 @tdoublep. |
@LiuXiaoxuanPKU is giving this a final look over. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work here. I'm good with it now, some very minor things.
@@ -418,15 +441,19 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs, | |||
# Use smaller output len for fast test. | |||
32, | |||
]) | |||
# test with chunk size >= `speculative_disable_by_batch_size` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you update the comment here since it's confused.
Also why is max_num_seqs = prefill_chunk_size?
…robs` with ChunkedPrefill (vllm-project#10132) Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: wallashss <wallashss@ibm.com> Co-authored-by: wallashss <wallashss@ibm.com>
…robs` with ChunkedPrefill (vllm-project#10132) Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: wallashss <wallashss@ibm.com> Co-authored-by: wallashss <wallashss@ibm.com>
…robs` with ChunkedPrefill (vllm-project#10132) Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: wallashss <wallashss@ibm.com> Co-authored-by: wallashss <wallashss@ibm.com> Signed-off-by: Isotr0py <2037008807@qq.com>
…robs` with ChunkedPrefill (vllm-project#10132) Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: wallashss <wallashss@ibm.com> Co-authored-by: wallashss <wallashss@ibm.com>
…robs` with ChunkedPrefill (vllm-project#10132) Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: wallashss <wallashss@ibm.com> Co-authored-by: wallashss <wallashss@ibm.com>
…robs` with ChunkedPrefill (vllm-project#10132) Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: wallashss <wallashss@ibm.com> Co-authored-by: wallashss <wallashss@ibm.com>
Follow-up to #9291, an attempt at fixing prompt_logprobs and enabling hidden state-based speculators (MLP/Medusa).
The main issue with prompt_logprobs is that it changes the output of the mixed prefill-decode batch to have #prompt_tokens+#decode_tokens entries instead of just #sampling_entries (terminal-chunks only), and current code was not accounting for that.
My approach currently relies on splitting prefills and decodes processing to account for that; really open to anything more elegant here.
Regarding hidden states, we have to disregard those coming from non-terminal chunks (logits_processor already discards those, we simply have to adjust code to reflect it) to store the last latent we actually care about.
Benchmarks
Reporting results of benchmarks run on 4xA100-80GB with the following configuration (MultistepSpec regression check):
Benchmark client command
python3 benchmarks/benchmark_serving.py \ --backend vllm \ --dataset-name sharegpt \ --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json \ --model meta-llama/Meta-Llama-3.1-70B-Instruct \ --tokenizer meta-llama/Meta-Llama-3.1-70B-Instruct \ --num-prompts 20 \ --endpoint /v1/completions \ --save-result --request-rate 2/4/6/8
Median TPOT reported here, slightly worse on median TPOT but slightly better throughput (consistent with MQAscorer too). Overall I'd say performance is similar.
![image](https://private-user-images.githubusercontent.com/10706289/403080744-fd3ca493-7965-4d2d-a13e-7b218970b5d0.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk1NzU2MzMsIm5iZiI6MTczOTU3NTMzMywicGF0aCI6Ii8xMDcwNjI4OS80MDMwODA3NDQtZmQzY2E0OTMtNzk2NS00ZDJkLWExM2UtN2IyMTg5NzBiNWQwLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTQlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE0VDIzMjIxM1omWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTU2NDMwZjQ0MjgxNGI3Njc0NmIwOGI1OGMwOTdlMzg3MzRkMDJmNThhNGU4ODEyZDk2ODFjOGE4YTU4MWFkNWUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.2wrBDrr5jk6eUCa3T-59WkwpZgxtJuECmGqZ6weEdIY)
Detail of the
request rate=8
follows:For the sake of completeness I also run the same TP=4 benchmark on
--enforce-eager=True
to forceMQAScorer
(comparison)All results here https://drive.google.com/file/d/1WOndRnE9STbr7TNNm-jBmHTyARKhkqIc/view?usp=sharing.