-
-
Notifications
You must be signed in to change notification settings - Fork 11.4k
Description
CUDA Graph Capture Issue: Unexpected Prefill Branches in Uniform Decode Graphs when mtp=2
Labels: bug, cuda-graphs, mla, attention-backend, Mtp=2
Summary
When capturing uniform decode CUDA graphs for MLA backends with QueryLenSupport.UNIFORM (e.g., ROCM_AITER_MLA, FLASHINFER_MLA, FLASHMLA), if uniform_decode_query_len > 1 (e.g., q_len=3 with mtp=2) and num_tokens is not divisible by uniform_decode_query_len, padding during capture causes requests with different query lengths to be incorrectly classified as prefill, leading to unexpected prefill branches in the captured uniform decode full graph.
Environment
Affected Scenarios
- Affected Backends: All MLA backends with
QueryLenSupport.UNIFORMFLASHINFER_MLA(query_len_support = QueryLenSupport.UNIFORM)FLASHMLA(query_len_support = QueryLenSupport.UNIFORM)ROCM_AITER_MLA(query_len_support = QueryLenSupport.UNIFORM) – Currently not merged
- Affected Modes: CUDA Graph capture with uniform decode batches (e.g., MTP=2, q_len=3)
server
MODEL=deepseek-ai/DeepSeek-R1
VLLM_ATTENTION_BACKEND="FLASHINFER_MLA" \
VLLM_USE_V1=1 \
vllm serve $MODEL \
--tensor-parallel-size 8 \
--disable-log-requests \
--no-enable-prefix-caching \
--compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \
--trust-remote-code \
--block-size 64 \
--kv-cache-dtype fp8 \
--speculative-config='{"method": "deepseek_mtp", "num_speculative_tokens": 2}' client
MODEL=deepseek-ai/DeepSeek-R1
lm_eval \
--model local-completions \
--tasks gsm8k \
--model_args model=${MODEL},base_url=http://127.0.0.1:8000/v1/completions \
--batch_size 100Problem Description
When using MLA attention backend with CUDA graph capture enabled, if uniform_decode_query_len > 1 (e.g., q_len=3 in speculative decoding scenarios) and num_tokens is not divisible by uniform_decode_query_len, the following issues occur:
-
Capture Phase: When capturing a uniform decode full graph, the last request has an incomplete token count (e.g.,
num_tokens=8, q_len=3creates a batch of[3, 3, 2]), causing that request to be identified as prefill, which incorrectly includes a prefill branch in a graph that should only contain uniform decode. -
Forward Phase: When actually running with only complete decode requests (e.g., 2 requests with
q_len=3), using the captured graph may incorrectly trigger the prefill branch.
Example Scenario
Capture Phase:
uniform_decode_query_len = 3(MTP=2)- Default CUDA graph capture sizes:
[1, 2, 4, 8, 16, 24, ...] - When capturing a uniform decode graph with
num_tokens=8:num_reqs = cdiv(8, 3) = 3num_scheduled_tokens_list = [3, 3, 2]- First two requests have
q_len=3, last request hasq_len=2 - With
require_uniform=True: Request withquery_len=2is classified as prefill - Graph captured with both decode and prefill branches
Runtime Phase:
- Actual batch:
[3, 3]requests (6 tokens, padded to 8) - Graph execution tries to access prefill metadata/branches
Root Cause Analysis
1. Issue in Capture Phase (_dummy_run)
In the _dummy_run method in vllm/v1/worker/gpu_model_runner.py, the code for uniform decode mode:
elif uniform_decode:
num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len))
num_scheduled_tokens_list = [max_query_len] * num_reqs
if num_tokens % max_query_len != 0:
num_scheduled_tokens_list[-1] = num_tokens % max_query_lenWhen num_tokens=8, max_query_len=3, this creates a batch of [3, 3, 2], causing inconsistent query lengths.
2. Issue in MLA Attention Metadata Construction
In the build method in vllm/v1/attention/backends/mla/common.py:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
require_uniform=(self.query_len_support != QueryLenSupport.VARLEN),
)
)Since MLA's query_len_support = QueryLenSupport.UNIFORM, require_uniform=True. In split_decodes_and_prefills:
if require_uniform:
is_prefill = query_lens != query_lens[0]
else:
is_prefill = query_lens > decode_thresholdWhen query_lens=[3, 3, 2], the last request is identified as prefill, causing the captured uniform decode full graph to include a prefill branch.
Expected Behavior
For uniform decode CUDA graph capture, all requests should have consistent query_len matching uniform_decode_query_len. The captured graph should contain only decode branches.
Actual Behavior
The captured graph contains unexpected prefill branches due to incorrect classification of requests with different query lengths, leading to runtime memory access faults.
Proposed Solutions
Solution 1: Filter Out Non-Divisible Capture Sizes ⭐ (Recommended)
In the _set_cudagraph_sizes method in vllm/config/vllm.py, filter out capture sizes that are not divisible by uniform_decode_query_len.
Implementation Location: Add filtering logic after update_sizes_for_sequence_parallelism:
# Calculate uniform_decode_query_len
uniform_decode_query_len = 1 # default value
if self.speculative_config and self.speculative_config.num_speculative_tokens:
uniform_decode_query_len = 1 + self.speculative_config.num_speculative_tokens
# Filter out sizes that are not divisible by uniform_decode_query_len
if uniform_decode_query_len > 1:
original_sizes = cudagraph_capture_sizes.copy()
cudagraph_capture_sizes = [
size for size in cudagraph_capture_sizes
if size % uniform_decode_query_len == 0
]
if len(cudagraph_capture_sizes) < len(original_sizes):
removed_sizes = set(original_sizes) - set(cudagraph_capture_sizes)
logger.warning(
"Batch sizes %s are removed because they are not "
"multiple of uniform_decode_query_len %d when "
"capturing uniform decode CUDA graphs",
sorted(removed_sizes),
uniform_decode_query_len,
)Pros:
- Solves the problem at the source, ensuring all capture sizes are divisible by
uniform_decode_query_len - Small scope of changes, only requires modifying configuration logic
- Clear logic, similar to sequence parallelism filtering logic
- Prevents
num_tokens % max_query_len != 0situations in capture phase - If there's no speculative config,
uniform_decode_query_len=1, all sizes are divisible by 1, no filtering needed
Cons:
- Reduces selectable range of CUDA graph buckets; may impact performance.
- Need specific fix when the user specifies the capture batch list, and a warning should be issued.
Example:
- If
uniform_decode_query_len=3, default capture sizes[1, 2, 4, 8, 16, 24, ...]would be filtered to[24, 48, 72 ...] - This ensures that when capturing
num_tokens=24ornum_tokens=48, all requests have consistent query_len
others
- Maybe modify the default CUDA Graph logic: when num_speculative_tokens =2, q_len = 3, increment it in multiples of 6 or 9.
Solution 2: Special Handling in build_for_cudagraph_capture
Force all requests to be treated as decode in build_for_cudagraph_capture for uniform decode mode, bypassing the require_uniform check.
Pros: Small scope of changes, only affects capture phase
Cons: Requires modifying metadata construction logic, may introduce edge cases
Impact
- Severity: High (causes runtime crashes)
- Scope: Affects all MLA backends with
QueryLenSupport.UNIFORMusing CUDA graphs with MTP/speculative decoding - Workaround: Disable CUDA graphs or ensure capture sizes are divisible by
uniform_decode_query_len
Discussion Welcome
I understand there may be multiple solutions to this problem, each with its own trade-offs. I prefer Solution 1 (filtering non-divisible capture sizes) because it solves the problem at the configuration stage, has clear logic, and requires minimal changes.
Feel free to propose other approaches or improvements. We can discuss the best implementation strategy together. I'm happy to help implement the solution once we agree on the approach.
Additional Context
This issue is particularly relevant for:
- Models using Multi-Token Prediction (MTP) or speculative decoding where
uniform_decode_query_len > 1 - CUDA graph capture sizes that don't align with
uniform_decode_query_len - Performance-critical scenarios where CUDA graphs are essential