Skip to content

Conversation

@xiaohajiayou
Copy link

@xiaohajiayou xiaohajiayou commented Oct 14, 2025

Purpose

(see Issue #26711).

  1. Fix CUDA Graph Capture Crash in EAGLE
    Resolve IndexError: list index out of range when accessing self.cudagraph_batch_sizes[-1] during CUDA Graph capture

    • Root Cause:
      • When using PIECEWISE capture, gpu_model_runner passes use_cudagraphs=True, but MTP heads of draft models (e.g., DeepSeek MTP) lack native CUDA graph support.
      • EAGLE sets self.use_cuda_graph = False and initializes self.cudagraph_batch_sizes as an empty list for such models.
      • PR [Spec-Decode] Support piecewise cudagraphs for Eagle head #25109 ignored self.use_cuda_graph in dummy_run arguments, leading to out-of-range access of the empty cudagraph_batch_sizes list.
    • Change: In vllm/v1/spec_decode/eagle.py:dummy_run, gate padding and cudagraph_runtime_mode with (use_cudagraphs and self.use_cuda_graph) to prevent empty-list indexing and align with runner behavior.
  2. Fix DeepSeek V3.2 MTP Metadata & Sparse MLA Layer Selection
    Resolve incorrect layer classification and metadata mapping for DeepSeek V3.2 Sparse MLA.

    • Root Cause:
      • PR Separate MLAAttention class from Attention #25103 Separate MLAAttention class from Attention, but the Indexer layers were not excluded from attention layer lists, causing mismatches between MTP metadata and model components.
      • Sparse MLA architectures require strict separation of attention and indexer layers to ensure proper metadata propagation for expert routing.
    • Change: In vllm/v1/spec_decode/eagle.py:load_model, filter indexer layers from attention layer lists using draft_attn_layer_names - draft_indexer_layer_names, ensuring accurate metadata selection for DeepSeek MTP heads.

Test Plan

For the scenarios described in Issue #26711 (DeepSeek V3.2 MTP metadata anomalies and CUDA graph capture crashes in EAGLE mode), run tests with and without the fix applied to verify if the issues recur.

Test Result

Verification for Issue #26711 has been completed. After the fix: DeepSeek V3.2 MTP metadata mapping works normally, and there are no crashes in CUDA graph capture under EAGLE mode. The original issues have been resolved.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added deepseek Related to DeepSeek models speculative-decoding v1 labels Oct 14, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes an IndexError in dummy_run that occurs when a drafter does not support CUDA graphs. The change prevents a crash by ensuring self.use_cuda_graph is checked before accessing CUDA graph configurations. I've added a suggestion to further improve the robustness of this fix by also checking if self.cudagraph_batch_sizes is non-empty, which handles an additional edge case and improves code readability by reducing duplication.

@xiaohajiayou xiaohajiayou force-pushed the fix-eagle-dummy-run branch 3 times, most recently from 746c47f to 47a4376 Compare October 14, 2025 12:03
@xiaohajiayou xiaohajiayou changed the title Fix IndexError in dummy_run when drafter(DeepSeek v32) doesn’t support CUDA graphs. [Bugfix] DeepSeek V3.2 MTP metadata & CUDA graph issues Oct 14, 2025
) -> None:
if use_cudagraphs and num_tokens <= self.cudagraph_batch_sizes[-1]:
# Determine if CUDA graphs should be used for this run.
cudagraphs_enabled = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be easier to just set self.use_cuda_graph = self.use_cuda_graph and bool(self.cudagraph_batch_sizes) in __init__ since they are both unchanging over the lifetime of the object I think.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the scenario where use_cuda_graph is True and cudagraph_batch_sizes is empty? I wonder if this might be a symptom of a deeper issue

Copy link
Author

@xiaohajiayou xiaohajiayou Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I applied it and finalized the flag in __init__, and all runtime gating now only checks this flag.

Copy link
Author

@xiaohajiayou xiaohajiayou Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Potential scenarios include: after configuration initialization, cudagraph_mode is overridden to include PIECEWISE mode while the capture size list remains empty—such as when the model's enforce_eager mechanism blocks the size generation process, users explicitly configure an empty list, or sizes are filtered to have no valid entries.

It seems this situation should not occur at present. The addition of this check field is mainly to ensure no out-of-bounds issues arise, while making the drafter's behavior more consistent and safe. @benchislett

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xiaohajiayou could you check if #26821 solves the issue, or if this additional change is also necessary?

Copy link
Author

@xiaohajiayou xiaohajiayou Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’d lean toward keeping it. #26821 already handled the “drafter forces eager ⇒ empty cudagraph_capture_sizes” situation. What’s left are the cases where self.use_cuda_graph still flips to True while self.cudagraph_batch_sizes ends up empty—for example:

  • someone starts vLLM with --compilation-config '{"level": 3, "cudagraph_capture_sizes": []}';
  • the default capture sizes get filtered away by max_num_batched_tokens, sequence parallel, or similar constraints.

In both scenarios, the guard self.use_cuda_graph &= bool(self.cudagraph_batch_sizes) automatically disables the drafter’s CUDA graph when the list is empty. The drafter simply falls back to eager, which costs the graph speedup but keeps the model running instead of crashing on an index error.

That trade-off’s much friendlier than a hard failure.

@xiaohajiayou xiaohajiayou force-pushed the fix-eagle-dummy-run branch 2 times, most recently from f764463 to d92ba15 Compare October 16, 2025 03:33
@xiaohajiayou
Copy link
Author

xiaohajiayou commented Oct 18, 2025

The issue referenced in #26711 is now fixed, with no test issues. Mind reviewing if we can merge this? @benchislett @luccafong

)
draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names
self.attn_layer_names = list(draft_attn_layer_names)
self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you help me understand how attn_layer_names is used, and why draft_indexer_layer_names must be excluded?

Copy link
Author

@xiaohajiayou xiaohajiayou Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! Here’s how the pieces fit together:

  • When we build the drafter layers in

    draft_attn_layer_names = (
    get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase).keys()
    - target_attn_layer_names
    )
    indexer_layers = get_layers_from_vllm_config(
    self.vllm_config, DeepseekV32IndexerCache
    )

    we grab every module that inherits AttentionLayerBase. DeepSeek’s Lightning Indexer (DeepseekV32IndexerCache) does that too, so its layer names got lumped into draft_attn_layer_names.

  • Later, _get_attention_metadata_builder looks at the very first name in self.attn_layer_names, finds the backend for that layer, and caches its metadata builder(might be indexer metadata builder) for the whole set.

    def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder:
    """Find and return the attention metadata builders for EAGLE layers.
    Returns:
    The metadata builders for EAGLE layers.
    Raises:
    AssertionError: If no metadata builders are found for EAGLE layers.
    """
    builder = None
    chosen_layer = self.attn_layer_names[0]
    .

  • After that we loop over self.attn_layer_names and hand that builder’s output to every entry

    per_layer_attn_metadata = {}
    for layer_name in self.attn_layer_names:
    per_layer_attn_metadata[layer_name] = attn_metadata
    for layer_name in self.indexer_layer_names:
    assert draft_indexer_metadata is not None
    per_layer_attn_metadata[layer_name] = draft_indexer_metadata

  • The snag is that Lightning Indexer expects DeepseekV32IndexerMetadata, produced by DeepseekV32IndexerMetadataBuilder .

    class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):

    the standard attention backends expect completely different metadata. If an indexer sneaks into self.attn_layer_names, _get_attention_metadata_builder can lock onto the indexer backend, and the loop then feeds indexer metadata to the true attention layers while the indexer never reaches its dedicated path .

  • So we subtract draft_indexer_layer_names when we finalize self.attn_layer_names , like:
    self.attn_layer_names = list(draft_attn_layer_names - draft_indexer_layer_names)
    That guarantees the first entry really is an attention layer. _get_attention_metadata_builder picks the correct attention backend, the drafter attention layers share that metadata as intended, and the indexer layers stay in self.indexer_layer_names, where they go through the metadata builder their backend expects.

Hope that clarifies why draft_indexer_layer_names has to be excluded.

Copy link
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@benchislett benchislett added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 23, 2025
Signed-off-by: xiaohajiayou <923390377@qq.com>
…dagraphs_enabled

Signed-off-by: xiaohajiayou <923390377@qq.com>
Signed-off-by: xiaohajiayou <923390377@qq.com>
@xiaohajiayou
Copy link
Author

All CI checks have passed, and the issues in Issue #26711 are resolved—can we merge this PR and close the corresponding issue?
@benchislett

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants