Skip to content

Conversation

@gjc0824
Copy link

@gjc0824 gjc0824 commented Sep 23, 2025

Purpose

This PR adds Decode Context Parallel (DCP) support for GQA follwing PR #23734 and PR #24864. Current implementation based on FlashInfer Attention.

FlashInfer inserts the current query KV into the cache before computation. Each query then attends to both its own KV and the context KV on the local device, with LSE applied to correct the attention outputs.

  • In the prefill/partial-prefill stage, custom mask is added to support interleaved KV cache with FlashInfer.
q_lens = 8, total_lens = 25 , group_size = 4, local_rank = 0

stored kv cache
rank0: 0 4 8 12 16 20 24
rank1: 1 5 9 13 17 21
rank2: 2 6 10 14 18 22
rank3: 3 7 11 15 19 23

rank0 custom mask
q\kv    0      4      8     12    16      20      24
17   True,  True,  True,  True,  True,  False, False
18   True,  True,  True,  True,  True,  False, False
19   True,  True,  True,  True,  True,  False, False
20   True,  True,  True,  True,  True,  True,  False
21   True,  True,  True,  True,  True,  True,  False
22   True,  True,  True,  True,  True,  True,  False
23   True,  True,  True,  True,  True,  True,  False
24   True,  True,  True,  True,  True,  True,  True
  • In the decode stage, this PR follows the DCP decode approach from MLA, i.e., all-gathering Q and lse, then correcting the attn out before performing reduce-scatter.

Test Plan

Qwen/Qwen3-235B-A22B

export VLLM_ATTENTION_BACKEND='FLASHINFER'
vllm serve Qwen/Qwen3-235B-A22B --gpu-memory-utilization 0.9 --tensor-parallel-size 8 --decode-context-parallel-size 2

Test Result

  • gsm8k eval
dcp=1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8578|±  |0.0068|
|     |       |strict-match    |     5|exact_match|↑  |0.8415|±  |0.0071|

dcp=2
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8613|±  |0.0067|
|     |       |strict-match    |     5|exact_match|↑  |0.8469|±  |0.0070|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Decode Context Parallel (DCP) support for Grouped-Query Attention (GQA) with the FlashInfer backend, which is a valuable enhancement for distributed inference performance. The changes are comprehensive, covering configuration validation, modifications to the attention backend to support DCP-specific logic like query head gathering and LSE-based output correction, and the implementation of a custom attention mask for prefills. The addition of tests for a GQA model using the new functionality is also a great inclusion. The overall implementation is well-executed. I have a couple of suggestions to enhance code quality by addressing a dynamically assigned attribute and removing duplicated code.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@gjc0824 gjc0824 force-pushed the dcp-gqa-flashinfer branch 9 times, most recently from 540c862 to b9e9b41 Compare September 24, 2025 03:44
continue
K = ((rightmost - r) // p) + 1
j = torch.arange(K)
t = torch.arange(Q)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we generally avoid single character variable names; theyre ok though if there is supporting comment, can you please add comments explaining what the mask looks like and how it is constructed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your review. We have added the comment about mask examples and algorithm explanation after vectorized improvements.

torch.int64).tolist()
r = self.dcp_rank
p = self.dcp_world_size
for i in range(num_prefills):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is there a way we can vectorize this loop or replace it with a triton kernel? ideally we avoid python loops as they can be very slow and create GPU bubbles

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your valuable review. We have vectorized the "num_prefills" loop to avoid GPU bubbles. Looking forward to your further review.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self.dcp_world_size > 1:
    # init custom mask for interleave kv cache
    # |-------total_lens----------|
    # |--context_lens--|--q_lens--|
    # Example: dcp_size=2, dcp_rank=0
    # For a SINGLE prefill seq, q_lens=3, total_lens=5
    # k_lens on RANK1 is (5 - 1 - 0) // 2 + 1 = 3
    # mask.shape = [q_lens, k_lens] = [3,3]
    # mask [[True, True, False],
    #       [True, True, False],
    #       [True, True, True]]
    dcp_rank = self.dcp_rank
    dcp_size = self.dcp_world_size

    q_lens = (qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]).to(
            dtype=torch.int64, device=self.device)
    total_lens = seq_lens_cpu[prefill_start:prefill_start +
                num_prefills].to(dtype=torch.int64,
                device=self.device)
    context_lens = total_lens - q_lens
    # max indices for global sequences
    max_indices = total_lens - 1
    # if max_indices are smaller than dcp_rank,
    # current rank has no kv cache, is invalid,
    # the mask is skipped
    valid = (max_indices >= dcp_rank)
    assert torch.any(valid), "There is no valid sequence"

    # local kv lens on current dcp_rank
    k_lens = torch.div(max_indices - dcp_rank, 
                        dcp_size, 
                        rounding_mode="floor") + 1
    k_lens = torch.where(
        valid,
        k_lens,
        torch.zeros_like(k_lens))
    # vectorize operation
    # obtain the max length of all prefill reqs
    max_q = int(q_lens[valid].max().item())
    max_k = int(k_lens[valid].max().item())
    # generate local q and k indices
    q_indices = torch.arange(max_q, device=self.device)
    k_indices = torch.arange(max_k, device=self.device)
    # valid q and k indices of each reqs
    valid_q = valid[:, None] & \
        (q_indices[None, :] < q_lens[:, None])
    valid_k = valid[:, None] & \
        (k_indices[None, :] < k_lens[:, None])
    # where global q_indices >= global k_indices,
    # the mask is True
    # global q_indices = context_lens + local q_indices
    # global k_indices = local k_indcies * dcp_size + dcp_rank
    # ====> local k_indcies must be smaller or equal k_upper
    # k_upper=(context_lens + local q_indices - dcp_rank) // dcp_size
    k_upper = torch.div(
        context_lens[:, None] + q_indices - dcp_rank,
        dcp_size, rounding_mode="floor")
    k_upper = torch.where(
            valid_q,
            torch.clamp(k_upper, min=-1),
            k_upper.new_full(k_upper.shape, -1))
    mask = (k_indices[None, None, :] <= k_upper[:, :, None]) \
            & (k_upper[:, :, None] >= 0)
    valid_positions = valid_q[:, :, None] & valid_k[:, None, :]
    # flashinfer backend needs flattened format
    custom_mask = torch.masked_select(mask, valid_positions)

@LucasWilkinson
Copy link
Collaborator

Apologies for the delayed review! left a couple nits; overall its looking pretty good though

@mergify
Copy link

mergify bot commented Oct 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gjc0824.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 9, 2025
@mergify mergify bot added tpu Related to Google TPUs tool-calling labels Oct 10, 2025
@mergify mergify bot added the kv-connector label Oct 10, 2025
@gjc0824 gjc0824 force-pushed the dcp-gqa-flashinfer branch from 5a107b7 to b163b5b Compare October 10, 2025 02:25
@mergify mergify bot removed the tpu Related to Google TPUs label Oct 10, 2025
Signed-off-by: gaojc <1055866782@qq.com>
@gjc0824
Copy link
Author

gjc0824 commented Oct 14, 2025

Apologies for the delayed review! left a couple nits; overall its looking pretty good though

@gjc0824
Copy link
Author

gjc0824 commented Oct 14, 2025

Apologies for the delayed review! left a couple nits; overall its looking pretty good though

Hi @LucasWilkinson . Could you re-review this PR and give the final sign off ? Thanks!

@gjc0824 gjc0824 reopened this Oct 14, 2025
@github-project-automation github-project-automation bot moved this from Done to To Triage in gpt-oss Issues & Enhancements Oct 14, 2025
@mergify
Copy link

mergify bot commented Oct 14, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gjc0824.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com>
@gjc0824 gjc0824 force-pushed the dcp-gqa-flashinfer branch from 8b7e0ed to bff3cda Compare October 16, 2025 06:59
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the delay! Overall looks pretty good so far but I think we should land #26696 first (seems more important and this can build on that), thoughts?


self.num_qo_heads = self.model_config.get_num_attention_heads(
self.vllm_config.parallel_config
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block_table_tensor = common_attn_metadata.block_table_tensor

if self.dcp_world_size > 1:
seq_lens_np = seq_lens_np // self.dcp_world_size + (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we land #26696 first and then update this to use the dcp_local_seq_lens computed in the model runner?

if self.dcp_world_size > 1:
prefill_query = get_dcp_group().all_gather(
prefill_query.contiguous(), dim=1
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I guess this is fine but I guess the name "decode context parallel" is falling apart a bit here 😞

],
"bigcode/gpt_bigcode-santacoder": [
CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to keep the default backend for CI.

class CPTestOptions(NamedTuple):
multi_node_only: bool
load_format: str | None = None
attn_backend: str = "FLASH_ATTN"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLA can't use "FLASH_ATTN" backend, so the default value should not be set.

@gjc0824 gjc0824 requested a review from pavanimajety as a code owner October 31, 2025 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models kv-connector llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues qwen Related to Qwen models rocm Related to AMD ROCm speculative-decoding structured-output tool-calling v1

Projects

Status: Done
Status: Done
Status: In progress

Development

Successfully merging this pull request may close these issues.

3 participants