Skip to content

Conversation

@BoyuanFeng
Copy link
Contributor

@BoyuanFeng BoyuanFeng commented Oct 13, 2025

This PR enables horizontal fusion from inductor. This is helpful for fusing q-norm & k-norm into 1 kernel; and q-rope & k-rope into 1 kernel. It was NOT fused before since q and k have different shapes, which prevent some optimizations to happen.

The following trace comes from qwen3-0.6b.

Before:

image

After (together w/ #26680):
image

After qkv_proj, we reduces from 5 kernels to 2 kernels: 1 for qk-norm and 1 for qk-rope.

Performance

Applying this PR + #26680

Qwen/Qwen3-0.6B

Before:
image

After:
image

Signed-off-by: Boyuan Feng <boyuan@meta.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new configuration option, use_horizontal_fusion, to enable horizontal fusion for qk-norm and qk-rope operations in PyTorch Inductor. This is achieved by setting combo_kernels and benchmark_combo_kernel in the Inductor configuration when the feature is enabled and the PyTorch version is 2.9.0.dev or newer. My main feedback is regarding the default value of the new flag. For stability, it would be safer to disable this experimental feature by default and allow users to opt-in.

since we know all keys are in a range [0, max_capture_size],
we can optimize it to list[int] for better lookup performance."""

use_horizontal_fusion = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use_horizontal_fusion flag is enabled by default. This will automatically enable the combo_kernels feature in PyTorch Inductor for users on versions 2.9.0.dev or newer. Since this relies on a feature in a development version of PyTorch, it may be unstable. It would be safer to set this to False by default to prevent potential issues for users on bleeding-edge PyTorch versions. Users can then explicitly opt-in to enable this experimental feature.

Suggested change
use_horizontal_fusion = True
use_horizontal_fusion = False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to enable by default since it benefits models in general.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Signed-off-by: Boyuan Feng <boyuan@meta.com>
@zou3519
Copy link
Collaborator

zou3519 commented Oct 13, 2025

Main thing I think around here is testing, do we have a plan around that or are we yolo-ing this? @BoyuanFeng @ProExpertProg

Comment on lines +506 to +509
# use horizontal fusion, which is useful for fusing qk-norm and
# qk-rope when query and key have different shapes.
self.inductor_compile_config["combo_kernels"] = True
self.inductor_compile_config["benchmark_combo_kernel"] = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would appreciate a doc pointer to how this works so I can understand for future work. Currently this is very opaque

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have some doc here. I will followup with more pytorch docs.

@BoyuanFeng
Copy link
Contributor Author

@zou3519 tested for facebook/opt-125m, qwen/qwen3-0.6b, google/gemma-3-4b-it, openai/gpt-oss-20b, and the outputs are correct: https://paste.sh/llBbgv2x#GHNmWaQcfkKA_6ypMb_UGkgo

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yolo here seems acceptable we can add tests in the future

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 14, 2025
@zou3519 zou3519 merged commit ca683a2 into vllm-project:main Oct 14, 2025
48 checks passed
Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
bbartels pushed a commit to bbartels/vllm that referenced this pull request Oct 16, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: bbartels <benjamin@bartels.dev>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Boyuan Feng <boyuan@meta.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
@ProExpertProg ProExpertProg mentioned this pull request Oct 28, 2025
1 task
@ProExpertProg ProExpertProg linked an issue Oct 28, 2025 that may be closed by this pull request
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: Optimize RoPE

5 participants