Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
)
return result.view(original_shape)

def forward(self, input: torch.Tensor) -> torch.Tensor:
def forward_native(self, input: torch.Tensor) -> torch.Tensor:
if self._xielu_cuda_obj is not None and input.is_cuda:
if not torch._dynamo.is_compiling():
return self._xielu_cuda_fn(input)
Expand All @@ -464,6 +464,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
)
return self._xielu_python(input)

def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_native(input)


class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
Expand Down
9 changes: 8 additions & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,7 +1593,7 @@ def maybe_all_reduce_tensor_model_parallel(
else:
return tensor_model_parallel_all_reduce(final_hidden_states)

def forward(
def forward_native(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
Expand Down Expand Up @@ -1627,6 +1627,13 @@ def forward(
return (shared_output[..., :og_hidden_states],
fused_output[..., :og_hidden_states])

def forward_cuda(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
return self.forward_native(hidden_states, router_logits)

def forward_impl_chunked(
self,
full_hidden_states: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
cache = torch.cat((cos, sin), dim=-1)
return cache

def forward(
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
Expand Down Expand Up @@ -129,3 +129,12 @@ def forward(
query = query_rot
key = key_rot
return query, key

def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
return self.forward_native(positions, query, key, offsets)
11 changes: 10 additions & 1 deletion vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
device=self.device)
return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache

def forward(
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
Expand Down Expand Up @@ -161,6 +161,15 @@ def forward(
dim=-1)
return query, key

def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(positions, query, key, offsets)

def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
Expand Down
10 changes: 9 additions & 1 deletion vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
"""3D rotary positional embedding. 3D is t:time h:height w:width"""

def forward(
def forward_native( # type: ignore[override]
self,
positions: torch.Tensor,
query: torch.Tensor,
Expand Down Expand Up @@ -70,3 +70,11 @@ def forward(
self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

def forward_cuda( # type: ignore[override]
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
return self.forward_native(positions, query, key)
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
return cache

def forward(
def forward_native( # type: ignore[override]
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
Expand All @@ -72,3 +72,10 @@ def forward(
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
return query_out.type_as(query), key_out.type_as(key)

def forward_cuda( # type: ignore[override]
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
return self.forward_native(query, key)
23 changes: 0 additions & 23 deletions vllm/model_executor/layers/rotary_embedding/mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
from transformers import PretrainedConfig

from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton

from .base import RotaryEmbedding
Expand Down Expand Up @@ -202,28 +201,6 @@ def __init__(
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2

self.use_triton = current_platform.is_cuda_alike()

def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""MRope forward.
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
if self.use_triton:
return self.forward_cuda(positions, query, key)
else:
return self.forward_native(positions, query, key)

def forward_native(
self,
positions: torch.Tensor,
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0]:].data.fill_(0)

def forward(self, input_):
def forward_native(self, input_):
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = get_masked_input_and_mask(
Expand All @@ -420,6 +420,9 @@ def forward(self, input_):
output = tensor_model_parallel_all_reduce(output_parallel)
return output

def forward_cuda(self, input_):
return self.forward_native(input_)

def extra_repr(self) -> str:
s = f"num_embeddings={self.num_embeddings_per_partition}"
s += f", embedding_dim={self.embedding_dim}"
Expand Down