Skip to content

Conversation

@AlonKejzman
Copy link
Contributor

@AlonKejzman AlonKejzman commented Sep 11, 2025

…g draft model length

Purpose

Enable running Eagle Speculative Decoding in environments where the input may exceed the drafter model length but not the verifier's.

Test Plan

Running with inputs that are below the drafter model length, and then between the drafter and verifier.

Test Result

Successful.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a bugfix to prevent crashes in Eagle Speculative Decoding when the input sequence length exceeds the draft model's capacity. The fix correctly adds a check to early-exit from the propose method.

My review identifies a critical issue with the implementation of this early exit. The returned empty tensor has a hardcoded shape and incorrect data type, which can lead to crashes or incorrect behavior in batched scenarios. I've provided a suggestion to fix this by dynamically creating the tensor with the correct shape and dtype. I've also recommended removing the newly introduced constant, as it becomes obsolete with the suggested fix.

@AlonKejzman AlonKejzman force-pushed the main branch 3 times, most recently from 332386b to 58adccd Compare September 11, 2025 14:23
@tomasruizt
Copy link
Contributor

Could you provide commands to reproduce the issue?

@AlonKejzman
Copy link
Contributor Author

Sure!

Serving the model

vllm serve meta-llama/Meta-Llama-3-8B-Instruct --speculative-config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 4, "max_model_len": 2048}'

Request that makes it crash

import requests

url = "http://localhost:8000/v1/chat/completions"

headers = {
    "Authorization": "Bearer EMPTY", 
    "Content-Type": "application/json",
}

payload = {
    "model": "meta-llama/Meta-Llama-3-8B-Instruct",
    "messages": [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Hello!" * 2049}
    ],
    "temperature": 0.0,
    "max_tokens": 1,
}

requests.post(url, headers=headers, json=payload, timeout=60)

@tomasruizt
Copy link
Contributor

Is it intended behavior for the draft model to have a shorter model length compared to the target model?

If I understand the use case correctly, it's to use spec decoding only on short sequences and not on longer sequences.

If that is intended, then perhaps we should not be even calling the drafting method in the GPUModelRunner at all, but rather skipping it altogether.

@benchislett
Copy link
Collaborator

See also #22935 which is a more restrictive approach that doesn't let the serving engine launch with max_model_len larger than that of the drafter. I think the proper strategy is to simply skip drafting entirely from gpu_model_runner.propose_draft_token_ids() and not in Eagle.propose() since logic like prepare_inputs and other setup is would still run.

@@ -159,6 +159,14 @@ def propose(
sampling_metadata: SamplingMetadata,
mm_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.Tensor:
# do not attempt to forward if the input size is too big
if common_attn_metadata.seq_lens.max(
) + self.num_speculative_tokens > self.draft_model_config.max_model_len:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels more like a general spec decoding issue instead of eagle specific issue, could we add it to gpu_model_runner?

self._draft_token_ids = self.propose_draft_token_ids(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@mergify
Copy link

mergify bot commented Sep 21, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @AlonKejzman.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@AlonKejzman
Copy link
Contributor Author

@tomasruizt @luccafong You are right, I adjusted the fix accordingly
@benchislett WDYT? Maybe this is more flexible than #22935 since it allows for smaller drafter models while bypassing the drafter when needed?

@AlonKejzman AlonKejzman force-pushed the main branch 2 times, most recently from 54e7eab to abaf634 Compare September 21, 2025 16:13
@mergify mergify bot added documentation Improvements or additions to documentation frontend multi-modality Related to multi-modality (#4194) qwen Related to Qwen models structured-output labels Sep 25, 2025
@AlonKejzman AlonKejzman force-pushed the main branch 4 times, most recently from 90e74e5 to 515a675 Compare September 25, 2025 12:45
@benchislett benchislett enabled auto-merge (squash) September 25, 2025 13:17
@benchislett benchislett merged commit e04a1b6 into vllm-project:main Sep 25, 2025
42 checks passed
sergiopaniego pushed a commit to sergiopaniego/vllm that referenced this pull request Sep 29, 2025
sergiopaniego pushed a commit to sergiopaniego/vllm that referenced this pull request Sep 29, 2025
vllm-project#24662)

Signed-off-by: AlonKejzman <alonkeizman@gmail.com>
Signed-off-by: sergiopaniego <sergiopaniegoblanco@gmail.com>
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
#24662)

Signed-off-by: AlonKejzman <alonkeizman@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
vllm-project#24662)

Signed-off-by: AlonKejzman <alonkeizman@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
vllm-project#24662)

Signed-off-by: AlonKejzman <alonkeizman@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation frontend multi-modality Related to multi-modality (#4194) qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding structured-output v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants