Skip to content

Conversation

@wwl2755
Copy link
Contributor

@wwl2755 wwl2755 commented May 1, 2025

As mentioned in #15901, currently we only support top-1 selection from the candidates from the EAGLE model (we call it chain-draft), and in EAGLE and EAGLE-2, both are claim select top-k tokens from each forward pass can benefit the acceptance rate, so we want to support it (we call it tree-draft).

As this would be a big change, I would like to work on a WIP PR and would be appreciated to receive any comments/suggestions/discussion during implementation.

Design Doc: https://docs.google.com/document/d/1mMoSicPPMMzaE_T5Zk2SnTderw1OXRUs2T16JxfVGCQ/edit?usp=sharing

cc: @LiuXiaoxuanPKU @WoosukKwon

  • Construct tree structure
  • Select top-k instead of top-1 from the logits from targer model
  • Selection in Level-0 and level-1
  • Selection in Level-2 & 2+
  • Node/path selection after expansion
  • Attention metadata & attention mask, customized ROPE embedding
  • Now the nodes without ancestor-relation could see each other (which is incorrect)
  • Rejection logic
  • KV cache & CUDA graph (future PRs)
  • E2e and unit tests

wwl2755 added 3 commits May 1, 2025 19:53
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
@github-actions
Copy link

github-actions bot commented May 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label May 1, 2025
@wwl2755 wwl2755 marked this pull request as draft May 1, 2025 20:49
PADDING_SLOT_ID = -1


class TreeArray:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expanding in real time is going to be very costly. Dynamic tree in actual production could be less effective

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! That's why I pre-allocate the "max_nodes" in advance. The difference from chain drafting is the size would be larger and number of tokens passed to forward pass is larger. The benefit can be longer acceptance length, which could reduce forward passes in target model.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that. We also need to have logits shifting logic for sampling and tree dynamic in actual use might have less efficiency in drafting. Maybe we could start with support for a static tree as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. the sampling logic should be changed also. It is not included yet. You can see a tracker of the progress in: https://docs.google.com/document/d/1mMoSicPPMMzaE_T5Zk2SnTderw1OXRUs2T16JxfVGCQ/edit?usp=sharing

And IMO, from static tree to dynamic tree, it won't introduce much difference (select all/top-k to expand & rerank logic). The major differene are from the tree structure comparing with the chain draft. But I'm open to community's opinions on which should we target on first.

with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states, output_hidden_states = self.model(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious in ROPE kernel do we already take into consideration that positions can be customized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for bringing this up! IIUC, in the rotary_embedding.py, we could pass an offsets to the forward function.

We have to custom the logic since different path have been mixed together and I would categorize it in "Attention metadata & attention mask" in the tracker. For now, it is only a place-holder.

@mergify
Copy link

mergify bot commented May 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wwl2755.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 23, 2025
wwl2755 added 2 commits May 27, 2025 06:14
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
@mergify mergify bot removed the needs-rebase label May 27, 2025
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
@mergify
Copy link

mergify bot commented Jul 31, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wwl2755.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 31, 2025
with set_forward_context(tree_per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states, output_hidden_states = self.model(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to use the tree_draft_propsoe you provided, but I got error like this
image

@wwl2755
Copy link
Contributor Author

wwl2755 commented Sep 1, 2025

Close because tree attention was supported in #20401

@wwl2755 wwl2755 closed this Sep 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants