Skip to content

Conversation

@Isotr0py
Copy link
Member

@Isotr0py Isotr0py commented Oct 24, 2025

Purpose

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a critical bug in the InternSdpaAttention's forward pass for interns1-vit models with QK normalization enabled. The previous implementation would crash due to incorrectly attempting to unpack a 3D tensor into four variables and also used flatten incorrectly. The fix applies the QK normalization directly to the query and key tensors, which is the correct approach. The removal of the unused B, N, C variables is also a good cleanup.

I've added one comment pointing out a related latent bug when num_dummy_heads > 0, which will also cause a crash. This seems to stem from incorrectly adapted logic for tensor parallelism dummy heads. While the current PR fixes the most obvious issue, addressing the related bug would make the implementation more robust.

Comment on lines +227 to +228
q = self.q_norm(q)
k = self.k_norm(k)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

While this change correctly fixes the immediate crash, there's a latent critical bug when num_dummy_heads > 0 which will cause crashes in this block and later on.

  1. q_norm and k_norm initialization: In __init__, self.q_norm is initialized as RMSNorm(hidden_size=self.dummy_dim, ..., var_hidden_size=self.embed_dim). If num_dummy_heads > 0, self.dummy_dim > self.embed_dim. The input q has a last dimension of self.embed_dim. RMSNorm expects the input's last dimension to match its hidden_size (self.dummy_dim), so self.q_norm(q) will raise a ValueError. The same applies to k_norm.

  2. projection_layer input shape: The output of self.attn(q, k, v) will have a shape of (..., self.embed_dim). However, self.projection_layer is initialized as nn.Linear(self.dummy_dim, self.embed_dim). If num_dummy_heads > 0, the call to self.projection_layer on a subsequent line will fail due to a shape mismatch.

The MultiHeadAttention implementation used here doesn't seem to account for dummy heads. The entire dummy_dim logic within InternSdpaAttention might need to be re-evaluated to either be correctly implemented or removed if it's not applicable for this ViT attention layer. A potential fix for the normalization part would be to initialize RMSNorm with self.embed_dim.

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) October 24, 2025 15:25
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 24, 2025
@DarkLight1337 DarkLight1337 merged commit acc78ae into vllm-project:main Oct 24, 2025
56 checks passed
kingsmad pushed a commit to kingsmad/vllm that referenced this pull request Oct 25, 2025
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
@Isotr0py Isotr0py deleted the fix-interns1 branch October 25, 2025 03:35
rohin-garg pushed a commit to rohin-garg/vllm that referenced this pull request Oct 25, 2025
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
ilmarkov pushed a commit to neuralmagic/vllm that referenced this pull request Nov 7, 2025
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

vllm deploy error

2 participants