Skip to content

Conversation

@luyuzhe111
Copy link
Contributor

@luyuzhe111 luyuzhe111 commented Mar 17, 2025

This PR extends #10132 and adds chunked prefill support for EAGLE. The complexity involves two main aspects:

  1. Deal with mixed batch scenarios: when chunked prefill is not enabled, vllm's scheduler prioritizes prefill requests and doesn’t put prefill and decode to the same batch. With chunked prefill, we will need to deal with mixed batches and in particular, preserve prefill hidden states from target model for EAGLE (in vllm/spec_decode/batch_expansion.py, vllm/spec_decode/interfaces.py), and prefill EAGLE model properly (in vllm/spec_decode/spec_decode_worker.py).

  2. Save and pass around last token's hidden states in non-terminal chunks: unlike Medusa, EAGLE utilizes hidden states of all previous tokens. This means the last token's hidden states in a non-terminal chunks has to be preserved. To this end, I register a new attribute to SamplerOutput called non_terminal_hidden_states as in vllm/model_executor/layers/sampler.py. I extract these non_terminal_hidden_states in both prefill and (mixed-batch) decoding stage as in vllm/spec_decode/spec_decode_worker.py. The function used to extract those hidden states is in vllm/spec_decode/util.py. Finally, since these non-terminal hidden states might be needed when prefilling the current chunk, I extended prepare_prefill_hidden_states to take in these hidden states when preparing the prefill hidden states.

Limitation: the current implementation only supports batch expansion scorer.

cc @LiuXiaoxuanPKU @comaniac Would appreciate your review. Thanks!

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would like to note here that the original code has a bug. speculative_model is the directory of the draft model, not the type of it.

pyc96 added a commit to pyc96/vllm that referenced this pull request Mar 21, 2025
Signed-off-by: pyc96 <pychen96@gmail.com>
@luyuzhe111
Copy link
Contributor Author

To leverage chunked prefill implemented here for DeepSeek MTP, one can simply convert MTP weights to EAGLE format. Script can be found here and some simple changes included in this PR #14990.

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
@luyuzhe111
Copy link
Contributor Author

luyuzhe111 commented Mar 25, 2025

@LiuXiaoxuanPKU @WoosukKwon also added a missed corner case of the refactoring in #14434 to decide the method of the speculative model (for EAGLE, we should also check for model_type).

@github-actions
Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale Over 90 days of inactivity label Jun 24, 2025
@mergify
Copy link

mergify bot commented Jun 24, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @luyuzhe111.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants