-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[V1][TPU] Integrate the new ragged paged attention kernel with vLLM v1 on TPU #13379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1][TPU] Integrate the new ragged paged attention kernel with vLLM v1 on TPU #13379
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
|
This pull request has merge conflicts that must be resolved before it can be |
f5f21f2 to
2316f14
Compare
…eal attn_metadata in dummy_run and basic.py is still working fine.
…elp much about the dynamo compilation
|
@alexm-redhat @WoosukKwon , there are 2 issues currently. One is running |
| self.model = torch.compile(model, | ||
| backend="openxla", | ||
| fullgraph=True, | ||
| dynamic=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@WoosukKwon , do you remember the reason why we set dynamic=False?
|
cc @bvrockwell |
|
hey @mgoin , it's my first PR in vLLM repo. I see "pre-commit / pre-commit (pull_request)" in the CI is red, it seems that it complains the format and mypy. For formatting, is there a linter I can use in vLLM? |
|
Hey @vanbasten23 please install precommit using these directions https://docs.vllm.ai/en/latest/contributing/overview.html#testing Then on your next commit it will apply |
3690070 to
1a942d5
Compare
Thanks. I followed it. Somehow, running |
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am able to get good evaluations with Qwen/Qwen2-1.5B-Instruct on gsm8k with a small number of samples and with a small batch size. The attention kernel clearly has issues at the moment with compilation time/memory usage, but this interface is where we should go. So I think we should land this as-is and quickly iterate with new kernels to improve the usability and fix bugs, rather than stay in the emulated V0 style we have on main currently.
Thank you for your great work!
|
Thanks @mgoin for reviewing the PR. |
|
Hi @mgoin , could you help merge the PR? I don't see a merge button on my side. |
|
Thanks for the ping, yes you need committer status to merge, which I'll handle. Let me quickly chat with @alexm-redhat before merging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vanbasten23 thanks for the hard work on this, the PR looks good to land! Just some small comments, but you can address them also after landing.
…1 on TPU (vllm-project#13379) Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
…1 on TPU (vllm-project#13379) Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
…1 on TPU (vllm-project#13379) Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This PR integrates the new ragged paged attention kernel with vLLM v1 on TPU. In particular, this PR
Test plan: