Skip to content

Conversation

@vanbasten23
Copy link
Collaborator

@vanbasten23 vanbasten23 commented Feb 17, 2025

This PR integrates the new ragged paged attention kernel with vLLM v1 on TPU. In particular, this PR

  • Update torch_xla pin to the latest
  • Update pallas.py in v1 to use the new ragged paged attention kernel instead of the 3 separate kernels in v0.
  • Combine prompt and decode steps into one single step in tpu_model_runner.py, similar to what GPU does today.

Test plan:

  • $ VLLM_USE_V1=1 python vllm/examples/offline_inference/basic.py 2>&1 | tee out.txt
  • $ VLLM_USE_V1=1 pytest -s -v vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine 2>&1 | tee out.txt

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify
Copy link

mergify bot commented Feb 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vanbasten23.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 18, 2025
@vanbasten23 vanbasten23 force-pushed the xiowei/tpu_v1_kernel_integration_take2 branch from f5f21f2 to 2316f14 Compare February 19, 2025 17:55
@mergify mergify bot removed the needs-rebase label Feb 21, 2025
@vanbasten23
Copy link
Collaborator Author

@alexm-redhat @WoosukKwon , there are 2 issues currently. One is running vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine is very slow (2h, probably due to excessive compiling), which I'm investigating. Another issue is the test vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine actually fails. Any suggestion on how to find a smaller repro in order to debug?

self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon , do you remember the reason why we set dynamic=False?

@vanbasten23
Copy link
Collaborator Author

cc @bvrockwell

@vanbasten23
Copy link
Collaborator Author

@vllm-v-team@google.com

@vanbasten23
Copy link
Collaborator Author

hey @mgoin , it's my first PR in vLLM repo. I see "pre-commit / pre-commit (pull_request)" in the CI is red, it seems that it complains the format and mypy. For formatting, is there a linter I can use in vLLM?

@mgoin
Copy link
Member

mgoin commented Feb 27, 2025

Hey @vanbasten23 please install precommit using these directions https://docs.vllm.ai/en/latest/contributing/overview.html#testing

pip install -r requirements-dev.txt
pre-commit install --hook-type pre-commit --hook-type commit-msg

Then on your next commit it will apply

@vanbasten23 vanbasten23 force-pushed the xiowei/tpu_v1_kernel_integration_take2 branch from 3690070 to 1a942d5 Compare February 27, 2025 21:10
@vanbasten23
Copy link
Collaborator Author

Hey @vanbasten23 please install precommit using these directions https://docs.vllm.ai/en/latest/contributing/overview.html#testing

pip install -r requirements-dev.txt
pre-commit install --hook-type pre-commit --hook-type commit-msg

Then on your next commit it will apply

Thanks. I followed it. Somehow, running pre-commit run --all-files removed https://github.com/vanbasten23/vllm/blob/58d1b2aa772deb166355423997fbf5c1b6b186a1/vllm/v1/attention/backends/pallas.py#L7 which is important even though it is not directly used. I've added it manually.

vanbasten23 and others added 4 commits February 27, 2025 21:51
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am able to get good evaluations with Qwen/Qwen2-1.5B-Instruct on gsm8k with a small number of samples and with a small batch size. The attention kernel clearly has issues at the moment with compilation time/memory usage, but this interface is where we should go. So I think we should land this as-is and quickly iterate with new kernels to improve the usability and fix bugs, rather than stay in the emulated V0 style we have on main currently.

Thank you for your great work!

@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs labels Feb 28, 2025
@mgoin mgoin changed the title Integrate the new ragged paged attention kernel with vLLM v1 on TPU [V1][TPU] Integrate the new ragged paged attention kernel with vLLM v1 on TPU Feb 28, 2025
@vanbasten23
Copy link
Collaborator Author

Thanks @mgoin for reviewing the PR.

@vanbasten23
Copy link
Collaborator Author

Hi @mgoin , could you help merge the PR? I don't see a merge button on my side.

@mgoin
Copy link
Member

mgoin commented Feb 28, 2025

Thanks for the ping, yes you need committer status to merge, which I'll handle. Let me quickly chat with @alexm-redhat before merging

Copy link
Collaborator

@alexm-redhat alexm-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vanbasten23 thanks for the hard work on this, the PR looks good to land! Just some small comments, but you can address them also after landing.

@mgoin mgoin merged commit c3b6559 into vllm-project:main Feb 28, 2025
47 checks passed
Akshat-Tripathi added a commit to krai/vllm that referenced this pull request Mar 3, 2025
…1 on TPU (vllm-project#13379)

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
…1 on TPU (vllm-project#13379)

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
…1 on TPU (vllm-project#13379)

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants