Skip to content

Conversation

@lhsjohn
Copy link
Contributor

@lhsjohn lhsjohn commented Sep 1, 2025

Purpose

This PR resolves two critical issues when using DeepSeek models with speculative decoding (multi-token prediction) enabled:

  1. Fixes AssertionError: assert torch.all(query_lens[first_prefill:] > decode_threshold) when performing request processing with mtp1 enabled
  2. Fixes "cannot access local variable 'hidden_states' where it is not associated with a value" with mtp1 enabled see: [Bug]: when user MTP method will cause UnboundLocalErro
  3. Fixed the issue when mtp and FlashMLA are enabled, decoding requests are mistakenly treated as prefill requests and processed via FlashAttention, causing performance degradation​

Test Plan

​​Test Environment:​​
Hardware: 8× H20 141GB
Models: DeepSeek-R1
vLLM: v0.10.1
test script

export VLLM_ATTENTION_BACKEND=FLASHMLA
vllm serve /xxx/DeepSeek-R1 \
           --trust-remote-code \
           --block-size 64 \
           --served-model-name deepseek-r1 \
           --max-model-len 28672 \
           --max-num-seqs 32 \
           --gpu-memory-utilization 0.85 \
           -tp 8 \
           --enable-expert-parallel \
           --enable-eplb \
           --max-num-batched-tokens 28672 \
           --enable-chunked-prefill \
           --no-enable-prefix-caching \
           --port 8021 \
           --speculative-config='{"method": "deepseek_mtp", "num_speculative_tokens": 1}' \
           --load-format "auto"

benchmark test

python3 ./bench_serving.py --backend vllm --dataset-name random --model deepseek-r1 --tokenizer ./tokenizer --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json --random-input-len 3500 --random-output-len 1000 --random-range-ratio 1 --request-rate 1 --max-concurrency 8 --num-prompts 128 --base-url http://xxx:8021 --host 0.0.0.0 --port 8000

Test Result

1、bs8
mtp1, prefill use flash attn
============ Serving Benchmark Result ============
Backend:                                 vllm      
Traffic request rate:                    1.0       
Max reqeuest concurrency:                8         
Successful requests:                     128       
Benchmark duration (s):                  1532.81   
Total input tokens:                      448000    
Total generated tokens:                  128000    
Total generated tokens (retokenized):    71842     
Request throughput (req/s):              0.08      
Input token throughput (tok/s):          292.27    
Output token throughput (tok/s):         83.51     
Total token throughput (tok/s):          375.78    
Concurrency:                             7.94      
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   95085.81  
Median E2E Latency (ms):                 95407.25  
---------------Time to First Token----------------
Mean TTFT (ms):                          779.29    
Median TTFT (ms):                        687.72    
P99 TTFT (ms):                           2110.51   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          94.40     
Median TPOT (ms):                        94.60     
P99 TPOT (ms):                           102.68    
---------------Inter-token Latency----------------
Mean ITL (ms):                           168.32    
Median ITL (ms):                         163.09    
P99 ITL (ms):                            513.78    

bs8:
2、mtp1, prefill use flashmla
============ Serving Benchmark Result ============
Backend:                                 vllm      
Traffic request rate:                    1.0       
Max reqeuest concurrency:                8         
Successful requests:                     128       
Benchmark duration (s):                  551.37    
Total input tokens:                      448000    
Total generated tokens:                  128000    
Total generated tokens (retokenized):    72078     
Request throughput (req/s):              0.23      
Input token throughput (tok/s):          812.52    
Output token throughput (tok/s):         232.15    
Total token throughput (tok/s):          1044.66   
Concurrency:                             7.91      
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   34088.88  
Median E2E Latency (ms):                 33913.39  
---------------Time to First Token----------------
Mean TTFT (ms):                          626.44    
Median TTFT (ms):                        480.36    
P99 TTFT (ms):                           2261.17   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          33.50     
Median TPOT (ms):                        33.41     
P99 TPOT (ms):                           39.51     
---------------Inter-token Latency----------------
Mean ITL (ms):                           59.53     
Median ITL (ms):                         54.37     
P99 ITL (ms):                            420.68    
==================================================

​​Mean TPOT (Time Per Output Token):​​ Reduced from ​​94.40ms​​ to ​​33.41ms​​
Mean TTFT (Time To First Token):​​ Reduced from ​​779.29ms​​ to ​​626.44ms​​

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses two bugs related to speculative decoding with DeepSeek models and adds support for the flashmla backend. The fixes in vllm/v1/spec_decode/eagle.py and vllm/v1/attention/backends/mla/common.py appear correct. However, the new logic for handling variable-length batches in vllm/v1/attention/backends/mla/flashmla.py contains a critical bug in how it detects uniform sequence lengths, which could lead to incorrect model outputs. I have provided a detailed comment and a suggested fix for this issue.

@lhsjohn lhsjohn changed the title [Bugfix] DeepSeek MTP assertion error and local variable access error in vllm/v1/spec_decode/eagle.py [Bugfix] DeepSeek MTP assertion error in vllm/v1/attention/backends/utils.py and local variable access error in vllm/v1/spec_decode/eagle.py Sep 1, 2025
@github-actions
Copy link

github-actions bot commented Sep 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Signed-off-by: lhsjohn <huashuoli@tencent.com>
Signed-off-by: lhsjohn <huashuoli@tencent.com>
Signed-off-by: lhsjohn <huashuoli@tencent.com>
@lhsjohn lhsjohn force-pushed the feature/deepseek-mtp-flashmla branch from 9e88df0 to fbd4dee Compare September 1, 2025 13:50
Signed-off-by: lhsjohn <huashuoli@tencent.com>
@mergify mergify bot added the documentation Improvements or additions to documentation label Sep 1, 2025
@facebook-github-bot
Copy link

@bwasti has imported this pull request. If you are a Meta employee, you can view this in D81590440.

@benchislett
Copy link
Collaborator

Hi @lhsjohn, a few concerns:

  • Padding the inputs directly in each attention op is not ideal. Padding the entire inputs, as described in [Performance]: Padded Speculative Decoding #21984, has numerous advantages to this approach. If you have strong opinions on the matter, I would love to discuss further.
  • Not all MLA backends support a query size dimension for their inputs, especially on blackwell. Those backends should still be able to use prefill kernels where possible. This means that reorder_batch_threshold needs to be handled more carefully.

Please see #22684, which includes these changes as well as support for FlashInfer-MLA and performance optimizations to allow the padded approach to run with no synchronization points between the verification and drafting phase. This is necessary to enable overlapped execution in the future (see #23569 and #22262 for context).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for this change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion!. I didn't have this commit when I modified it. I adjusted it here when I rebased to resolve the conflict. hh

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current problem has been solved by rebase

@lhsjohn
Copy link
Contributor Author

lhsjohn commented Sep 4, 2025

Hi @lhsjohn, a few concerns:

  • Padding the inputs directly in each attention op is not ideal. Padding the entire inputs, as described in [Performance]: Better MTP Support (decode optimized) #21984, has numerous advantages to this approach. If you have strong opinions on the matter, I would love to discuss further.
  • Not all MLA backends support a query size dimension for their inputs, especially on blackwell. Those backends should still be able to use prefill kernels where possible. This means that reorder_batch_threshold needs to be handled more carefully.

Please see #22684, which includes these changes as well as support for FlashInfer-MLA and performance optimizations to allow the padded approach to run with no synchronization points between the verification and drafting phase. This is necessary to enable overlapped execution in the future (see #23569 and #22262 for context).

Thank you for your sincere advice. I will take a look at the two points you mentioned.

@mergify
Copy link

mergify bot commented Sep 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @lhsjohn.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@lhsjohn
Copy link
Contributor Author

lhsjohn commented Sep 9, 2025

Hi @lhsjohn, a few concerns:

  • Padding the inputs directly in each attention op is not ideal. Padding the entire inputs, as described in [Performance]: Better MTP Support (decode optimized) #21984, has numerous advantages to this approach. If you have strong opinions on the matter, I would love to discuss further.
  • Not all MLA backends support a query size dimension for their inputs, especially on blackwell. Those backends should still be able to use prefill kernels where possible. This means that reorder_batch_threshold needs to be handled more carefully.

Please see #22684, which includes these changes as well as support for FlashInfer-MLA and performance optimizations to allow the padded approach to run with no synchronization points between the verification and drafting phase. This is necessary to enable overlapped execution in the future (see #23569 and #22262 for context).

I've implemented the Smart Decode Classification approach as referenced in #21984. During local testing, I observed a slight performance regression of about 3% compared to our current padding-based solution.

Signed-off-by: lhsjohn <huashuoli@tencent.com>
@lhsjohn lhsjohn force-pushed the feature/deepseek-mtp-flashmla branch from 11dd897 to f02168a Compare September 9, 2025 14:23
@lhsjohn lhsjohn force-pushed the feature/deepseek-mtp-flashmla branch 3 times, most recently from c678a1e to 138b40b Compare September 10, 2025 03:37
…decodes_and_prefills, when use flashmla backend and mtp1, set require_uniform = true in split_decodes_and_prefills to support flashmla kernel in decode phrase

Signed-off-by: lhsjohn <huashuoli@tencent.com>
Signed-off-by: lhsjohn <huashuoli@tencent.com>
@lhsjohn lhsjohn force-pushed the feature/deepseek-mtp-flashmla branch from 39a3d30 to d1dcc97 Compare September 10, 2025 04:02
@lhsjohn lhsjohn changed the title [Bugfix] DeepSeek MTP assertion error in vllm/v1/attention/backends/utils.py and local variable access error in vllm/v1/spec_decode/eagle.py [Bugfix] [Performance]Better MTP Support when use flashmla Sep 10, 2025
@lhsjohn
Copy link
Contributor Author

lhsjohn commented Sep 10, 2025

Hi @lhsjohn, a few concerns:

  • Padding the inputs directly in each attention op is not ideal. Padding the entire inputs, as described in [Performance]: Better MTP Support (decode optimized) #21984, has numerous advantages to this approach. If you have strong opinions on the matter, I would love to discuss further.
  • Not all MLA backends support a query size dimension for their inputs, especially on blackwell. Those backends should still be able to use prefill kernels where possible. This means that reorder_batch_threshold needs to be handled more carefully.

Please see #22684, which includes these changes as well as support for FlashInfer-MLA and performance optimizations to allow the padded approach to run with no synchronization points between the verification and drafting phase. This is necessary to enable overlapped execution in the future (see #23569 and #22262 for context).

Hello, I appreciate seeing your message on the PR. I'm lhsjohn, the PR proposer. I've submitted another version based on your suggestions. If you have time, could you please take a look?

Key points:

  1. Removed the padding logic in flashmla.py forward_decode
  2. Added a parameter to split_decodes_and_prefills to control how decode requests are split into even batches
  3. The backend can control whether uniform batching is required based on its own needs.


self.speculative_config = vllm_config.speculative_config
# Set reorder_batch_threshold based on speculative config
if (self.speculative_config is not None and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might have negative consequences for backends which do not have kernel support for spec-friendly decodes. If so, we might want to have a per-backend flag to modulate when we apply this. Something like:

    reorder_batch_threshold: ClassVar[int] = 1
    supports_spec_decodes: ClassVar[bool] = false
...
        self.speculative_config = vllm_config.speculative_config
        # Set reorder_batch_threshold based on speculative config
        if (self.supports_spec_decodes and 
                self.speculative_config is not None and
                self.speculative_config.num_speculative_tokens is not None):
            self.reorder_batch_threshold = (  # type: ignore[misc]
                1 + self.speculative_config.num_speculative_tokens)
        else:
            self.reorder_batch_threshold = 1  # type: ignore[misc]


assert isinstance(q, torch.Tensor)

batch_size = attn_metadata.decode.seq_lens.shape[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you refactor this into a utility function? It will likely need to be called in each backend that supports this feature (FlashInfer-MLA at least), so it will be nice to be able to reuse the logic.

"""

if require_uniform:
return split_decodes_and_prefills_uniform(common_attn_metadata,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of a separate function couldn't we just do something like:

if require_uniform:
      decode_threshold = min(decode_threshold, min(query_lens))

argmax should return the first instance of is_prefill so it should be safe, we just need to drop:

assert torch.all(query_lens[first_prefill:] > decode_threshold)

but we have to drop that for #24845 anyways

Copy link
Collaborator

@LucasWilkinson LucasWilkinson Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry I see you want to handle the

[2, 2, 2, 1, 5] case; I think this quite unlikely but I think we can handle this pretty simply by doing something like

# all prefills fast out
if query_lens[0] > decode_threshold:
       return 0, num_reqs, 0, num_tokens

if require_uniform:
       is_prefill = query_lens != query_lens[0]
else:
       is_prefill = query_lens > decode_threshold

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(and still dropping assert torch.all(query_lens[first_prefill:] > decode_threshold))

Copy link
Collaborator

@benchislett benchislett Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson I think the current implementation is probably correct. In the case of [1, 2, 2, 1, 10] with decode_threshold = 2, we want to return [1] for decodes and not [2, 2] or [1, 1]. The decodes sequence must be a prefix of the requests since we only return num_decodes and that is used to determine how far from the front we should slice.

To handle this more thoroughly you would have to modify the batch reordering code. This PR doesn't, and only does a best-effort pass to read uniform decodes from the front, falling back to prefills if there's a mismatch. I think that is fine for now.

Edit* to make the example a better counterexample.

Copy link
Collaborator

@LucasWilkinson LucasWilkinson Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh ya sorry im not doubting the correctness of the current implementation, sorry for the confusion!; I was just suggesting we can just modify the existing implementation and do:

# all prefills fast out
if query_lens[0] > decode_threshold: 
    return 0, num_reqs, 0, num_tokens 
if require_uniform: 
    is_prefill = query_lens != query_lens[0] 
else: 
    is_prefill = query_lens > decode_threshold

instead of the current

is_prefill = query_lens > decode_threshold

(and remove assert torch.all(query_lens[first_prefill:] > decode_threshold))

then we wouldn't need the separate function and could achieve the same effect with alot less code (and it would be vectorized)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson It's not clear to me why this is doable. You're talking about a modification to split_decodes_and_prefills, right? In this case, I think it's possible that it could receive an input like [2, 1, 2, 1, 2, 1] In which case you would need to split into decode [2] and prefills [1, 2, 1, 2, 1]. You would not be able to do seq_lens == 2 and split into [2, 2, 2] and [1, 1, 1] since these are not contiguous in the input request array.

Copy link
Collaborator

@LucasWilkinson LucasWilkinson Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh it's because 'is_prefill' is fed into 'argmax' to find the split point which should return the index of the first prefill and ignore any subsequent decodes

@benchislett
Copy link
Collaborator

I have added some utilities in #25183 that will support this PR and others to enable MTP/Spec support in a common interface. Included are many of the refactors i requested in my review here, so you do not have to duplicate the effort.

@benchislett
Copy link
Collaborator

Closing as #26541 has accomplished this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models documentation Improvements or additions to documentation speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants