-
-
Couldn't load subscription status.
- Fork 10.8k
[Bugfix] [Performance]Better MTP Support when use flashmla #24045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses two bugs related to speculative decoding with DeepSeek models and adds support for the flashmla backend. The fixes in vllm/v1/spec_decode/eagle.py and vllm/v1/attention/backends/mla/common.py appear correct. However, the new logic for handling variable-length batches in vllm/v1/attention/backends/mla/flashmla.py contains a critical bug in how it detects uniform sequence lengths, which could lead to incorrect model outputs. I have provided a detailed comment and a suggested fix for this issue.
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
Signed-off-by: lhsjohn <huashuoli@tencent.com>
Signed-off-by: lhsjohn <huashuoli@tencent.com>
Signed-off-by: lhsjohn <huashuoli@tencent.com>
9e88df0 to
fbd4dee
Compare
Signed-off-by: lhsjohn <huashuoli@tencent.com>
|
Hi @lhsjohn, a few concerns:
Please see #22684, which includes these changes as well as support for FlashInfer-MLA and performance optimizations to allow the padded approach to run with no synchronization points between the verification and drafting phase. This is necessary to enable overlapped execution in the future (see #23569 and #22262 for context). |
vllm/v1/spec_decode/eagle.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for this change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your suggestion!. I didn't have this commit when I modified it. I adjusted it here when I rebased to resolve the conflict. hh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current problem has been solved by rebase
Thank you for your sincere advice. I will take a look at the two points you mentioned. |
|
This pull request has merge conflicts that must be resolved before it can be |
I've implemented the Smart Decode Classification approach as referenced in #21984. During local testing, I observed a slight performance regression of about 3% compared to our current padding-based solution. |
Signed-off-by: lhsjohn <huashuoli@tencent.com>
11dd897 to
f02168a
Compare
c678a1e to
138b40b
Compare
…decodes_and_prefills, when use flashmla backend and mtp1, set require_uniform = true in split_decodes_and_prefills to support flashmla kernel in decode phrase Signed-off-by: lhsjohn <huashuoli@tencent.com>
Signed-off-by: lhsjohn <huashuoli@tencent.com>
39a3d30 to
d1dcc97
Compare
Hello, I appreciate seeing your message on the PR. I'm lhsjohn, the PR proposer. I've submitted another version based on your suggestions. If you have time, could you please take a look? Key points:
|
|
|
||
| self.speculative_config = vllm_config.speculative_config | ||
| # Set reorder_batch_threshold based on speculative config | ||
| if (self.speculative_config is not None and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this might have negative consequences for backends which do not have kernel support for spec-friendly decodes. If so, we might want to have a per-backend flag to modulate when we apply this. Something like:
reorder_batch_threshold: ClassVar[int] = 1
supports_spec_decodes: ClassVar[bool] = false
...
self.speculative_config = vllm_config.speculative_config
# Set reorder_batch_threshold based on speculative config
if (self.supports_spec_decodes and
self.speculative_config is not None and
self.speculative_config.num_speculative_tokens is not None):
self.reorder_batch_threshold = ( # type: ignore[misc]
1 + self.speculative_config.num_speculative_tokens)
else:
self.reorder_batch_threshold = 1 # type: ignore[misc]|
|
||
| assert isinstance(q, torch.Tensor) | ||
|
|
||
| batch_size = attn_metadata.decode.seq_lens.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you refactor this into a utility function? It will likely need to be called in each backend that supports this feature (FlashInfer-MLA at least), so it will be nice to be able to reuse the logic.
| """ | ||
|
|
||
| if require_uniform: | ||
| return split_decodes_and_prefills_uniform(common_attn_metadata, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of a separate function couldn't we just do something like:
if require_uniform:
decode_threshold = min(decode_threshold, min(query_lens))
argmax should return the first instance of is_prefill so it should be safe, we just need to drop:
assert torch.all(query_lens[first_prefill:] > decode_threshold)
but we have to drop that for #24845 anyways
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh sorry I see you want to handle the
[2, 2, 2, 1, 5] case; I think this quite unlikely but I think we can handle this pretty simply by doing something like
# all prefills fast out
if query_lens[0] > decode_threshold:
return 0, num_reqs, 0, num_tokens
if require_uniform:
is_prefill = query_lens != query_lens[0]
else:
is_prefill = query_lens > decode_threshold
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(and still dropping assert torch.all(query_lens[first_prefill:] > decode_threshold))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasWilkinson I think the current implementation is probably correct. In the case of [1, 2, 2, 1, 10] with decode_threshold = 2, we want to return [1] for decodes and not [2, 2] or [1, 1]. The decodes sequence must be a prefix of the requests since we only return num_decodes and that is used to determine how far from the front we should slice.
To handle this more thoroughly you would have to modify the batch reordering code. This PR doesn't, and only does a best-effort pass to read uniform decodes from the front, falling back to prefills if there's a mismatch. I think that is fine for now.
Edit* to make the example a better counterexample.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh ya sorry im not doubting the correctness of the current implementation, sorry for the confusion!; I was just suggesting we can just modify the existing implementation and do:
# all prefills fast out
if query_lens[0] > decode_threshold:
return 0, num_reqs, 0, num_tokens
if require_uniform:
is_prefill = query_lens != query_lens[0]
else:
is_prefill = query_lens > decode_threshold
instead of the current
is_prefill = query_lens > decode_threshold
(and remove assert torch.all(query_lens[first_prefill:] > decode_threshold))
then we wouldn't need the separate function and could achieve the same effect with alot less code (and it would be vectorized)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasWilkinson It's not clear to me why this is doable. You're talking about a modification to split_decodes_and_prefills, right? In this case, I think it's possible that it could receive an input like [2, 1, 2, 1, 2, 1] In which case you would need to split into decode [2] and prefills [1, 2, 1, 2, 1]. You would not be able to do seq_lens == 2 and split into [2, 2, 2] and [1, 1, 1] since these are not contiguous in the input request array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh it's because 'is_prefill' is fed into 'argmax' to find the split point which should return the index of the first prefill and ignore any subsequent decodes
|
I have added some utilities in #25183 that will support this PR and others to enable MTP/Spec support in a common interface. Included are many of the refactors i requested in my review here, so you do not have to duplicate the effort. |
|
Closing as #26541 has accomplished this. |
Purpose
This PR resolves two critical issues when using DeepSeek models with speculative decoding (multi-token prediction) enabled:
Test Plan
Test Environment:
Hardware: 8× H20 141GB
Models: DeepSeek-R1
vLLM: v0.10.1
test script
benchmark test
python3 ./bench_serving.py --backend vllm --dataset-name random --model deepseek-r1 --tokenizer ./tokenizer --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json --random-input-len 3500 --random-output-len 1000 --random-range-ratio 1 --request-rate 1 --max-concurrency 8 --num-prompts 128 --base-url http://xxx:8021 --host 0.0.0.0 --port 8000Test Result