Skip to content

Conversation

@yaochengji
Copy link
Collaborator

@yaochengji yaochengji commented Jun 5, 2025

Purpose

To integrate the MoE gmm kernel update in torch_xla repo. We can observe a lot of performance gain on Mixtral model. Also it modified the pallas gmm kernel test a bit to prove that it can support irregular dimension size.

Before the update:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  121.44    
Total input tokens:                      1672653   
Total generated tokens:                  128000    
Request throughput (req/s):              8.23      
Output token throughput (tok/s):         1054.00   
Total Token throughput (tok/s):          14827.30  
---------------Time to First Token----------------
Mean TTFT (ms):                          57163.68  
Median TTFT (ms):                        56839.68  
P99 TTFT (ms):                           113953.49 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          116.61    
Median TPOT (ms):                        121.36    
P99 TPOT (ms):                           122.98    
---------------Inter-token Latency----------------
Mean ITL (ms):                           116.61    
Median ITL (ms):                         126.01    
P99 ITL (ms):                            127.24    
==================================================

After the update:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  83.38     
Total input tokens:                      1672653   
Total generated tokens:                  128000    
Request throughput (req/s):              11.99     
Output token throughput (tok/s):         1535.22   
Total Token throughput (tok/s):          21596.85  
---------------Time to First Token----------------
Mean TTFT (ms):                          39235.74  
Median TTFT (ms):                        38956.00  
P99 TTFT (ms):                           78323.49  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          80.03     
Median TPOT (ms):                        83.56     
P99 TPOT (ms):                           84.11     
---------------Inter-token Latency----------------
Mean ITL (ms):                           80.03     
Median ITL (ms):                         86.29     
P99 ITL (ms):                            87.79     
==================================================

Test Plan

pytest -s -v tests/tpu/test_moe_pallas.py

Test Result

passed.

Signed-off-by: Chengji Yao <chengjiyao@google.com>
@yaochengji yaochengji requested review from mgoin and vanbasten23 June 5, 2025 18:29
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @yaochengji, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello! Gemini here, providing a summary of this pull request. The main purpose of this PR is to update the pinned version of torch_xla in the requirements file. This update is intended to integrate recent improvements, specifically an updated Mixture-of-Experts (MoE) GMM kernel from the torch_xla repository. The author has provided benchmark results demonstrating significant performance gains on the Mixtral model after applying this update. Additionally, the PR includes a minor adjustment to the Pallas GMM kernel test to verify its support for irregular dimension sizes, which is part of the kernel update's capabilities.

Highlights

  • MoE GMM Kernel Performance Improvement: This update integrates an optimized MoE GMM kernel from torch_xla, which the author reports leads to substantial performance improvements for models like Mixtral, as evidenced by the included benchmark comparisons showing increased throughput and reduced latency.
  • Pallas GMM Test Coverage: A test case in tests/tpu/test_moe_pallas.py is modified to include an irregular dimension size (k=511) in its parameterization, ensuring the updated Pallas GMM kernel correctly handles non-standard dimensions.

Changelog

  • requirements/tpu.txt
    • Updated the pinned version of torch from 2.8.0.dev20250529 to 2.8.0.dev20250605 (line 21).
    • Updated the pinned version of torchvision from 0.22.0.dev20250529 to 0.23.0.dev20250605 (line 22).
    • Updated the pinned version of torch_xla for Python 3.9, 3.10, and 3.11 to use the 2.8.0.dev20250605 wheel (lines 23-25).
  • tests/tpu/test_moe_pallas.py
    • Changed the parameter k in the @pytest.mark.parametrize decorator from 512 to 511 to test irregular dimension sizes for the Pallas GMM kernel (line 30).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.


A pin is updated,
Speed for MoE is created,
Tests pass, code is great.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@yaochengji yaochengji requested a review from lsy323 June 5, 2025 18:30
@mergify mergify bot added ci/build tpu Related to Google TPUs labels Jun 5, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request updates the torch_xla pin and modifies the pallas gmm kernel test. The performance gain on the Mixtral model is significant, as shown in the benchmark results. The test modification ensures support for irregular dimension sizes. Overall, this is a valuable update. Here are some suggestions for improvement.

Summary of Findings

  • Test Parameterization: The change modifies a test parameter k from 512 to 511. It would be better to ensure that the test covers cases where num_tokens * topk is a multiple of 16, as required by the Pallas GMM kernel. Consider adjusting m, n, or topk instead, or adding a specific test case that satisfies this condition.

Merge Readiness

The pull request introduces performance improvements and a test modification. The test modification addresses a constraint of the Pallas GMM kernel. However, it would be beneficial to ensure that the test suite explicitly covers the cases where the constraint num_tokens * topk is satisfied. I am unable to directly approve this pull request, and recommend that others review and approve this code before merging. I would recommend addressing the medium severity issues before merging.

@pytest.mark.parametrize("m", [8, 16, 64, 2048])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 512, 1024])
@pytest.mark.parametrize("k", [128, 511, 1024])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a test case where k is chosen such that num_tokens * topk is a multiple of 16, given the constraint mentioned in line 27. This would provide more confidence in the kernel's correctness under the required condition.

@github-actions
Copy link

github-actions bot commented Jun 5, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Collaborator

@vanbasten23 vanbasten23 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Thanks Chengji

@vanbasten23
Copy link
Collaborator

Could you also run a sample benchmarking (eg meta-llama/Meta-Llama-3.1-8B-Instruct) before merging the PR?

@vanbasten23 vanbasten23 added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 5, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, great result! It would also be nice to verify if performance is better for MoE with many small experts, like Qwen/Qwen3-30B-A3B/

@vanbasten23
Copy link
Collaborator

Please also add --xla_tpu_use_enhanced_launch_barrier=false before merging the PR. Without the flag, running on multi-chip TPU may hang pytorch/xla#9084

Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
@yaochengji
Copy link
Collaborator Author

Please also add --xla_tpu_use_enhanced_launch_barrier=false before merging the PR. Without the flag, running on multi-chip TPU may hang pytorch/xla#9084

@vanbasten23 thanks for catching this! It is updated.

@yaochengji
Copy link
Collaborator Author

LGTM, great result! It would also be nice to verify if performance is better for MoE with many small experts, like Qwen/Qwen3-30B-A3B/

The old kernel block sizes choosing logic has some issue and it cannot run Qwen/Qwen3-30B-A3B with tp=4.

So I compared the performance between the case without kernel and that with kernel in this PR. Below is the result. We can observe that gmm kernel is really critical when there're many experts.

Without kernel:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  1050.18   
Total input tokens:                      1777281   
Total generated tokens:                  128000    
Request throughput (req/s):              0.95      
Output token throughput (tok/s):         121.88    
Total Token throughput (tok/s):          1814.24   
---------------Time to First Token----------------
Mean TTFT (ms):                          466472.33 
Median TTFT (ms):                        458438.13 
P99 TTFT (ms):                           1031099.56
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          1114.18   
Median TPOT (ms):                        1256.88   
P99 TPOT (ms):                           1259.46   
---------------Inter-token Latency----------------
Mean ITL (ms):                           1114.18   
Median ITL (ms):                         1257.31   
P99 ITL (ms):                            1269.14   
==================================================

With kernel:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  110.76    
Total input tokens:                      1777281   
Total generated tokens:                  128000    
Request throughput (req/s):              9.03      
Output token throughput (tok/s):         1155.65   
Total Token throughput (tok/s):          17201.93  
---------------Time to First Token----------------
Mean TTFT (ms):                          52239.41  
Median TTFT (ms):                        52112.22  
P99 TTFT (ms):                           104398.86 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          109.61    
Median TPOT (ms):                        114.86    
P99 TPOT (ms):                           115.05    
---------------Inter-token Latency----------------
Mean ITL (ms):                           109.61    
Median ITL (ms):                         114.81    
P99 ITL (ms):                            115.99    
==================================================

@yaochengji yaochengji enabled auto-merge (squash) June 6, 2025 02:33
os.environ["LIBTPU_INIT_ARGS"] = (
"--xla_tpu_force_1d_allreduce_at_chunk_count=1")
os.environ.get("LIBTPU_INIT_ARGS", "") +
" --xla_tpu_force_1d_allreduce_at_chunk_count=1")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We can add a comment saying the additional libtpu arg is needed due to pytorch/xla#9084

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine. Because here we're not adding any specific libtpu arg, but inherit all the args if any.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense, thanks!

Copy link
Collaborator

@lsy323 lsy323 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yaochengji for updating the torch_xla pin!!

@vanbasten23
Copy link
Collaborator

Could you check if VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-70B-Instruct --disable-log-requests --gpu-memory-utilization 0.98 --max-num-batched-tokens 2048 --max-num-seqs 128 --max-model-len 2048 --tensor-parallel-size 8 --no-enable-prefix-caching works on your v6e-8 VM?

@yaochengji yaochengji merged commit b61dc5f into vllm-project:main Jun 6, 2025
65 checks passed
@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@vanbasten23
Copy link
Collaborator

the command #19231 (comment) failed with an error https://gist.github.com/vanbasten23/6772e44bc8b562256c3b184fb403c2b5 on my v6e-8 locally. cc @yaochengji @lsy323 Not sure if you see the same.

@yaochengji
Copy link
Collaborator Author

the command #19231 (comment) failed with an error https://gist.github.com/vanbasten23/6772e44bc8b562256c3b184fb403c2b5 on my v6e-8 locally. cc @yaochengji @lsy323 Not sure if you see the same.

Thanks, Xiongfei! As currently I don't have a v6e-8 VM. Do you mind share more log of the issue? The current gist doesn't have too much information.

@vanbasten23
Copy link
Collaborator

Thanks, Xiongfei! As currently I don't have a v6e-8 VM. Do you mind share more log of the issue? The current gist doesn't have too much information.

I'm trying to repro now but the benchmarking script is running very slow. But there is not much extra useful info in the log..

Also I couldn't repro using the script from #19231 (comment). What I observed is that, running script takes very long time and it may fail due to the error above. From the log, I remember it seems it's loading the model: 50% -> 60% -> 70%.... If you try a few time, it will load 100% and succeed eventually. I tried to clean up the cache ~/.cache/huggingface but I couldn't repro the error. Not sure if there is other vLLM cache

Do we have a shared v6e-8 VM? If not, feel free to ping me and use mine.

@vanbasten23
Copy link
Collaborator

hmm, sorry I couldn't repro anymore

leoli1208 pushed a commit to leoli1208/vllm that referenced this pull request Jul 22, 2025
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants