Skip to content

Commit f23ca75

Browse files
committed
refactor to avoid incorrect rope dispatch
Signed-off-by: Yan Ma <yan.ma@intel.com>
1 parent acadc2b commit f23ca75

File tree

4 files changed

+47
-43
lines changed

4 files changed

+47
-43
lines changed

vllm/model_executor/layers/rotary_embedding/base.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@
1414

1515

1616
@CustomOp.register("rotary_embedding")
17-
class RotaryEmbedding(CustomOp):
17+
class RotaryEmbeddingBase(CustomOp):
1818
"""Original rotary positional embedding."""
1919

2020
def __init__(
21-
self,
22-
head_size: int,
23-
rotary_dim: int,
24-
max_position_embeddings: int,
25-
base: float,
26-
is_neox_style: bool,
27-
dtype: torch.dtype,
21+
self,
22+
head_size: int,
23+
rotary_dim: int,
24+
max_position_embeddings: int,
25+
base: float,
26+
is_neox_style: bool,
27+
dtype: torch.dtype,
2828
) -> None:
2929
super().__init__()
3030
self.head_size = head_size
@@ -59,10 +59,10 @@ def _compute_inv_freq(self, base: float) -> torch.Tensor:
5959
# create the cache on GPU for faster initialization. This may cause
6060
# a slight numerical difference between the HF implementation and ours.
6161
inv_freq = 1.0 / (
62-
base
63-
** (
64-
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
65-
)
62+
base
63+
** (
64+
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
65+
)
6666
)
6767
return inv_freq
6868

@@ -81,11 +81,24 @@ def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
8181
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
8282
# is expensive, so avoid calling it if possible
8383
if (
84-
self.cos_sin_cache.device != query.device
85-
or self.cos_sin_cache.dtype != query.dtype
84+
self.cos_sin_cache.device != query.device
85+
or self.cos_sin_cache.dtype != query.dtype
8686
):
8787
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
8888

89+
90+
class RotaryEmbedding(RotaryEmbeddingBase):
91+
def __init__(
92+
self,
93+
head_size: int,
94+
rotary_dim: int,
95+
max_position_embeddings: int,
96+
base: float,
97+
is_neox_style: bool,
98+
dtype: torch.dtype,
99+
) -> None:
100+
super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype)
101+
89102
def forward_native(
90103
self,
91104
positions: torch.Tensor,

vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from vllm.platforms import current_platform
99

10-
from .base import RotaryEmbedding
10+
from .base import RotaryEmbeddingBase
1111
from .common import (
1212
rotate_gptj,
1313
rotate_neox,
@@ -22,7 +22,7 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
2222
return 0.1 * mscale * math.log(scale) + 1.0
2323

2424

25-
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
25+
class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
2626
"""RotaryEmbedding extended with YaRN method.
2727
2828
Credits to Peng et al. github.com/jquesnelle/yarn
@@ -146,5 +146,11 @@ def forward_native(
146146
key = key_rot
147147
return query, key
148148

149-
forward_cuda = forward_native
150-
forward_xpu = forward_native
149+
def forward_cuda(
150+
self,
151+
positions: torch.Tensor,
152+
query: torch.Tensor,
153+
key: torch.Tensor | None = None,
154+
offsets: torch.Tensor | None = None,
155+
) -> tuple[torch.Tensor, torch.Tensor | None]:
156+
return self.forward_native(query, key)

vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
import torch
77

8-
from .base import RotaryEmbedding
8+
from .base import RotaryEmbeddingBase
99

1010

11-
class Llama4VisionRotaryEmbedding(RotaryEmbedding):
11+
class Llama4VisionRotaryEmbedding(RotaryEmbeddingBase):
1212
def __init__(
1313
self,
1414
head_size: int,
@@ -72,6 +72,9 @@ def forward_native( # type: ignore[override]
7272
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
7373
return query_out.type_as(query), key_out.type_as(key)
7474

75-
forward_cuda = forward_native
76-
forward_hip = forward_native
77-
forward_xpu = forward_native
75+
def forward_cuda( # type: ignore[override]
76+
self,
77+
query: torch.Tensor,
78+
key: torch.Tensor | None = None,
79+
) -> tuple[torch.Tensor, torch.Tensor | None]:
80+
return self.forward_native(query, key)

vllm/model_executor/layers/rotary_embedding/mrope.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from vllm.triton_utils import tl, triton
1010

11-
from .base import RotaryEmbedding
11+
from .base import RotaryEmbeddingBase
1212
from .common import apply_rotary_emb_dispatch
1313
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
1414

@@ -200,7 +200,7 @@ def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.T
200200
return x_t
201201

202202

203-
class MRotaryEmbedding(RotaryEmbedding):
203+
class MRotaryEmbedding(RotaryEmbeddingBase):
204204
"""Rotary Embedding with Multimodal Sections."""
205205

206206
def __init__(
@@ -358,24 +358,6 @@ def forward_cuda(
358358
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
359359
return query, key
360360

361-
def forward_xpu(
362-
self,
363-
positions: torch.Tensor,
364-
query: torch.Tensor,
365-
key: torch.Tensor | None = None,
366-
offsets: torch.Tensor | None = None,
367-
) -> tuple[torch.Tensor, torch.Tensor | None]:
368-
return self.forward_native(positions, query, key, offsets)
369-
370-
def forward_cpu(
371-
self,
372-
positions: torch.Tensor,
373-
query: torch.Tensor,
374-
key: torch.Tensor | None = None,
375-
offsets: torch.Tensor | None = None,
376-
) -> tuple[torch.Tensor, torch.Tensor | None]:
377-
return self.forward_native(positions, query, key, offsets)
378-
379361
@classmethod
380362
def get_input_positions(
381363
cls,

0 commit comments

Comments
 (0)