Skip to content

Commit d2f9129

Browse files
committed
refactor to avoid incorrect rope dispatch
Signed-off-by: Yan Ma <yan.ma@intel.com>
1 parent 7aaed2b commit d2f9129

File tree

4 files changed

+36
-30
lines changed

4 files changed

+36
-30
lines changed

vllm/model_executor/layers/rotary_embedding/base.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
@CustomOp.register("rotary_embedding")
17-
class RotaryEmbedding(CustomOp):
17+
class RotaryEmbeddingBase(CustomOp):
1818
"""Original rotary positional embedding."""
1919

2020
def __init__(
@@ -86,6 +86,21 @@ def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
8686
):
8787
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
8888

89+
90+
class RotaryEmbedding(RotaryEmbeddingBase):
91+
def __init__(
92+
self,
93+
head_size: int,
94+
rotary_dim: int,
95+
max_position_embeddings: int,
96+
base: float,
97+
is_neox_style: bool,
98+
dtype: torch.dtype,
99+
) -> None:
100+
super().__init__(
101+
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
102+
)
103+
89104
def forward_native(
90105
self,
91106
positions: torch.Tensor,

vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from vllm.platforms import current_platform
99

10-
from .base import RotaryEmbedding
10+
from .base import RotaryEmbeddingBase
1111
from .common import (
1212
rotate_gptj,
1313
rotate_neox,
@@ -22,7 +22,7 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
2222
return 0.1 * mscale * math.log(scale) + 1.0
2323

2424

25-
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
25+
class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
2626
"""RotaryEmbedding extended with YaRN method.
2727
2828
Credits to Peng et al. github.com/jquesnelle/yarn
@@ -146,5 +146,11 @@ def forward_native(
146146
key = key_rot
147147
return query, key
148148

149-
forward_cuda = forward_native
150-
forward_xpu = forward_native
149+
def forward_cuda(
150+
self,
151+
positions: torch.Tensor,
152+
query: torch.Tensor,
153+
key: torch.Tensor | None = None,
154+
offsets: torch.Tensor | None = None,
155+
) -> tuple[torch.Tensor, torch.Tensor | None]:
156+
return self.forward_native(positions, query, key, offsets)

vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
import torch
77

8-
from .base import RotaryEmbedding
8+
from .base import RotaryEmbeddingBase
99

1010

11-
class Llama4VisionRotaryEmbedding(RotaryEmbedding):
11+
class Llama4VisionRotaryEmbedding(RotaryEmbeddingBase):
1212
def __init__(
1313
self,
1414
head_size: int,
@@ -72,6 +72,9 @@ def forward_native( # type: ignore[override]
7272
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
7373
return query_out.type_as(query), key_out.type_as(key)
7474

75-
forward_cuda = forward_native
76-
forward_hip = forward_native
77-
forward_xpu = forward_native
75+
def forward_cuda( # type: ignore[override]
76+
self,
77+
query: torch.Tensor,
78+
key: torch.Tensor | None = None,
79+
) -> tuple[torch.Tensor, torch.Tensor | None]:
80+
return self.forward_native(query, key)

vllm/model_executor/layers/rotary_embedding/mrope.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from vllm.triton_utils import tl, triton
99

10-
from .base import RotaryEmbedding
10+
from .base import RotaryEmbeddingBase
1111
from .common import apply_rotary_emb_dispatch
1212
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
1313

@@ -199,7 +199,7 @@ def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.T
199199
return x_t
200200

201201

202-
class MRotaryEmbedding(RotaryEmbedding):
202+
class MRotaryEmbedding(RotaryEmbeddingBase):
203203
"""Rotary Embedding with Multimodal Sections."""
204204

205205
def __init__(
@@ -357,24 +357,6 @@ def forward_cuda(
357357
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
358358
return query, key
359359

360-
def forward_xpu(
361-
self,
362-
positions: torch.Tensor,
363-
query: torch.Tensor,
364-
key: torch.Tensor | None = None,
365-
offsets: torch.Tensor | None = None,
366-
) -> tuple[torch.Tensor, torch.Tensor | None]:
367-
return self.forward_native(positions, query, key, offsets)
368-
369-
def forward_cpu(
370-
self,
371-
positions: torch.Tensor,
372-
query: torch.Tensor,
373-
key: torch.Tensor | None = None,
374-
offsets: torch.Tensor | None = None,
375-
) -> tuple[torch.Tensor, torch.Tensor | None]:
376-
return self.forward_native(positions, query, key, offsets)
377-
378360
@staticmethod
379361
def get_next_input_positions(
380362
mrope_position_delta: int,

0 commit comments

Comments
 (0)