Skip to content

Conversation

@minosfuture
Copy link
Contributor

@minosfuture minosfuture commented Sep 17, 2025

Purpose

Combined with vllm-project/flash-attention#93, this is to enable MTP (multi-token prediction) with DCP (decode context parallelism). It also allows prefill/decode to be mixed in a batch.

See vllm-project/flash-attention#93 for the implementation and solution details. Here we just need to pass the cp world size and cp rank.

Test Plan

  1. Benchmark output token throughput
  2. Verify LM eval accuracy

Test Result

Benchmark

config output token per second speedup (relative to tp8) speedup (relative to tp8+mtp)
tp8 2332.43 1x 0.77x
tp8,mtp 3038.71 1.30x 1x
tp8,dcp4 2983.08 1.28 0.98x
tp8,dcp8 3502.24 1.50 1.15x
tp8,mtp,dcp4 3722.81 1.60x 1.23x
tp8,mtp,dcp8 3600.58 1.54x 1.18x

Expand for details:

Metric Details #### With MTP and TP8,DCP4
VLLM_ATTENTION_BACKEND=FLASH_ATTN_MLA vllm serve deepseek-ai/DeepSeek-R1-0528 --tensor-parallel-size 8 --speculative-config '{"num_speculative_tokens": 1, "method": "deepseek_mtp"}' -dcp 4

============ Serving Benchmark Result ============
Successful requests:                     4096
Maximum request concurrency:             4096
Benchmark duration (s):                  1126.65
Total input tokens:                      8373936
Total generated tokens:                  4194304
Request throughput (req/s):              3.64
Output token throughput (tok/s):         3722.81
Total Token throughput (tok/s):          11155.42
---------------Time to First Token----------------
Mean TTFT (ms):                          490816.57
Median TTFT (ms):                        434249.00
P99 TTFT (ms):                           1072761.55
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          168.82
Median TPOT (ms):                        161.39
P99 TPOT (ms):                           262.88
---------------Inter-token Latency----------------
Mean ITL (ms):                           334.02
Median ITL (ms):                         208.86
P99 ITL (ms):                            824.69
==================================================
Benchmark completed for tp8_spec_dcp4

With MTP and TP8,DCP8

============ Serving Benchmark Result ============
Successful requests:                     4096
Maximum request concurrency:             4096
Benchmark duration (s):                  1164.90
Total input tokens:                      8373311
Total generated tokens:                  4194304
Request throughput (req/s):              3.52
Output token throughput (tok/s):         3600.58
Total Token throughput (tok/s):          10788.61
---------------Time to First Token----------------
Mean TTFT (ms):                          445652.28
Median TTFT (ms):                        448234.92
P99 TTFT (ms):                           896324.45
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          321.94
Median TPOT (ms):                        319.65
P99 TPOT (ms):                           491.82
---------------Inter-token Latency----------------
Mean ITL (ms):                           636.99
Median ITL (ms):                         480.00
P99 ITL (ms):                            1087.89
==================================================
Benchmark completed for tp8_spec_dcp8

With MTP and TP8

============ Serving Benchmark Result ============
Successful requests:                     4096
Maximum request concurrency:             4096
Benchmark duration (s):                  1380.29
Total input tokens:                      8371302
Total generated tokens:                  4194304
Request throughput (req/s):              2.97
Output token throughput (tok/s):         3038.71
Total Token throughput (tok/s):          9103.59
---------------Time to First Token----------------
Mean TTFT (ms):                          661675.41
Median TTFT (ms):                        677812.47
P99 TTFT (ms):                           1347726.14
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          57.12
Median TPOT (ms):                        52.00
P99 TPOT (ms):                           88.92
---------------Inter-token Latency----------------
Mean ITL (ms):                           113.07
Median ITL (ms):                         68.78
P99 ITL (ms):                            734.05
==================================================

With TP8, DCP8

============ Serving Benchmark Result ============
Successful requests:                     4096
Maximum request concurrency:             4096
Benchmark duration (s):                  1197.60
Total input tokens:                      8371375
Total generated tokens:                  4194304
Request throughput (req/s):              3.42
Output token throughput (tok/s):         3502.24
Total Token throughput (tok/s):          10492.34
---------------Time to First Token----------------
Mean TTFT (ms):                          456091.64
Median TTFT (ms):                        472464.37
P99 TTFT (ms):                           945354.75
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          356.88
Median TPOT (ms):                        368.52
P99 TPOT (ms):                           565.14
---------------Inter-token Latency----------------
Mean ITL (ms):                           356.88
Median ITL (ms):                         253.01
P99 ITL (ms):                            853.34
==================================================
Benchmark completed for tp8_dcp8

With TP8, DCP4

============ Serving Benchmark Result ============
Successful requests:                     4096
Maximum request concurrency:             4096
Benchmark duration (s):                  1406.03
Total input tokens:                      8369181
Total generated tokens:                  4194304
Request throughput (req/s):              2.91
Output token throughput (tok/s):         2983.08
Total Token throughput (tok/s):          8935.43
---------------Time to First Token----------------
Mean TTFT (ms):                          597440.50
Median TTFT (ms):                        545819.80
P99 TTFT (ms):                           1262335.73
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          236.80
Median TPOT (ms):                        224.94
P99 TPOT (ms):                           390.39
---------------Inter-token Latency----------------
Mean ITL (ms):                           236.80
Median ITL (ms):                         162.87
P99 ITL (ms):                            757.87
==================================================
Benchmark completed for tp8_dcp4

With TP8

VLLM_ATTENTION_BACKEND=FLASH_ATTN_MLA vllm serve deepseek-ai/DeepSeek-R1-0528 --tensor-parallel-size 8 --speculative-config '{"num_speculative_tokens": 1, "method": "deepseek_mtp"}' 

============ Serving Benchmark Result ============
Successful requests:                     4096
Maximum request concurrency:             4096
Benchmark duration (s):                  1798.26
Total input tokens:                      8373030
Total generated tokens:                  4194304
Request throughput (req/s):              2.28
Output token throughput (tok/s):         2332.43
Total Token throughput (tok/s):          6988.62
---------------Time to First Token----------------
Mean TTFT (ms):                          855856.29
Median TTFT (ms):                        876320.14
P99 TTFT (ms):                           1748166.55
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          82.57
Median TPOT (ms):                        74.08
P99 TPOT (ms):                           134.68
---------------Inter-token Latency----------------
Mean ITL (ms):                           82.57
Median ITL (ms):                         56.14
P99 ITL (ms):                            704.71
==================================================
Benchmark completed for tp8

LM Eval

with tp8dcp8+mtp

local-completions (model=deepseek-ai/DeepSeek-R1-0528,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=32), gen_kwargs: (None), limit: 100.0, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.96 ± 0.0197
strict-match 5 exact_match 0.96 ± 0.0197

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to enable multi-token prediction (MTP) with decode context parallelism (DCP) for FlashAttention-3. The changes involve removing a restriction on query length for DCP and passing cp_world_size and cp_rank to the attention kernel. While the changes in vllm/v1/worker/gpu_model_runner.py are correct, there is a critical issue in vllm/v1/attention/backends/mla/flashattn_mla.py. The newly used attributes self.dcp_world_size and self.dcp_rank are not properly initialized due to an issue in the MLACommonImpl base class, which will cause a TypeError at runtime. This must be addressed for the feature to function correctly.

@MatthewBonanni
Copy link
Contributor

Thanks for this contribution! Just wanted to leave a reminder to update the FlashAttention GIT_TAG in cmake/external_projects/vllm_flash_attn.cmake after vllm-project/flash-attention#93 lands

# assert once the custom mask is support is added to FA3.
if self.dcp_world_size > 1:
assert self.reorder_batch_threshold == 1, \
"DCP not support reorder_batch_threshold > 1 now."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since only flash_attn_mla support custom mask, we can't just remove this assert right now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense. I'll make a whitelist here for FA3 MLA

@mergify
Copy link

mergify bot commented Sep 25, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @minosfuture.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 25, 2025
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
…eq_len

Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
@mergify mergify bot added the rocm Related to AMD ROCm label Sep 30, 2025
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
@mergify mergify bot added the ci/build label Oct 6, 2025
# Needed by CrossAttentionBuilder
encoder_seq_lens: Optional[np.ndarray] = None

cp_seq_lens: Optional[torch.Tensor] = None
Copy link
Collaborator

@LucasWilkinson LucasWilkinson Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we maybe rename this to dcp_num_local_tokens or something like that? im worried this will clash with #25749 / #26133

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good. lemme keep the dcp prefix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall looks good to me; left one nit

Signed-off-by: Ming Yang <minos.future@gmail.com>
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 8, 2025
@vllm-bot vllm-bot merged commit 3b736e1 into vllm-project:main Oct 9, 2025
80 of 84 checks passed
yang926 pushed a commit to yang926/vllm_1008 that referenced this pull request Oct 9, 2025
…lm-project#25049)

Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: yang926 <yang926@naver.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…lm-project#25049)

Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
…lm-project#25049)

Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…lm-project#25049)

Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…lm-project#25049)

Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…lm-project#25049)

Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants