Skip to content

Commit 4aa2389

Browse files
[Bugfix] Fix platform-specific routing in CustomOp implementations (#24444)
Signed-off-by: Konrad Zawora <kzawora@habana.ai>
1 parent 1fdd5c4 commit 4aa2389

File tree

8 files changed

+53
-30
lines changed

8 files changed

+53
-30
lines changed

vllm/model_executor/layers/activation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
454454
)
455455
return result.view(original_shape)
456456

457-
def forward(self, input: torch.Tensor) -> torch.Tensor:
457+
def forward_native(self, input: torch.Tensor) -> torch.Tensor:
458458
if self._xielu_cuda_obj is not None and input.is_cuda:
459459
if not torch._dynamo.is_compiling():
460460
return self._xielu_cuda_fn(input)
@@ -464,6 +464,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
464464
)
465465
return self._xielu_python(input)
466466

467+
def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
468+
return self.forward_native(input)
469+
467470

468471
class ScaledActivation(nn.Module):
469472
"""An activation function with post-scale parameters.

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1593,7 +1593,7 @@ def maybe_all_reduce_tensor_model_parallel(
15931593
else:
15941594
return tensor_model_parallel_all_reduce(final_hidden_states)
15951595

1596-
def forward(
1596+
def forward_native(
15971597
self,
15981598
hidden_states: torch.Tensor,
15991599
router_logits: torch.Tensor,
@@ -1627,6 +1627,13 @@ def forward(
16271627
return (shared_output[..., :og_hidden_states],
16281628
fused_output[..., :og_hidden_states])
16291629

1630+
def forward_cuda(
1631+
self,
1632+
hidden_states: torch.Tensor,
1633+
router_logits: torch.Tensor,
1634+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
1635+
return self.forward_native(hidden_states, router_logits)
1636+
16301637
def forward_impl_chunked(
16311638
self,
16321639
full_hidden_states: torch.Tensor,

vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
8888
cache = torch.cat((cos, sin), dim=-1)
8989
return cache
9090

91-
def forward(
91+
def forward_native(
9292
self,
9393
positions: torch.Tensor,
9494
query: torch.Tensor,
@@ -129,3 +129,12 @@ def forward(
129129
query = query_rot
130130
key = key_rot
131131
return query, key
132+
133+
def forward_cuda(
134+
self,
135+
positions: torch.Tensor,
136+
query: torch.Tensor,
137+
key: Optional[torch.Tensor] = None,
138+
offsets: Optional[torch.Tensor] = None,
139+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
140+
return self.forward_native(positions, query, key, offsets)

vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
111111
device=self.device)
112112
return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
113113

114-
def forward(
114+
def forward_native(
115115
self,
116116
positions: torch.Tensor,
117117
query: torch.Tensor,
@@ -161,6 +161,15 @@ def forward(
161161
dim=-1)
162162
return query, key
163163

164+
def forward_cuda(
165+
self,
166+
positions: torch.Tensor,
167+
query: torch.Tensor,
168+
key: torch.Tensor,
169+
offsets: Optional[torch.Tensor] = None,
170+
) -> tuple[torch.Tensor, torch.Tensor]:
171+
return self.forward_native(positions, query, key, offsets)
172+
164173
def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
165174
cos, sin = cos_sin.chunk(2, dim=-1)
166175
if self.is_neox_style:

vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
1313
"""3D rotary positional embedding. 3D is t:time h:height w:width"""
1414

15-
def forward(
15+
def forward_native( # type: ignore[override]
1616
self,
1717
positions: torch.Tensor,
1818
query: torch.Tensor,
@@ -70,3 +70,11 @@ def forward(
7070
self.is_neox_style)
7171
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
7272
return query, key
73+
74+
def forward_cuda( # type: ignore[override]
75+
self,
76+
positions: torch.Tensor,
77+
query: torch.Tensor,
78+
key: Optional[torch.Tensor] = None,
79+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
80+
return self.forward_native(positions, query, key)

vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
5353
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
5454
return cache
5555

56-
def forward(
56+
def forward_native( # type: ignore[override]
5757
self,
5858
query: torch.Tensor,
5959
key: Optional[torch.Tensor] = None,
@@ -72,3 +72,10 @@ def forward(
7272
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
7373
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
7474
return query_out.type_as(query), key_out.type_as(key)
75+
76+
def forward_cuda( # type: ignore[override]
77+
self,
78+
query: torch.Tensor,
79+
key: Optional[torch.Tensor] = None,
80+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
81+
return self.forward_native(query, key)

vllm/model_executor/layers/rotary_embedding/mrope.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch
99
from transformers import PretrainedConfig
1010

11-
from vllm.platforms import current_platform
1211
from vllm.triton_utils import tl, triton
1312

1413
from .base import RotaryEmbedding
@@ -202,28 +201,6 @@ def __init__(
202201
if self.mrope_section:
203202
assert sum(self.mrope_section) == rotary_dim // 2
204203

205-
self.use_triton = current_platform.is_cuda_alike()
206-
207-
def forward(
208-
self,
209-
positions: torch.Tensor,
210-
query: torch.Tensor,
211-
key: Optional[torch.Tensor] = None,
212-
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
213-
"""MRope forward.
214-
215-
Args:
216-
positions:
217-
[num_tokens,] (text only) or
218-
[3, num_tokens] (T/H/W positions with multimodal inputs)
219-
query: [num_tokens, num_heads * head_size]
220-
key: [num_tokens, num_kv_heads * head_size]
221-
"""
222-
if self.use_triton:
223-
return self.forward_cuda(positions, query, key)
224-
else:
225-
return self.forward_native(positions, query, key)
226-
227204
def forward_native(
228205
self,
229206
positions: torch.Tensor,

vllm/model_executor/layers/vocab_parallel_embedding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
399399
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
400400
param[loaded_weight.shape[0]:].data.fill_(0)
401401

402-
def forward(self, input_):
402+
def forward_native(self, input_):
403403
if self.tp_size > 1:
404404
# Build the mask.
405405
masked_input, input_mask = get_masked_input_and_mask(
@@ -420,6 +420,9 @@ def forward(self, input_):
420420
output = tensor_model_parallel_all_reduce(output_parallel)
421421
return output
422422

423+
def forward_cuda(self, input_):
424+
return self.forward_native(input_)
425+
423426
def extra_repr(self) -> str:
424427
s = f"num_embeddings={self.num_embeddings_per_partition}"
425428
s += f", embedding_dim={self.embedding_dim}"

0 commit comments

Comments
 (0)