From 3d217606b5170657ffcfdbac32399f344ccde5be Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 8 Sep 2025 15:51:56 +0300 Subject: [PATCH 1/4] [Bugfix] Fix platform-specific routing in CustomOp implementations Signed-off-by: Konrad Zawora --- vllm/model_executor/layers/activation.py | 5 +++- vllm/model_executor/layers/fused_moe/layer.py | 9 +++++++- .../rotary_embedding/deepseek_scaling_rope.py | 11 ++++++++- .../rotary_embedding/dual_chunk_rope.py | 11 ++++++++- .../rotary_embedding/ernie45_vl_rope.py | 10 +++++++- .../rotary_embedding/llama4_vision_rope.py | 9 +++++++- .../layers/rotary_embedding/mrope.py | 23 ------------------- .../phi3_long_rope_scaled_rope.py | 11 ++++++++- .../layers/vocab_parallel_embedding.py | 2 +- 9 files changed, 60 insertions(+), 31 deletions(-) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 319fa938d400..235df1a77c5c 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -454,7 +454,7 @@ def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor: ) return result.view(original_shape) - def forward(self, input: torch.Tensor) -> torch.Tensor: + def forward_native(self, input: torch.Tensor) -> torch.Tensor: if self._xielu_cuda_obj is not None and input.is_cuda: if not torch._dynamo.is_compiling(): return self._xielu_cuda_fn(input) @@ -464,6 +464,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ) return self._xielu_python(input) + def forward_cuda(self, input: torch.Tensor) -> torch.Tensor: + return self.forward_native(input) + class ScaledActivation(nn.Module): """An activation function with post-scale parameters. diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 272ad3956537..0a3cd1e84177 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1588,7 +1588,7 @@ def maybe_all_reduce_tensor_model_parallel( else: return tensor_model_parallel_all_reduce(final_hidden_states) - def forward( + def forward_native( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -1622,6 +1622,13 @@ def forward( return (shared_output[..., :og_hidden_states], fused_output[..., :og_hidden_states]) + def forward_cuda( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + return self.forward_native(hidden_states, router_logits) + def forward_impl_chunked( self, full_hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index cd888b733426..7ac2e4bb6c34 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -88,7 +88,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: cache = torch.cat((cos, sin), dim=-1) return cache - def forward( + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -129,3 +129,12 @@ def forward( query = query_rot key = key_rot return query, key + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key, offsets) diff --git a/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py b/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py index 3d8da0fa9d8f..27e41dd0fa97 100644 --- a/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py @@ -111,7 +111,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: device=self.device) return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache - def forward( + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -161,6 +161,15 @@ def forward( dim=-1) return query, key + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.forward_native(positions, query, key, offsets) + def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass): cos, sin = cos_sin.chunk(2, dim=-1) if self.is_neox_style: diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py index 05322e56f262..2ca670b74e1c 100644 --- a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -12,7 +12,7 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): """3D rotary positional embedding. 3D is t:time h:height w:width""" - def forward( + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -70,3 +70,11 @@ def forward( self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key) \ No newline at end of file diff --git a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py index 415a85ab698b..e6821753aa9e 100644 --- a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py @@ -53,7 +53,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) return cache - def forward( + def forward_native( self, query: torch.Tensor, key: Optional[torch.Tensor] = None, @@ -72,3 +72,10 @@ def forward( query_out = torch.view_as_real(query_ * freqs_ci).flatten(3) key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) return query_out.type_as(query), key_out.type_as(key) + + def forward_cuda( + self, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(query, key) \ No newline at end of file diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 0ab4bc5375da..0acb5ea74245 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -8,7 +8,6 @@ import torch from transformers import PretrainedConfig -from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from .base import RotaryEmbedding @@ -202,28 +201,6 @@ def __init__( if self.mrope_section: assert sum(self.mrope_section) == rotary_dim // 2 - self.use_triton = current_platform.is_cuda_alike() - - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """MRope forward. - - Args: - positions: - [num_tokens,] (text only) or - [3, num_tokens] (T/H/W positions with multimodal inputs) - query: [num_tokens, num_heads * head_size] - key: [num_tokens, num_kv_heads * head_size] - """ - if self.use_triton: - return self.forward_cuda(positions, query, key) - else: - return self.forward_native(positions, query, key) - def forward_native( self, positions: torch.Tensor, diff --git a/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py index 9c36d633e2a9..241b96774726 100644 --- a/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py @@ -93,7 +93,7 @@ def _compute_cos_sin_cache( cache = torch.cat((cos, sin), dim=-1) return cache - def forward( + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -127,3 +127,12 @@ def forward( key = torch.cat((key_rot, key_pass), dim=-1) return query.flatten(-2), key.flatten(-2) + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + offsets: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return self.forward_native(positions, query, key) \ No newline at end of file diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index c92a7978195b..fb3271db1623 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -399,7 +399,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param[:loaded_weight.shape[0]].data.copy_(loaded_weight) param[loaded_weight.shape[0]:].data.fill_(0) - def forward(self, input_): + def forward_native(self, input_): if self.tp_size > 1: # Build the mask. masked_input, input_mask = get_masked_input_and_mask( From 1964d724d1addef5c69084bfa8f35eb91797d4e1 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 8 Sep 2025 15:56:43 +0300 Subject: [PATCH 2/4] missing forward_cuda in vocab parallel embedding Signed-off-by: Konrad Zawora --- vllm/model_executor/layers/vocab_parallel_embedding.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index fb3271db1623..b882b1a9a4af 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -420,6 +420,9 @@ def forward_native(self, input_): output = tensor_model_parallel_all_reduce(output_parallel) return output + def forward_cuda(self, input_): + return self.forward_native(input_) + def extra_repr(self) -> str: s = f"num_embeddings={self.num_embeddings_per_partition}" s += f", embedding_dim={self.embedding_dim}" From 8e1a529d07797abfcb9ce801de2cc3c8b2eeb1c2 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Mon, 8 Sep 2025 17:17:56 +0300 Subject: [PATCH 3/4] phi3 rope fix Signed-off-by: Konrad Zawora --- .../rotary_embedding/phi3_long_rope_scaled_rope.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py index 241b96774726..9c36d633e2a9 100644 --- a/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py @@ -93,7 +93,7 @@ def _compute_cos_sin_cache( cache = torch.cat((cos, sin), dim=-1) return cache - def forward_native( + def forward( self, positions: torch.Tensor, query: torch.Tensor, @@ -127,12 +127,3 @@ def forward_native( key = torch.cat((key_rot, key_pass), dim=-1) return query.flatten(-2), key.flatten(-2) - - def forward_cuda( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - return self.forward_native(positions, query, key) \ No newline at end of file From d9384c53224d227d007116c5db29a4f10951e0b3 Mon Sep 17 00:00:00 2001 From: Konrad Zawora Date: Thu, 11 Sep 2025 15:24:05 +0300 Subject: [PATCH 4/4] ignore Liskov substitution principle Signed-off-by: Konrad Zawora --- .../layers/rotary_embedding/ernie45_vl_rope.py | 4 ++-- .../layers/rotary_embedding/llama4_vision_rope.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py index 2ca670b74e1c..4960c20f4060 100644 --- a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -12,7 +12,7 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): """3D rotary positional embedding. 3D is t:time h:height w:width""" - def forward_native( + def forward_native( # type: ignore[override] self, positions: torch.Tensor, query: torch.Tensor, @@ -71,7 +71,7 @@ def forward_native( key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key - def forward_cuda( + def forward_cuda( # type: ignore[override] self, positions: torch.Tensor, query: torch.Tensor, diff --git a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py index e6821753aa9e..37ead43e22bc 100644 --- a/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py @@ -53,7 +53,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)) return cache - def forward_native( + def forward_native( # type: ignore[override] self, query: torch.Tensor, key: Optional[torch.Tensor] = None, @@ -73,9 +73,9 @@ def forward_native( key_out = torch.view_as_real(key_ * freqs_ci).flatten(3) return query_out.type_as(query), key_out.type_as(key) - def forward_cuda( + def forward_cuda( # type: ignore[override] self, query: torch.Tensor, key: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - return self.forward_native(query, key) \ No newline at end of file + return self.forward_native(query, key)