Skip to content

Conversation

@gshtras
Copy link
Collaborator

@gshtras gshtras commented Oct 22, 2025

Starting with #25135 Rope's forward_hip falls back to forward_cuda while assuming that the function modifies the values in place, which is not the case for DeepSeek's deepseek_scaling_rope implementation.
Instead of relying on an in-place modification we should return the values from the forward function implementation.

…-place

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
@gshtras gshtras added the rocm Related to AMD ROCm label Oct 22, 2025
@mergify mergify bot added the deepseek Related to DeepSeek models label Oct 22, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in the forward_hip method within the RotaryEmbedding class for ROCm platforms. The previous implementation incorrectly assumed that its fallback to forward_cuda would always involve an in-place modification of the query and key tensors. This assumption fails for certain implementations like DeepSeek's deepseek_scaling_rope, which do not modify tensors in-place. The fix rectifies this by returning the result from forward_cuda directly, making the forward_hip method robust to both in-place and out-of-place forward_cuda implementations. This change is a clear and correct fix that enhances the reliability of the rotary embedding layer.

@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 23, 2025
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this is an inheritance bug where the subclass overwrites forward_cuda and uses an out-of-place implementation? Nice find.

@tjtanaa
Copy link
Collaborator

tjtanaa commented Oct 23, 2025

@gshtras A question, why not choose to override at the DeepseekScalingRotaryEmbedding by redefining the like the forward_cuda function? Moreover, overrriding at the DeepseekScalingRotaryEmbedding can make the code easier to follow and avoid future inheritance bug.

    def forward_hip(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor | None = None,
        offsets: torch.Tensor | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        return self.forward_native(positions, query, key, offsets)

Is the RotaryEmbedding's forward_cuda the same as the DeepseekScalingRotaryEmbedding's forward_cuda?

def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""PyTorch-native implementation equivalent to forward()."""
assert key is not None
self._match_cos_sin_cache_dtype(query)
query_rot = query[..., : self.rotary_dim]
key_rot = key[..., : self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim :]
key_pass = key[..., self.rotary_dim :]
cos_sin = self.cos_sin_cache[
torch.add(positions, offsets) if offsets is not None else positions
]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key

CC @SageMoore

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) October 23, 2025 16:17
@DarkLight1337 DarkLight1337 merged commit 0825197 into vllm-project:main Oct 23, 2025
51 of 53 checks passed
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 23, 2025
…lm-project#27373)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
kingsmad pushed a commit to kingsmad/vllm that referenced this pull request Oct 25, 2025
…lm-project#27373)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…lm-project#27373)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
…lm-project#27373)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
ilmarkov pushed a commit to neuralmagic/vllm that referenced this pull request Nov 7, 2025
…lm-project#27373)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
…lm-project#27373)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants