Skip to content

Conversation

@ZJY0516
Copy link
Contributor

@ZJY0516 ZJY0516 commented Nov 6, 2025

Purpose

As #26680 and #19784 mentioned, using torch.zeros to allocate attention output buffer will introduce an unnecessary kernel overhead.

w/ torch.empty, we allocate but do not initialize. CUDAGraph would remove this allocation time.
w/ torch.zero, we allocate AND initialize, which requires an extra cuda/triton kernel. CUDAGraph cannot remove this kernel. This kernel adds ~1 us latency, which is on-par with a layer norm or rope kernel (~1.6 us).

qwen3-next as an example:
main
image
this pr
image

A triton kernel was eliminated

accuracy test

vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct --enable-expert-parallel -tp 4
lm_eval --model local-completions --model_args model=Qwen/Qwen3-Next-80B-A3B-Instruct,base_url=http://localhost:8000/v1/completions -t gsm8k --num_fewshot 5 --batch_size 250
Tasks Version Filter n-shot Metric   Value   Stderr
gsm8k 3 flexible-extract 5 exact_match 0 ± 0
    strict-match 5 exact_match 0 ± 0

Perf test

qwen3-next

TL;DR: throughput 6340.16 -> 6756.52

vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct --enable-expert-parallel -tp 4
vllm bench serve \
--model Qwen/Qwen3-Next-80B-A3B-Instruct \
--dataset-name random \
--num-prompts 32 \
--random-input-len 2048 \
--random-output-len 1024

this pr

============ Serving Benchmark Result ============
Successful requests:                     32        
Failed requests:                         0         
Benchmark duration (s):                  14.55     
Total input tokens:                      65536     
Total generated tokens:                  32768     
Request throughput (req/s):              2.20      
Output token throughput (tok/s):         2252.17   
Peak output token throughput (tok/s):    2912.00   
Peak concurrent requests:                32.00     
Total Token throughput (tok/s):          6756.52   
---------------Time to First Token----------------
Mean TTFT (ms):                          1819.22   
Median TTFT (ms):                        1847.39   
P99 TTFT (ms):                           2960.06   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          12.39     
Median TPOT (ms):                        12.38     
P99 TPOT (ms):                           13.63     
---------------Inter-token Latency----------------
Mean ITL (ms):                           12.40     
Median ITL (ms):                         11.33     
P99 ITL (ms):                            12.00     
==================================================

main

============ Serving Benchmark Result ============
Successful requests:                     32        
Failed requests:                         0         
Benchmark duration (s):                  15.14     
Total input tokens:                      65536     
Total generated tokens:                  30472     
Request throughput (req/s):              2.11      
Output token throughput (tok/s):         2012.31   
Peak output token throughput (tok/s):    2400.00   
Peak concurrent requests:                32.00     
Total Token throughput (tok/s):          6340.16   
---------------Time to First Token----------------
Mean TTFT (ms):                          1355.49   
Median TTFT (ms):                        1365.46   
P99 TTFT (ms):                           2122.49   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          13.42     
Median TPOT (ms):                        13.32     
P99 TPOT (ms):                           14.34     
---------------Inter-token Latency----------------
Mean ITL (ms):                           13.46     
Median ITL (ms):                         12.74     
P99 ITL (ms):                            13.36     
==================================================
kimi linear has performance regression ### kimi linear
vllm serve moonshotai/Kimi-Linear-48B-A3B-Instruct --trust-remote-code -tp 4 --enable-expert-parallel
vllm bench serve \
--model moonshotai/Kimi-Linear-48B-A3B-Instruct \
--dataset-name random \
--num-prompts 32 \
--random-input-len 2048 \
--random-output-len 1024 \
--trust-remote-code

this pr

============ Serving Benchmark Result ============
Successful requests:                     32        
Failed requests:                         0         
Benchmark duration (s):                  11.18     
Total input tokens:                      65536     
Total generated tokens:                  31784     
Request throughput (req/s):              2.86      
Output token throughput (tok/s):         2843.43   
Peak output token throughput (tok/s):    3424.00   
Peak concurrent requests:                32.00     
Total Token throughput (tok/s):          8706.36   
---------------Time to First Token----------------
Mean TTFT (ms):                          1028.07   
Median TTFT (ms):                        1031.84   
P99 TTFT (ms):                           1634.19   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          9.95      
Median TPOT (ms):                        9.89      
P99 TPOT (ms):                           11.42     
---------------Inter-token Latency----------------
Mean ITL (ms):                           9.91      
Median ITL (ms):                         9.32      
P99 ITL (ms):                            10.10     
==================================================

main

============ Serving Benchmark Result ============
Successful requests:                     32        
Failed requests:                         0         
Benchmark duration (s):                  11.06     
Total input tokens:                      65536     
Total generated tokens:                  32206     
Request throughput (req/s):              2.89      
Output token throughput (tok/s):         2911.00   
Peak output token throughput (tok/s):    3552.00   
Peak concurrent requests:                32.00     
Total Token throughput (tok/s):          8834.59   
---------------Time to First Token----------------
Mean TTFT (ms):                          1179.73   
Median TTFT (ms):                        1196.33   
P99 TTFT (ms):                           1800.05   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          9.65      
Median TPOT (ms):                        9.62      
P99 TPOT (ms):                           10.65     
---------------Inter-token Latency----------------
Mean ITL (ms):                           9.64      
Median ITL (ms):                         9.05      
P99 ITL (ms):                            9.59      
==================================================
vllm serve moonshotai/Kimi-Linear-48B-A3B-Instruct --trust-remote-code -tp 4
lm_eval --model local-completions --model_args model=moonshotai/Kimi-Linear-48B-A3B-Instruct,base_url=http://localhost:8000/v1/completions,trust_remote_code=True -t gsm8k --num_fewshot 5 --batch_size 250
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8939 ± 0.0085
strict-match 5 exact_match 0.8764 ± 0.0091

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@ZJY0516 ZJY0516 requested a review from sighingnow as a code owner November 6, 2025 03:53
@mergify mergify bot added the qwen Related to Qwen models label Nov 6, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces torch.zeros with torch.empty for allocating attention output buffers to reduce overhead, which is a good performance optimization. You've correctly identified and handled the case where the buffer might not be written to in profiling runs by explicitly filling it with zeros, thus preserving the original behavior. However, this change highlights a pre-existing critical bug in qwen3_next.py where a buffer is allocated with an unpadded size but may be written to with a padded size when using CUDAGraphs, leading to a potential out-of-bounds write. I've provided a detailed comment and a code suggestion to fix this issue. The changes in vllm/model_executor/layers/kda.py appear correct and do not have this issue.

@ZJY0516
Copy link
Contributor Author

ZJY0516 commented Nov 6, 2025

I don't know why we have performance regression on kimi linear. Do you have any idea? @BoyuanFeng

@BoyuanFeng
Copy link
Contributor

image

Looks like there are still 3 kernels?

For the perf, would the number be stable if you run it twice?

@ZJY0516
Copy link
Contributor Author

ZJY0516 commented Nov 6, 2025

Looks like there are still 3 kernels?

I think triton_poi_fused_5 is generated by torch.zeros?

@BoyuanFeng
Copy link
Contributor

lol did not see triton_poi_fused_4

@ZJY0516
Copy link
Contributor Author

ZJY0516 commented Nov 6, 2025

For the perf, would the number be stable if you run it twice?

Yes, for kimi linear, every time performance of this pr is lower than main

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
@ZJY0516
Copy link
Contributor Author

ZJY0516 commented Nov 6, 2025

I'll leave kimi linear optimization in next pr since it has performance regression

@vadiklyutiy
Copy link
Collaborator

As I remember we need torch.zero here.
In case of cudagraph there is padding. But GDN kernel now fill only unpadded number of rows. So, will have a garbage in the padded rows.

@vadiklyutiy vadiklyutiy self-requested a review November 6, 2025 10:15
@vadiklyutiy
Copy link
Collaborator

@codex review

@chatgpt-codex-connector
Copy link

Codex Review: Didn't find any major issues. Swish!

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@ZJY0516
Copy link
Contributor Author

ZJY0516 commented Nov 6, 2025

As I remember we need torch.zero here. In case of cudagraph there is padding. But GDN kernel now fill only unpadded number of rows. So, will have a garbage in the padded rows.

Do you have some related code? The gsm8k result shows that there is no accuracy problem.

@vadiklyutiy
Copy link
Collaborator

Could you try

lm_eval --model local-chat-completions --model_args model=Qwen/Qwen3-Next-80B-A3B-Instruct,base_url=http://localhost:8000/v1/chat/completions,num_concurrent=250 --tasks gsm8k --apply_chat_template --num_fewshot 5

and maybe different values of num_concurrent. It should produce more variation of shapes in forward step.

@ZJY0516
Copy link
Contributor Author

ZJY0516 commented Nov 6, 2025

Yes, I think you are right.

The result are quite low now

lm_eval --model local-chat-completions --model_args model=Qwen/Qwen3-Next-80B-A3B-Instruct,base_url=http://localhost:8000/v1/chat/completions,num_concurrent=250 --tasks gsm8k --apply_chat_template --num_fewshot 5
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0 ± 0
strict-match 5 exact_match 0 ± 0

@ZJY0516
Copy link
Contributor Author

ZJY0516 commented Nov 6, 2025

@vadiklyutiy Maybe we should add a comment to explain why we should not use torch.empty here?

@ZJY0516
Copy link
Contributor Author

ZJY0516 commented Nov 6, 2025

I ran tests on the main branch and observed inconsistent results across different commands.

lm_eval --model local-completions --model_args model=Qwen/Qwen3-Next-80B-A3B-Instruct,base_url=http://localhost:8000/v1/completions -t gsm8k --num_fewshot 5 --batch_size 250
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8537 ± 0.0097
strict-match 5 exact_match 0.8074 ± 0.0109
lm_eval --model local-chat-completions --model_args model=Qwen/Qwen3-Next-80B-A3B-Instruct,base_url=http://localhost:8000/v1/chat/completions,num_concurrent=280 --tasks gsm8k --apply_chat_template --num_fewshot 5
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.7877 ± 0.0113
strict-match 5 exact_match 0.6566 ± 0.0131
lm_eval --model local-completions --model_args model=Qwen/Qwen3-Next-80B-A3B-Instruct,base_url=http://localhost:8000/v1/completions -t gsm8k --num_fewshot 5 --batch_size 250 --apply_chat_template
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.7998 ± 0.011
strict-match 5 exact_match 0.6664 ± 0.013

@ZJY0516 ZJY0516 marked this pull request as draft November 7, 2025 05:10
@vadiklyutiy
Copy link
Collaborator

I think all 3 options generates different prompts and some difference in scopes are expected.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants