-
Notifications
You must be signed in to change notification settings - Fork 637
[feature] support pcp + mtp in full graph #4572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for PCP (Prefill Context Parallelism) and MTP (Multi-Token Prediction) in full graph mode, along with several related bug fixes. The changes correctly generalize PCP-only logic to accommodate DCP (Decode Context Parallelism) as well. A notable improvement is the handling of variable query lengths in speculative decoding batches, which replaces assumptions of fixed lengths with more robust logic. However, I've identified one critical issue in the implementation that needs to be addressed.
b08104c to
b330b75
Compare
| # prefill target_hidden_states: pcp split | ||
| num_tokens_d = num_decode_reqs * self.decode_threshold | ||
| query_lens_d = self.runner.query_lens[:num_decode_reqs] | ||
| num_tokens_d = query_lens_d.sum().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is query_lens_d a device tensor? if so, you call query_lens_d.sum().item() will incur cpu blocking, please fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a host tensor, refer to model_runner_v1.py: self.query_lens = torch.from_numpy(num_scheduled_tokens), so it will not influence host device sync. The _d is abbreviation of _decode.
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
ea969e8 to
f359b7a
Compare
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
Uh oh!
There was an error while loading. Please reload this page.