-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[Attention][DCP] Support DCP with query length > 1 (MTP) with FA3 #25049
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to enable multi-token prediction (MTP) with decode context parallelism (DCP) for FlashAttention-3. The changes involve removing a restriction on query length for DCP and passing cp_world_size and cp_rank to the attention kernel. While the changes in vllm/v1/worker/gpu_model_runner.py are correct, there is a critical issue in vllm/v1/attention/backends/mla/flashattn_mla.py. The newly used attributes self.dcp_world_size and self.dcp_rank are not properly initialized due to an issue in the MLACommonImpl base class, which will cause a TypeError at runtime. This must be addressed for the feature to function correctly.
|
Thanks for this contribution! Just wanted to leave a reminder to update the FlashAttention |
| # assert once the custom mask is support is added to FA3. | ||
| if self.dcp_world_size > 1: | ||
| assert self.reorder_batch_threshold == 1, \ | ||
| "DCP not support reorder_batch_threshold > 1 now." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since only flash_attn_mla support custom mask, we can't just remove this assert right now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sense. I'll make a whitelist here for FA3 MLA
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
…eq_len Signed-off-by: Ming Yang <minos.future@gmail.com>
965cdab to
9c0176b
Compare
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
vllm/v1/attention/backends/utils.py
Outdated
| # Needed by CrossAttentionBuilder | ||
| encoder_seq_lens: Optional[np.ndarray] = None | ||
|
|
||
| cp_seq_lens: Optional[torch.Tensor] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good. lemme keep the dcp prefix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
LucasWilkinson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall looks good to me; left one nit
Signed-off-by: Ming Yang <minos.future@gmail.com>
…lm-project#25049) Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: yang926 <yang926@naver.com>
…lm-project#25049) Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…lm-project#25049) Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
…lm-project#25049) Signed-off-by: Ming Yang <minos.future@gmail.com>
…lm-project#25049) Signed-off-by: Ming Yang <minos.future@gmail.com>
…lm-project#25049) Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…lm-project#25049) Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
…lm-project#25049) Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Purpose
Combined with vllm-project/flash-attention#93, this is to enable MTP (multi-token prediction) with DCP (decode context parallelism). It also allows prefill/decode to be mixed in a batch.
See vllm-project/flash-attention#93 for the implementation and solution details. Here we just need to pass the cp world size and cp rank.
Test Plan
Test Result
Benchmark
Expand for details:
Metric Details
#### With MTP and TP8,DCP4With MTP and TP8,DCP8
With MTP and TP8
With TP8, DCP8
With TP8, DCP4
With TP8
LM Eval
with tp8dcp8+mtp
local-completions (model=deepseek-ai/DeepSeek-R1-0528,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=32), gen_kwargs: (None), limit: 100.0, num_fewshot: None, batch_size: 1
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.