-
-
Notifications
You must be signed in to change notification settings - Fork 11.2k
[TPU] update torch_xla pin #19231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TPU] update torch_xla pin #19231
Conversation
Signed-off-by: Chengji Yao <chengjiyao@google.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @yaochengji, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
Hello! Gemini here, providing a summary of this pull request. The main purpose of this PR is to update the pinned version of torch_xla in the requirements file. This update is intended to integrate recent improvements, specifically an updated Mixture-of-Experts (MoE) GMM kernel from the torch_xla repository. The author has provided benchmark results demonstrating significant performance gains on the Mixtral model after applying this update. Additionally, the PR includes a minor adjustment to the Pallas GMM kernel test to verify its support for irregular dimension sizes, which is part of the kernel update's capabilities.
Highlights
- MoE GMM Kernel Performance Improvement: This update integrates an optimized MoE GMM kernel from
torch_xla, which the author reports leads to substantial performance improvements for models like Mixtral, as evidenced by the included benchmark comparisons showing increased throughput and reduced latency. - Pallas GMM Test Coverage: A test case in
tests/tpu/test_moe_pallas.pyis modified to include an irregular dimension size (k=511) in its parameterization, ensuring the updated Pallas GMM kernel correctly handles non-standard dimensions.
Changelog
- requirements/tpu.txt
- Updated the pinned version of
torchfrom2.8.0.dev20250529to2.8.0.dev20250605(line 21). - Updated the pinned version of
torchvisionfrom0.22.0.dev20250529to0.23.0.dev20250605(line 22). - Updated the pinned version of
torch_xlafor Python 3.9, 3.10, and 3.11 to use the2.8.0.dev20250605wheel (lines 23-25).
- Updated the pinned version of
- tests/tpu/test_moe_pallas.py
- Changed the parameter
kin the@pytest.mark.parametrizedecorator from512to511to test irregular dimension sizes for the Pallas GMM kernel (line 30).
- Changed the parameter
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
A pin is updated,
Speed for MoE is created,
Tests pass, code is great.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The pull request updates the torch_xla pin and modifies the pallas gmm kernel test. The performance gain on the Mixtral model is significant, as shown in the benchmark results. The test modification ensures support for irregular dimension sizes. Overall, this is a valuable update. Here are some suggestions for improvement.
Summary of Findings
- Test Parameterization: The change modifies a test parameter
kfrom 512 to 511. It would be better to ensure that the test covers cases wherenum_tokens * topkis a multiple of 16, as required by the Pallas GMM kernel. Consider adjustingm,n, ortopkinstead, or adding a specific test case that satisfies this condition.
Merge Readiness
The pull request introduces performance improvements and a test modification. The test modification addresses a constraint of the Pallas GMM kernel. However, it would be beneficial to ensure that the test suite explicitly covers the cases where the constraint num_tokens * topk is satisfied. I am unable to directly approve this pull request, and recommend that others review and approve this code before merging. I would recommend addressing the medium severity issues before merging.
| @pytest.mark.parametrize("m", [8, 16, 64, 2048]) | ||
| @pytest.mark.parametrize("n", [128, 1024, 2048]) | ||
| @pytest.mark.parametrize("k", [128, 512, 1024]) | ||
| @pytest.mark.parametrize("k", [128, 511, 1024]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
vanbasten23
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Thanks Chengji
|
Could you also run a sample benchmarking (eg meta-llama/Meta-Llama-3.1-8B-Instruct) before merging the PR? |
mgoin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, great result! It would also be nice to verify if performance is better for MoE with many small experts, like Qwen/Qwen3-30B-A3B/
|
Please also add |
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
@vanbasten23 thanks for catching this! It is updated. |
The old kernel block sizes choosing logic has some issue and it cannot run So I compared the performance between the case without kernel and that with kernel in this PR. Below is the result. We can observe that gmm kernel is really critical when there're many experts. Without kernel: With kernel: |
| os.environ["LIBTPU_INIT_ARGS"] = ( | ||
| "--xla_tpu_force_1d_allreduce_at_chunk_count=1") | ||
| os.environ.get("LIBTPU_INIT_ARGS", "") + | ||
| " --xla_tpu_force_1d_allreduce_at_chunk_count=1") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: We can add a comment saying the additional libtpu arg is needed due to pytorch/xla#9084
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine. Because here we're not adding any specific libtpu arg, but inherit all the args if any.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, makes sense, thanks!
lsy323
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @yaochengji for updating the torch_xla pin!!
|
Could you check if |
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
the command #19231 (comment) failed with an error https://gist.github.com/vanbasten23/6772e44bc8b562256c3b184fb403c2b5 on my v6e-8 locally. cc @yaochengji @lsy323 Not sure if you see the same. |
Thanks, Xiongfei! As currently I don't have a v6e-8 VM. Do you mind share more log of the issue? The current gist doesn't have too much information. |
I'm trying to repro now but the benchmarking script is running very slow. But there is not much extra useful info in the log.. Also I couldn't repro using the script from #19231 (comment). What I observed is that, running script takes very long time and it may fail due to the error above. From the log, I remember it seems it's loading the model: 50% -> 60% -> 70%.... If you try a few time, it will load 100% and succeed eventually. I tried to clean up the cache Do we have a shared v6e-8 VM? If not, feel free to ping me and use mine. |
|
hmm, sorry I couldn't repro anymore |
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Purpose
To integrate the MoE gmm kernel update in torch_xla repo. We can observe a lot of performance gain on Mixtral model. Also it modified the pallas gmm kernel test a bit to prove that it can support irregular dimension size.
Before the update:
After the update:
Test Plan
Test Result
passed.