Skip to content

Conversation

@yma11
Copy link
Contributor

@yma11 yma11 commented Sep 18, 2025

Purpose

Fix more dispatch issue on xpu introduced in #24444

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added deepseek Related to DeepSeek models llama Related to Llama models labels Sep 18, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a dispatch issue for RoPE on XPU for the llama4 and deepseek models by providing a forward_xpu method that falls back to the native PyTorch implementation. This prevents the use of an incorrect specialized kernel from the base class. My review includes suggestions to refactor the newly added methods to reduce code duplication and improve maintainability, which will help prevent potential bugs in the future.

Comment on lines 142 to 149
def forward_xpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
return self.forward_native(positions, query, key, offsets)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To improve maintainability and reduce code duplication, you can directly alias forward_native to forward_xpu. The current implementation duplicates the body of forward_cuda, and both just delegate to forward_native. Using a direct assignment makes the intent clearer and ensures that any future changes to the signature of forward_native only need to be made in one place, reducing the risk of future bugs.

    forward_xpu = forward_native

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand root cause now.
before #24444, RotaryEmbedding will use forward_xpu() , its child class like Llama4VisionRotaryEmbedding, MRotaryEmbedding, DeepseekScalingRotaryEmbedding will use forward method directly, ignore parent class dispatch forward.
after #24444, all these child classes will extends parent class RotaryEmbedding 's forward_xpu method, which is not match.
Maybe a best fix is we define a BaseRotaryEmbedding class which not do any dispatch. and all rope extends this base class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. That's the root cause. This is already a base class introduced in ROPE refactor PR #22192. In this class, we have forward_xpu dispatch which will go to forward_native or ops.rotary_embedding and this default behavior makes sense. But for cases like Llama4VisionRotaryEmbedding, MRotaryEmbedding, DeepseekScalingRotaryEmbedding, our kernel doesn't support them and we need fix it at kernel level to avoid using forward_native at child class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the Gemini suggestion here, could you try it out?

Comment on lines 83 to 88
def forward_xpu( # type: ignore[override]
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
return self.forward_native(query, key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To avoid code duplication and enhance maintainability, it's better to alias forward_native for forward_xpu, as both this method and forward_cuda simply call forward_native. This approach is cleaner and less prone to errors if the underlying forward_native implementation or its signature changes in the future.

    forward_xpu = forward_native  # type: ignore[override]

@xuechendi
Copy link
Contributor

May you add a description to explain current fixing.
Looks like currently,This PR is to copy forward_xpu in all derived ROPE classes, right?
And based on discussion, better plan to provide a non_custom_op base class and derive from that?

@ProExpertProg
Copy link
Collaborator

I think a BaseRoPE class makes sense!

@frost-intel
Copy link
Contributor

@yma11 Any movement on this? Would love to have Llama4 functional here.

@yma11 yma11 force-pushed the rope-fix branch 4 times, most recently from f23ca75 to a010ec7 Compare October 15, 2025 02:29
@yma11
Copy link
Contributor Author

yma11 commented Oct 15, 2025

@yma11 Any movement on this? Would love to have Llama4 functional here.

updated based on comments. let's wait for CI result.

@jikunshang
Copy link
Collaborator

do we need consider other rope class in this folder?

@yma11
Copy link
Contributor Author

yma11 commented Oct 15, 2025

do we need consider other rope class in this folder?

should all be covered.

@yma11
Copy link
Contributor Author

yma11 commented Oct 16, 2025

I think a BaseRoPE class makes sense!

@ProExpertProg can you help review this PR again? A base class is added.

@jikunshang jikunshang enabled auto-merge (squash) October 27, 2025 02:08
@jikunshang jikunshang added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 27, 2025
auto-merge was automatically disabled October 28, 2025 02:18

Head branch was pushed to by a user without write access

@yma11 yma11 force-pushed the rope-fix branch 2 times, most recently from ae13379 to a9d1af1 Compare October 29, 2025 02:17
yma11 added 2 commits October 29, 2025 08:12
Signed-off-by: Yan Ma <yan.ma@intel.com>
Signed-off-by: Yan Ma <yan.ma@intel.com>
@jikunshang jikunshang merged commit b798e39 into vllm-project:main Oct 30, 2025
47 checks passed
MatthewBonanni pushed a commit to MatthewBonanni/vllm that referenced this pull request Oct 30, 2025
ilmarkov pushed a commit to neuralmagic/vllm that referenced this pull request Nov 7, 2025
ZhengHongming888 pushed a commit to ZhengHongming888/vllm that referenced this pull request Nov 8, 2025
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
eldarkurtic pushed a commit to eldarkurtic/vllm that referenced this pull request Nov 12, 2025
Signed-off-by: Yan Ma <yan.ma@intel.com>
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models llama Related to Llama models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants