Skip to content

Conversation

@LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Nov 7, 2025

Temporary fix for #28207

For now just make sure when spec-decode is enabled that the cudagraph shapes are evenly divisible by 1 + num_speculative_tokens; see #28207 for more details

Test 1

Tested with

MODEL=deepseek-ai/DeepSeek-R1
VLLM_ATTENTION_BACKEND="FLASHINFER_MLA" \
VLLM_USE_V1=1 \
vllm serve $MODEL \
--tensor-parallel-size 8 \
--disable-log-requests \
--no-enable-prefix-caching \
--compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \
--trust-remote-code \
--block-size 64 \
--kv-cache-dtype fp8 \
--speculative-config='{"method": "deepseek_mtp", "num_speculative_tokens": 2}' 
MODEL=deepseek-ai/DeepSeek-R1
lm_eval \
    --model local-completions \
    --tasks gsm8k \
    --model_args model=${MODEL},base_url=http://127.0.0.1:8000/v1/completions \
    --batch_size 100

Hits an IMA on main but not on this branch

Test 2

vllm serve deepseek-ai/DeepSeek-R1 \
  --tensor-parallel-size 8 \
  --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE", "pass_config": {"enable_sequence_parallelism": true}}' \
  --speculative-config='{"method": "deepseek_mtp", "num_speculative_tokens": 2}'

Correctly fails with

ValueError: Can't determine cudagraph shapes that are both a multiple of 3 
(num_speculative_tokens + 1) required by spec-decode and 8 (tensor_parallel_size) 
required by sequence parallelism please adjust num_speculative_tokens or disable 
sequence parallelism

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a temporary fix for an issue with multi-token prediction and full CUDA graphs by adjusting CUDA graph capture sizes. The core logic change is in a new method, adjust_cudagraph_sizes_to_be_multipe_of, which unfortunately contains a critical bug that can lead to runtime errors and incorrect behavior. I've provided a detailed review comment with a suggested fix for this issue.

Comment on lines 212 to 222
def adjust_cudagraph_sizes_to_be_multipe_of(self, multiple_of: int):
new_sizes = sorted(
[
round_up(size, multiple_of)
for size in self.compilation_config.cudagraph_capture_sizes
]
)
if new_sizes[-1] > self.compilation_config.max_cudagraph_capture_size:
new_sizes = new_sizes[:-1]
self.compilation_config.max_cudagraph_capture_size = new_sizes[-1]
self.compilation_config.cudagraph_capture_sizes = new_sizes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation of adjust_cudagraph_sizes_to_be_multipe_of has several critical issues that can lead to incorrect behavior or runtime errors:

  1. Potential IndexError: If all cudagraph_capture_sizes, when rounded up, exceed max_cudagraph_capture_size, the new_sizes list can become empty after the if condition, leading to an IndexError on new_sizes[-1]. For example, if cudagraph_capture_sizes is [16], max_cudagraph_capture_size is 16, and multiple_of is 20, new_sizes becomes [20]. The if condition is met, and new_sizes is modified to [], causing a crash on the next line.

  2. Incorrect Filtering: The logic if new_sizes[-1] > ...: new_sizes = new_sizes[:-1] only checks and removes the largest element. If multiple rounded-up sizes exceed max_cudagraph_capture_size, the smaller ones will incorrectly remain in the list.

  3. Incorrect max_cudagraph_capture_size update: The max_cudagraph_capture_size can be updated to a value larger than its original value, which seems to contradict its purpose as a hard limit derived from scheduler and token configurations.

I suggest a more robust implementation that correctly filters the sizes and handles edge cases gracefully.

Additionally, there is a typo in the method name (multipe_of should be multiple_of). I've kept it in the suggestion to match the current code, but it should be corrected here and at the call site.

    def adjust_cudagraph_sizes_to_be_multipe_of(self, multiple_of: int):
        max_size = self.compilation_config.max_cudagraph_capture_size
        # Use a set to handle duplicates from rounding up
        rounded_sizes = {
            round_up(size, multiple_of)
            for size in self.compilation_config.cudagraph_capture_sizes
        }
        new_sizes = sorted([s for s in rounded_sizes if s <= max_size])

        if not new_sizes:
            # All rounded-up sizes exceeded the max size.
            # Disable cudagraphs by setting sizes to empty.
            self.compilation_config.max_cudagraph_capture_size = 0
            self.compilation_config.cudagraph_capture_sizes = []
            return

        self.compilation_config.max_cudagraph_capture_size = new_sizes[-1]
        self.compilation_config.cudagraph_capture_sizes = new_sizes

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the computation of bs_to_padded_graph_size and introduces logic to adjust CUDA graph capture sizes. While the intent is to fix an issue with speculative decoding, the changes introduce two critical bugs. First, the refactoring of bs_to_padded_graph_size computation breaks the model initialization order, as it's now computed after profile_run which depends on it. Second, the new method to adjust capture sizes contains a typo and is vulnerable to an IndexError if it results in an empty list of sizes. I have provided detailed comments and suggestions to fix these critical issues.

@gemini-code-assist
Copy link
Contributor

Warning

Gemini encountered an error creating the review. You can try again by commenting /gemini review.

@LucasWilkinson LucasWilkinson changed the title [WIP] Tmp fix for IMA with MTP = 2 and full-cg [BugFix] Temporary fix for IMA with MTP = 2 and full-cg Nov 12, 2025
@LucasWilkinson LucasWilkinson marked this pull request as ready for review November 12, 2025 04:41
Temp fix for vllm-project#28207

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/tmp-full-cg-mtp-2-fix branch from 5c36137 to 529078e Compare November 12, 2025 04:41
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 13, 2025
LucasWilkinson and others added 2 commits November 14, 2025 00:23
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
LucasWilkinson and others added 4 commits November 14, 2025 16:45
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@LucasWilkinson LucasWilkinson added this to the v0.11.1 milestone Nov 15, 2025
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not be compatible with sequence parallelism but that's for high-throughput cases anyway, just might be worth adding a warning.

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@LucasWilkinson
Copy link
Collaborator Author

This might not be compatible with sequence parallelism but that's for high-throughput cases anyway, just might be worth adding a warning.

Done 👍

@mgoin mgoin merged commit 64e39d6 into vllm-project:main Nov 17, 2025
49 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants