Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion vllm/model_executor/layers/rotary_embedding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


@CustomOp.register("rotary_embedding")
class RotaryEmbedding(CustomOp):
class RotaryEmbeddingBase(CustomOp):
"""Original rotary positional embedding."""

def __init__(
Expand Down Expand Up @@ -86,6 +86,21 @@ def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
):
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)


class RotaryEmbedding(RotaryEmbeddingBase):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)

def forward_native(
self,
positions: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from vllm.platforms import current_platform

from .base import RotaryEmbedding
from .base import RotaryEmbeddingBase
from .common import (
rotate_gptj,
rotate_neox,
Expand All @@ -22,7 +22,7 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return 0.1 * mscale * math.log(scale) + 1.0


class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
"""RotaryEmbedding extended with YaRN method.

Credits to Peng et al. github.com/jquesnelle/yarn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import torch

from .base import RotaryEmbedding
from .base import RotaryEmbeddingBase


class Llama4VisionRotaryEmbedding(RotaryEmbedding):
class Llama4VisionRotaryEmbedding(RotaryEmbeddingBase):
def __init__(
self,
head_size: int,
Expand Down Expand Up @@ -78,10 +78,3 @@ def forward_cuda( # type: ignore[override]
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(query, key)

def forward_hip( # type: ignore[override]
self,
query: torch.Tensor,
key: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(query, key)
22 changes: 2 additions & 20 deletions vllm/model_executor/layers/rotary_embedding/mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from vllm.triton_utils import tl, triton

from .base import RotaryEmbedding
from .base import RotaryEmbeddingBase
from .common import apply_rotary_emb_dispatch
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale

Expand Down Expand Up @@ -199,7 +199,7 @@ def apply_interleaved_rope(x: torch.Tensor, mrope_section: list[int]) -> torch.T
return x_t


class MRotaryEmbedding(RotaryEmbedding):
class MRotaryEmbedding(RotaryEmbeddingBase):
"""Rotary Embedding with Multimodal Sections."""

def __init__(
Expand Down Expand Up @@ -357,24 +357,6 @@ def forward_cuda(
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

def forward_xpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets)

def forward_cpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets)

@staticmethod
def get_next_input_positions(
mrope_position_delta: int,
Expand Down