Skip to content

Conversation

@seven-mile
Copy link
Contributor

@seven-mile seven-mile commented Oct 4, 2025

Purpose

I encountered an issue similar to #26198, but within EAGLE3, as they share the same code path.

The root cause is an "optimization" for calculating valid_mask for padded speculation, introduced by #24539:

# Generate a mask for all valid tokens within those requests
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu,
dtype=torch.bool)
else:
valid_mask = (
(valid_sampled_token_ids_gpu != -1) &
(valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size))

max_gen_len == 1 usually means speculation does not happen during decode phase. It seems reasonable to have a fast path for a full 1 valid_mask. But note that any prefill input batch would also meet the condition.

When a prefill request is chunked due to exceeding the token budget 8192, it contributes to discard_request_indices (it's surely not supposed to be sampled before all prompt tokens are consumed). But the fast path still yields a valid_mask full of 1, which is incorrect and leads to sentinel values leaking and further OOB.

Test Plan

This change removes a problematic optimization and reverts to the default behavior. I believe no complex tests are necessary.

I tested the fix on 2xH100 using the following commands:

vllm serve \
    Qwen/Qwen3-30B-A3B \
    --host 0.0.0.0 \
    --port 7000 \
    --seed 42 \
    -dp 2 \
    --enable-expert-parallel \
    --enforce-eager \
    --max-model-len 4096 \
    --gpu_memory_utilization 0.8 \
    --speculative-config '{"model":"Tengyunw/qwen3_30b_moe_eagle3","num_speculative_tokens":4}'
vllm bench serve \
  --backend vllm --model Qwen/Qwen3-30B-A3B \
  --dataset-name sharegpt \
  --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
  --num-prompts 200 \
  --host localhost --port 7000

Test Result

It works without any crash.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

… chunked prefill occurs

Signed-off-by: seven-mile <i@7li.moe>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug in speculative decoding where a faulty optimization for padded speculation during chunked prefill could lead to incorrect valid_mask generation. The author correctly identifies that the max_gen_len == 1 condition is not exclusive to the decode phase and can also occur during chunked prefills, causing sentinel values to be treated as valid. The proposed fix, which removes the problematic fast path and reverts to the more general and robust masking logic, is correct and effectively resolves the issue. This change prevents potential out-of-bounds errors and ensures the stability of speculative decoding under these conditions. The reasoning is sound, and the fix is well-justified.

@seven-mile seven-mile changed the title [Bugfix][SpecDecode] Fix wrong valid_mask for padded speculation when chunked prefill occurs [Bugfix][Spec Decode] Fix wrong valid_mask for padded speculation when chunked prefill occurs Oct 5, 2025
Copy link
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix!

An alternative way to fix this might be to change the index_fill_ to apply to the mask in the max_gen_len == 1 case. The fix implemented here is also fine and likely just as performant.

@benchislett
Copy link
Collaborator

Please fix the conflicts with main

@mergify
Copy link

mergify bot commented Oct 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @seven-mile.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@benchislett benchislett added the bug Something isn't working label Oct 6, 2025
@mergify mergify bot added the needs-rebase label Oct 6, 2025
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
@mergify mergify bot removed the needs-rebase label Oct 6, 2025
@benchislett benchislett added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 6, 2025
@benchislett benchislett enabled auto-merge (squash) October 6, 2025 16:07
@benchislett benchislett merged commit b2ea5ba into vllm-project:main Oct 6, 2025
48 checks passed
southfreebird pushed a commit to southfreebird/vllm that referenced this pull request Oct 7, 2025
…n chunked prefill occurs (vllm-project#26231)

Signed-off-by: seven-mile <i@7li.moe>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…n chunked prefill occurs (vllm-project#26231)

Signed-off-by: seven-mile <i@7li.moe>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
…n chunked prefill occurs (vllm-project#26231)

Signed-off-by: seven-mile <i@7li.moe>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
…n chunked prefill occurs (vllm-project#26231)

Signed-off-by: seven-mile <i@7li.moe>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…n chunked prefill occurs (vllm-project#26231)

Signed-off-by: seven-mile <i@7li.moe>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Co-authored-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants