Skip to content

[Bug]: CUDA Graph Capture Issue: Unexpected Prefill Branches in Uniform Decode Graphs when MTP=2 #28207

@zhyajie

Description

@zhyajie

CUDA Graph Capture Issue: Unexpected Prefill Branches in Uniform Decode Graphs when mtp=2

Labels: bug, cuda-graphs, mla, attention-backend, Mtp=2

Summary

When capturing uniform decode CUDA graphs for MLA backends with QueryLenSupport.UNIFORM (e.g., ROCM_AITER_MLA, FLASHINFER_MLA, FLASHMLA), if uniform_decode_query_len > 1 (e.g., q_len=3 with mtp=2) and num_tokens is not divisible by uniform_decode_query_len, padding during capture causes requests with different query lengths to be incorrectly classified as prefill, leading to unexpected prefill branches in the captured uniform decode full graph.

Environment

Affected Scenarios

  • Affected Backends: All MLA backends with QueryLenSupport.UNIFORM
    • FLASHINFER_MLA (query_len_support = QueryLenSupport.UNIFORM)
    • FLASHMLA (query_len_support = QueryLenSupport.UNIFORM)
    • ROCM_AITER_MLA (query_len_support = QueryLenSupport.UNIFORM) – Currently not merged
  • Affected Modes: CUDA Graph capture with uniform decode batches (e.g., MTP=2, q_len=3)

server

MODEL=deepseek-ai/DeepSeek-R1
VLLM_ATTENTION_BACKEND="FLASHINFER_MLA" \
VLLM_USE_V1=1 \
vllm serve $MODEL \
--tensor-parallel-size 8 \
--disable-log-requests \
--no-enable-prefix-caching \
--compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \
--trust-remote-code \
--block-size 64 \
--kv-cache-dtype fp8 \
--speculative-config='{"method": "deepseek_mtp", "num_speculative_tokens": 2}' 

client

MODEL=deepseek-ai/DeepSeek-R1
lm_eval \
    --model local-completions \
    --tasks gsm8k \
    --model_args model=${MODEL},base_url=http://127.0.0.1:8000/v1/completions \
    --batch_size 100

Problem Description

When using MLA attention backend with CUDA graph capture enabled, if uniform_decode_query_len > 1 (e.g., q_len=3 in speculative decoding scenarios) and num_tokens is not divisible by uniform_decode_query_len, the following issues occur:

  1. Capture Phase: When capturing a uniform decode full graph, the last request has an incomplete token count (e.g., num_tokens=8, q_len=3 creates a batch of [3, 3, 2]), causing that request to be identified as prefill, which incorrectly includes a prefill branch in a graph that should only contain uniform decode.

  2. Forward Phase: When actually running with only complete decode requests (e.g., 2 requests with q_len=3), using the captured graph may incorrectly trigger the prefill branch.

Example Scenario

Capture Phase:

  • uniform_decode_query_len = 3 (MTP=2)
  • Default CUDA graph capture sizes: [1, 2, 4, 8, 16, 24, ...]
  • When capturing a uniform decode graph with num_tokens=8:
    • num_reqs = cdiv(8, 3) = 3
    • num_scheduled_tokens_list = [3, 3, 2]
    • First two requests have q_len=3, last request has q_len=2
    • With require_uniform=True: Request with query_len=2 is classified as prefill
    • Graph captured with both decode and prefill branches

Runtime Phase:

  • Actual batch: [3, 3] requests (6 tokens, padded to 8)
  • Graph execution tries to access prefill metadata/branches

Root Cause Analysis

1. Issue in Capture Phase (_dummy_run)

In the _dummy_run method in vllm/v1/worker/gpu_model_runner.py, the code for uniform decode mode:

elif uniform_decode:
    num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len))
    num_scheduled_tokens_list = [max_query_len] * num_reqs
    if num_tokens % max_query_len != 0:
        num_scheduled_tokens_list[-1] = num_tokens % max_query_len

When num_tokens=8, max_query_len=3, this creates a batch of [3, 3, 2], causing inconsistent query lengths.

2. Issue in MLA Attention Metadata Construction

In the build method in vllm/v1/attention/backends/mla/common.py:

num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
    split_decodes_and_prefills(
        common_attn_metadata,
        decode_threshold=self.reorder_batch_threshold,
        require_uniform=(self.query_len_support != QueryLenSupport.VARLEN),
    )
)

Since MLA's query_len_support = QueryLenSupport.UNIFORM, require_uniform=True. In split_decodes_and_prefills:

if require_uniform:
    is_prefill = query_lens != query_lens[0]
else:
    is_prefill = query_lens > decode_threshold

When query_lens=[3, 3, 2], the last request is identified as prefill, causing the captured uniform decode full graph to include a prefill branch.

Expected Behavior

For uniform decode CUDA graph capture, all requests should have consistent query_len matching uniform_decode_query_len. The captured graph should contain only decode branches.

Actual Behavior

The captured graph contains unexpected prefill branches due to incorrect classification of requests with different query lengths, leading to runtime memory access faults.

Proposed Solutions

Solution 1: Filter Out Non-Divisible Capture Sizes ⭐ (Recommended)

In the _set_cudagraph_sizes method in vllm/config/vllm.py, filter out capture sizes that are not divisible by uniform_decode_query_len.

Implementation Location: Add filtering logic after update_sizes_for_sequence_parallelism:

# Calculate uniform_decode_query_len
uniform_decode_query_len = 1  # default value
if self.speculative_config and self.speculative_config.num_speculative_tokens:
    uniform_decode_query_len = 1 + self.speculative_config.num_speculative_tokens

# Filter out sizes that are not divisible by uniform_decode_query_len
if uniform_decode_query_len > 1:
    original_sizes = cudagraph_capture_sizes.copy()
    cudagraph_capture_sizes = [
        size for size in cudagraph_capture_sizes 
        if size % uniform_decode_query_len == 0
    ]
    if len(cudagraph_capture_sizes) < len(original_sizes):
        removed_sizes = set(original_sizes) - set(cudagraph_capture_sizes)
        logger.warning(
            "Batch sizes %s are removed because they are not "
            "multiple of uniform_decode_query_len %d when "
            "capturing uniform decode CUDA graphs",
            sorted(removed_sizes),
            uniform_decode_query_len,
        )

Pros:

  • Solves the problem at the source, ensuring all capture sizes are divisible by uniform_decode_query_len
  • Small scope of changes, only requires modifying configuration logic
  • Clear logic, similar to sequence parallelism filtering logic
  • Prevents num_tokens % max_query_len != 0 situations in capture phase
  • If there's no speculative config, uniform_decode_query_len=1, all sizes are divisible by 1, no filtering needed

Cons:

  • Reduces selectable range of CUDA graph buckets; may impact performance.
  • Need specific fix when the user specifies the capture batch list, and a warning should be issued.

Example:

  • If uniform_decode_query_len=3, default capture sizes [1, 2, 4, 8, 16, 24, ...] would be filtered to [24, 48, 72 ...]
  • This ensures that when capturing num_tokens=24 or num_tokens=48, all requests have consistent query_len

others

  • Maybe modify the default CUDA Graph logic: when num_speculative_tokens =2, q_len = 3, increment it in multiples of 6 or 9.

Solution 2: Special Handling in build_for_cudagraph_capture

Force all requests to be treated as decode in build_for_cudagraph_capture for uniform decode mode, bypassing the require_uniform check.

Pros: Small scope of changes, only affects capture phase
Cons: Requires modifying metadata construction logic, may introduce edge cases

Impact

  • Severity: High (causes runtime crashes)
  • Scope: Affects all MLA backends with QueryLenSupport.UNIFORM using CUDA graphs with MTP/speculative decoding
  • Workaround: Disable CUDA graphs or ensure capture sizes are divisible by uniform_decode_query_len

Discussion Welcome

I understand there may be multiple solutions to this problem, each with its own trade-offs. I prefer Solution 1 (filtering non-divisible capture sizes) because it solves the problem at the configuration stage, has clear logic, and requires minimal changes.

Feel free to propose other approaches or improvements. We can discuss the best implementation strategy together. I'm happy to help implement the solution once we agree on the approach.

Additional Context

This issue is particularly relevant for:

  • Models using Multi-Token Prediction (MTP) or speculative decoding where uniform_decode_query_len > 1
  • CUDA graph capture sizes that don't align with uniform_decode_query_len
  • Performance-critical scenarios where CUDA graphs are essential

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions