Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] [Spec decode]: Enable MLPSpeculator/Medusa and prompt_logprobs with ChunkedPrefill #10132

Merged
merged 20 commits into from
Jan 27, 2025

Conversation

NickLucche
Copy link
Contributor

@NickLucche NickLucche commented Nov 7, 2024

Follow-up to #9291, an attempt at fixing prompt_logprobs and enabling hidden state-based speculators (MLP/Medusa).

The main issue with prompt_logprobs is that it changes the output of the mixed prefill-decode batch to have #prompt_tokens+#decode_tokens entries instead of just #sampling_entries (terminal-chunks only), and current code was not accounting for that.
My approach currently relies on splitting prefills and decodes processing to account for that; really open to anything more elegant here.

Regarding hidden states, we have to disregard those coming from non-terminal chunks (logits_processor already discards those, we simply have to adjust code to reflect it) to store the last latent we actually care about.

Benchmarks

Reporting results of benchmarks run on 4xA100-80GB with the following configuration (MultistepSpec regression check):

python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3.1-70B-Instruct --max-model-len 32768 -tp 4 --speculative-model meta-llama/Meta-Llama-3.1-8B-Instruct  --num-speculative-tokens 4 --speculative-draft-tensor-parallel-size 1 --enable_chunked_prefill True --max_num_batched_tokens 512 --max_num_seqs 32  
Benchmark client command python3 benchmarks/benchmark_serving.py \ --backend vllm \ --dataset-name sharegpt \ --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json \ --model meta-llama/Meta-Llama-3.1-70B-Instruct \ --tokenizer meta-llama/Meta-Llama-3.1-70B-Instruct \ --num-prompts 20 \ --endpoint /v1/completions \ --save-result --request-rate 2/4/6/8

Median TPOT reported here, slightly worse on median TPOT but slightly better throughput (consistent with MQAscorer too). Overall I'd say performance is similar.
image

Detail of the request rate=8 follows:

#PR-10132
============ Serving Benchmark Result ============
Successful requests:                     17        
Benchmark duration (s):                  18.81     
Total input tokens:                      2491      
Total generated tokens:                  4055      
Request throughput (req/s):              0.90      
Output token throughput (tok/s):         215.59    
Total Token throughput (tok/s):          348.02    
---------------Time to First Token----------------
Mean TTFT (ms):                          332.06    
Median TTFT (ms):                        279.83    
P99 TTFT (ms):                           720.11    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          37.02     
Median TPOT (ms):                        37.88     
P99 TPOT (ms):                           58.45     
---------------Inter-token Latency----------------
Mean ITL (ms):                           120.74    
Median ITL (ms):                         120.01    
P99 ITL (ms):                            226.19    
==================================================

# PRE-PR 
============ Serving Benchmark Result ============
Successful requests:                     17        
Benchmark duration (s):                  18.80     
Total input tokens:                      2491      
Total generated tokens:                  3904      
Request throughput (req/s):              0.90      
Output token throughput (tok/s):         207.71    
Total Token throughput (tok/s):          340.24    
---------------Time to First Token----------------
Mean TTFT (ms):                          320.24    
Median TTFT (ms):                        263.25    
P99 TTFT (ms):                           703.45    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          37.51     
Median TPOT (ms):                        37.53     
P99 TPOT (ms):                           58.51     
---------------Inter-token Latency----------------
Mean ITL (ms):                           120.03    
Median ITL (ms):                         119.22    
P99 ITL (ms):                            226.25    

For the sake of completeness I also run the same TP=4 benchmark on

  • (newly added) MLPSpeculator
  • (newly added) MLPSpeculator+logprobs=2
  • --enforce-eager=True to force MQAScorer (comparison)

All results here https://drive.google.com/file/d/1WOndRnE9STbr7TNNm-jBmHTyARKhkqIc/view?usp=sharing.

Copy link

github-actions bot commented Nov 7, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link

mergify bot commented Nov 7, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 7, 2024
@NickLucche NickLucche marked this pull request as ready for review November 26, 2024 14:57
@NickLucche NickLucche force-pushed the fix-prompt-logprobs-mlpspec branch from 45b4e73 to b423f32 Compare November 26, 2024 18:37
@mergify mergify bot removed the needs-rebase label Nov 26, 2024
Copy link

mergify bot commented Nov 27, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 27, 2024
@NickLucche NickLucche force-pushed the fix-prompt-logprobs-mlpspec branch from b423f32 to 88bba37 Compare November 27, 2024 09:31
@mergify mergify bot removed the needs-rebase label Nov 27, 2024
@sroy745
Copy link
Collaborator

sroy745 commented Dec 9, 2024

cc: @tdoublep, who has worked on the MLSpeculator. If you have time, appreciate your review of this PR. Thanks!

Copy link
Collaborator

@sroy745 sroy745 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pr!

I am wondering if we can split this pr into 2 - 1) one for enabling MLPSpeculator/Medusa and 2) for enabling the prompt logprobs.

The logic for enabling prompt log probabilities appears to be non-trivial. I'm wondering if this feature (chunked_prefill + sd + prompt_logprobs) is actively being requested. If not, can we consider postponing it for now, given the significant complexity involved in implementing it?

cc: @LiuXiaoxuanPKU

vllm/spec_decode/spec_decode_worker.py Show resolved Hide resolved
vllm/spec_decode/spec_decode_worker.py Show resolved Hide resolved
vllm/spec_decode/spec_decode_worker.py Show resolved Hide resolved
@tdoublep
Copy link
Member

tdoublep commented Dec 9, 2024

@sroy745 I should be able to take a look later this week.

@NickLucche
Copy link
Contributor Author

Thanks for reviewing this!

I am wondering if we can split this pr into 2 - 1) one for enabling MLPSpeculator/Medusa and 2) for enabling the prompt logprobs.

No problem on my side, let's wait for a second opinion before tearing the PR.

I'm wondering if this feature (chunked_prefill + sd + prompt_logprobs) is actively being requested.

AFAIK at least folks at IBM have shown immediate interest on this, let's wait for more input on this matter too.

@sroy745
Copy link
Collaborator

sroy745 commented Dec 11, 2024

Thanks for reviewing this!

I am wondering if we can split this pr into 2 - 1) one for enabling MLPSpeculator/Medusa and 2) for enabling the prompt logprobs.

No problem on my side, let's wait for a second opinion before tearing the PR.

I'm wondering if this feature (chunked_prefill + sd + prompt_logprobs) is actively being requested.

AFAIK at least folks at IBM have shown immediate interest on this, let's wait for more input on this matter too.

Thanks. I was not aware of this feature request. SG to include it given the feature request. I will continue with my review.

Copy link
Collaborator

@sroy745 sroy745 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for PR ! Left some comments.

Since this pr makes changes to the batch_expansion and mqa_scorer I am wondering if we can run the sd benchmark with and without this pr and ensure that there is no impact on the vanilla sd performance?

# Add all terminal chunks sizes as well as decodes with no
# speculation to get out tokens and skip over prompt ones.
seq_meta = contracted_seq_group_metadata_list
nospec_sizes = torch.tensor([
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are we handling non terminal chunks here? Don't we need to ignore non terminal chunks amongst the prefill sequences? If so how are we ensuring that?

Copy link
Contributor Author

@NickLucche NickLucche Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't ignore non-terminal chunks here, we actually have to pick their corresponding "output" token (which is always -1) so that it can be post-processed. These -1s are simply discarded later, but we're just complying with current state of post-processing.

Basically it's only needed to have matching number of num_input_request and num_outputs (tokens/probs).

I rephrased the comment to make it hopefully a bit clearer.

vllm/spec_decode/batch_expansion.py Show resolved Hide resolved
vllm/spec_decode/batch_expansion.py Show resolved Hide resolved

# Split loop into prefill|decode for readability.
start_loc, i = 0, 0
while i < len(target_seq_group_metadata_list
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if we can split this into 2 different separate method - one for handling the prefills and the other for the decodes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

easily, just wanted to highlight the fact that we're still only looping once

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on second though, I think the only option to keep things clean and avoid repetitions is to split the prefills|decodes list first and then process each in their own function; but then again, there's not that many lines here anyway..

tests/spec_decode/e2e/test_logprobs.py Outdated Show resolved Hide resolved
tests/spec_decode/e2e/test_logprobs.py Outdated Show resolved Hide resolved
tests/spec_decode/e2e/test_logprobs.py Outdated Show resolved Hide resolved
tests/spec_decode/e2e/test_multistep_correctness.py Outdated Show resolved Hide resolved
vllm/config.py Show resolved Hide resolved
# scheduling on baseline too, we get slightly different logprobs, ending
# up sampling different tokens at the tail (ie top tokens don't change).
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the scheduling changes can change the batching which in turn can lead to a different output? FAQ# 3 here refers to a similar issue https://docs.vllm.ai/en/latest/usage/faq.html.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that was my guess too

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: if we are using greedy decoding, should only one token has probability 1, all other tokens have probability 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can sample greedily but output is still a regular distr; regardless in this test we also compare the ranking of the tokens that were not sampled (not top1) and their prob

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I've seen issues like that quite frequently with fp16, but normally goes away in fp32 (which I think this test in running in). There is still no guarantee you will get exactly the same logprobs though.

@NickLucche
Copy link
Contributor Author

Sure good idea, let me address the review changes then I can post some numbers on that

seq_group_meta.token_chunk_size)
prompt_token_ids = prompt_token_ids[start:end]
prompt_logprobs = [
create_logprobs_output(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: if the user does not need logprobs, why do we still create a fake logprobs here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh I am not sure either, I suppose it's just to comply with post processing code. I would love to remove it, but we'd need a separate PR 'cause it's already there. Maybe @tjohnson31415 knows about this.

@NickLucche
Copy link
Contributor Author

I've added tests for Medusa (so that the PR content actually matches its title) but disabled CP compat with EAGLE, as we still have some issues to address there.
I'd rather have that in a separate contribution since this was one has already grown to an unpleasant size.

NickLucche and others added 5 commits December 18, 2024 12:49
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: wallashss <wallashss@ibm.com>
Co-authored-by: wallashss <wallashss@ibm.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
@NickLucche NickLucche force-pushed the fix-prompt-logprobs-mlpspec branch from d3131b6 to bbcc807 Compare December 18, 2024 12:49
Copy link
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding support for this! I have a few questions.

May be missing something, but it seems like most of the changes here related to enabling prompt logprobs with chunked prefill + spec decode, and the changes related to MLPSpec/Medusa (e.g., the hidden states stuff) is a relatively small piece?

# scheduling on baseline too, we get slightly different logprobs, ending
# up sampling different tokens at the tail (ie top tokens don't change).
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I've seen issues like that quite frequently with fp16, but normally goes away in fp32 (which I think this test in running in). There is still no guarantee you will get exactly the same logprobs though.

@@ -418,15 +441,19 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
# Use smaller output len for fast test.
32,
])
# test with chunk size >= `speculative_disable_by_batch_size`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would this be a case that we need to test? Isn't prefill_chunk_size measured in tokens, and the speculative_disable_by_batch_size measured in number of sequences?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I will change the comment, I was referring to the arg to maybe_enable_chunked_prefill, which sets both number of tokens as well as number of sequences. In practice you end up "converging" to batch_size=max_num_seqs=prefill_chunk_size as you get all decodes with size 1.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update the comment here since it's confused.
Also why is max_num_seqs = prefill_chunk_size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No reason, this is just an arbitrary test value that I kept from existing tests

Comment on lines +57 to +58
# Scoring model may also return logprobs for prompt tokens
# for each request, when chunked prefill is enabled.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also generate logprobs using spec decode if chunked prefill is not enabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes previously available features shouldn't be affected by this change. We were able to get away with fewer loc because we either had all prompts (so no spec, here https://github.com/vllm-project/vllm/blob/main/vllm/spec_decode/spec_decode_worker.py#L654) or all decodes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In particular, addition here was necessary because the Speculator now runs prefills too, so it needed a way to report the prompt_logprobs back to the worker.

Comment on lines +576 to +577
start = 1 if seq_data._num_computed_tokens == 0 \
else seq_data._num_computed_tokens
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we skip first location when seq_data._num_computed_tokens==0 but not otherwise?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have no prob for the first token of the prompt, p(x_i|x_i-1... )=>there is no prior p(x0) .
Same logic as https://github.com/vllm-project/vllm/blob/main/vllm/spec_decode/spec_decode_worker.py#L591.

@NickLucche
Copy link
Contributor Author

Thanks for the review!

it seems like most of the changes here related to enabling prompt logprobs

I am afraid so, I was hoping support for plogs could be added with less effort, but it ended up having to be quite invasive.

Copy link
Collaborator

@sroy745 sroy745 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Left one comment about a TP > 1 test. PTAL

I am wondering if we could do one round of testing around the following

  1. Compare the TPOT with and without this pr for vanilla sd runs (with and without prompt logprobs) and make sure there is no regression.

  2. Do a sanity run for an MLP Speculator with target model tp>=1

  3. Do a sanity run for sanity run for chunked_prefill + sd for regular draft model + target model tp >=1 (with and without logprobs)

tests/spec_decode/e2e/test_logprobs.py Show resolved Hide resolved
@joerunde joerunde added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 13, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
@NickLucche
Copy link
Contributor Author

I've added the benchmark results @sroy745, let me know what you think

@sroy745
Copy link
Collaborator

sroy745 commented Jan 15, 2025

Thanks for the sharing the results. LGTM. There are some spec_decoding test failures. PTAL

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
@njhill
Copy link
Member

njhill commented Jan 23, 2025

Thanks for the heroic efforts on this @NickLucche, and the detailed reviews @sroy745 @tdoublep.

@njhill
Copy link
Member

njhill commented Jan 23, 2025

@LiuXiaoxuanPKU is giving this a final look over.

Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work here. I'm good with it now, some very minor things.

tests/spec_decode/e2e/conftest.py Outdated Show resolved Hide resolved
@@ -418,15 +441,19 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
# Use smaller output len for fast test.
32,
])
# test with chunk size >= `speculative_disable_by_batch_size`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update the comment here since it's confused.
Also why is max_num_seqs = prefill_chunk_size?

NickLucche and others added 2 commits January 27, 2025 08:53
@njhill njhill merged commit 6116ca8 into vllm-project:main Jan 27, 2025
49 checks passed
tjtanaa pushed a commit to EmbeddedLLM/vllm that referenced this pull request Jan 28, 2025
…robs` with ChunkedPrefill (vllm-project#10132)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: wallashss <wallashss@ibm.com>
Co-authored-by: wallashss <wallashss@ibm.com>
rasmith pushed a commit to rasmith/vllm that referenced this pull request Jan 30, 2025
…robs` with ChunkedPrefill (vllm-project#10132)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: wallashss <wallashss@ibm.com>
Co-authored-by: wallashss <wallashss@ibm.com>
Isotr0py pushed a commit to Isotr0py/vllm that referenced this pull request Feb 2, 2025
…robs` with ChunkedPrefill (vllm-project#10132)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: wallashss <wallashss@ibm.com>
Co-authored-by: wallashss <wallashss@ibm.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
NickLucche added a commit to NickLucche/vllm that referenced this pull request Feb 7, 2025
…robs` with ChunkedPrefill (vllm-project#10132)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: wallashss <wallashss@ibm.com>
Co-authored-by: wallashss <wallashss@ibm.com>
ShangmingCai pushed a commit to ShangmingCai/vllm that referenced this pull request Feb 10, 2025
…robs` with ChunkedPrefill (vllm-project#10132)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: wallashss <wallashss@ibm.com>
Co-authored-by: wallashss <wallashss@ibm.com>
GWS0428 pushed a commit to GWS0428/VARserve that referenced this pull request Feb 12, 2025
…robs` with ChunkedPrefill (vllm-project#10132)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: wallashss <wallashss@ibm.com>
Co-authored-by: wallashss <wallashss@ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants