-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[V1][Spec Decode] Enable spec decode for top-p & top-k sampling #15063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Wondering how did we test it?
Good point. I just wanted to get some initial feedback before adding tests. Will update the PR. |
|
@houseroad @LiuXiaoxuanPKU I've added the tests, and they're passing locally. Could you please review? I'd appreciate including this PR in v0.8.2 if possible. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just minor QQ about the test
| num_tokens = batch_size * num_draft_tokens | ||
|
|
||
| # Randomly create unmasked indices. | ||
| num_top_p_tokens = int(vocab_size * top_p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit confused by the definition of top_p sampling, should it be ' restricting the sampling to the set of most probable tokens with cumulative probability more than p'? Instead of sampling a fixed percentage of tokens like here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LiuXiaoxuanPKU Good catch. It only makes sense when int(vocab_size * top_p) tokens all have equal high logits (e.g., 100) while the others have -100. But definitely this is not general enough.
I've updated it to test top-p more precisely. Could you please take another look?
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…-project#15063) Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
…-project#15063) Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Wes Medford <wryanmedford@gmail.com>
…-project#15063) Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
…-project#15063) Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
…-project#15063) Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
…-project#15063) Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
This PR is to enable spec decoding for requests with top-p & top-k sampling.
It is implemented using
apply_top_k_top_pto mask the logits of the target model.While this is more expensive than FlashInfer's sorting-free sampling, I think it's good for the first step.