-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[Model][Speculative Decoding] Expand DeepSeek MTP code to support k > n_predict #13626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3692f83
570da36
637dc63
f329694
0f7d679
0eb2fad
c308daa
627a8ca
8143c90
b5772f4
997ee09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -184,8 +184,7 @@ def create_worker( | |||
| elif draft_model_config.hf_config.model_type == "medusa": | ||||
| proposer_worker = MedusaWorker(**draft_worker_kwargs) | ||||
| else: | ||||
| if draft_tp == 1 or draft_model_config.hf_config.model_type ==\ | ||||
| "deepseek_mtp": | ||||
| if draft_tp == 1: | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the concern for using TP1DraftModelRunner for deepseek_mtp where tp > 1 here?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @luccafong Using the plain ModelRunner for this branch makes the most sense to me , as we are using it as intended for single-step execution.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was planning to make TP1DraftModelRunner compatible for both cases, so name is probably confusing, actually changing model runner will be harder to support the real multi-step k>1 MTP. Good for this PR's use case.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That sounds challenging. Best of luck!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @luccafong QQ: does TP1DraftModelRunner work when the draft model has TP=8? If not, then this PR also has the benefit of greatly relieving the memory pressure I believe, since otherwise the full MTP module (~16GB weights) will all reside on the first device, which limit the max context length that can be supported?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
previous implementation support TP=8 in draft model runner for MTP in case of k=1. and we have attached the benchmark there for TP=8 #12755.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @luccafong ah this makes total sense now. thanks for the clarification!
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pyc96 Yes, there is some special handling there and in a few other places. Right now multi-step is hard-coded in several ways to be specific to the Flash Attention backend. The advance_step operation handles some attention-specific state processing, and so its behaviour is specific to the attention backend being used. I expect that much of this logic could be re-used for MLA, but this will have to be its own contribution. In the past I have tried to forcibly enable this code path for MLA, and there were several mismatches in state that would need to be handled specifically. I don't think that adding this compatibility is a trivial change. If you disable MLA, then the multi-step code is usable.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @benchislett Thank you for your reply! On top of your PR, I changed
and it just worked fine for drafter TP=1 case. Actually without this change, for TP=1 case, I believe current code will throw this assertion error during request serving. Is that intended?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @pyc96 I can confirm that it does seem to work with multi-step now. Even For this PR, I would prefer to merge sooner and have a follow-up work (cc @luccafong) add the multi-step functionality for DeepSeek with some thorough testing and sign-off. For now, I've pushed a commit which removes the MLA backend from |
||||
| if current_platform.is_cuda_alike(): | ||||
| draft_worker_kwargs[ | ||||
| "model_runner_cls"] = TP1DraftModelRunner | ||||
|
|
@@ -203,7 +202,8 @@ def create_worker( | |||
|
|
||||
| proposer_worker = MultiStepWorker(**draft_worker_kwargs) | ||||
| if draft_model_config.hf_config.model_type == "deepseek_mtp": | ||||
| num_spec_prefill_steps = num_speculative_tokens | ||||
| num_spec_prefill_steps = \ | ||||
| draft_model_config.hf_config.n_predict | ||||
|
|
||||
| proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( | ||||
| proposer_worker, draft_tp, target_tp) | ||||
|
|
||||
Uh oh!
There was an error while loading. Please reload this page.