Skip to content

Commit 474014a

Browse files
author
maleksan85
committed
Initial commit for V1 successfull compilation
1 parent caa2810 commit 474014a

File tree

4 files changed

+6
-3
lines changed

4 files changed

+6
-3
lines changed

vllm/model_executor/layers/activation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ def __init__(self):
7070
from vllm._ipex_ops import ipex_ops
7171
self.op = ipex_ops.silu_and_mul
7272

73-
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
73+
def forward_native(self,
74+
x: torch.Tensor,
75+
scale: Optional[torch.Tensor] = None) -> torch.Tensor:
7476
"""PyTorch-native implementation equivalent to forward()."""
7577
d = x.shape[-1] // 2
7678
return F.silu(x[..., :d]) * x[..., d:]

vllm/model_executor/layers/layernorm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def forward_native(
3939
self,
4040
x: torch.Tensor,
4141
residual: Optional[torch.Tensor] = None,
42+
scale: Optional[torch.Tensor] = None,
4243
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
4344
"""PyTorch-native implementation equivalent to forward()."""
4445
orig_dtype = x.dtype

vllm/model_executor/layers/linear.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from vllm.logger import init_logger
1616
from vllm.model_executor.layers.quantization.base_config import (
1717
QuantizationConfig, QuantizeMethodBase)
18-
from vllm.model_executor.layers.tuned_gemm import tgemm
1918
# yapf: disable
2019
from vllm.model_executor.parameter import (BasevLLMParameter,
2120
BlockQuantScaleParameter,
@@ -138,7 +137,7 @@ def apply(self,
138137
layer: torch.nn.Module,
139138
x: torch.Tensor,
140139
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
141-
return tgemm.mm(x, layer.weight, bias)
140+
return torch.mm(x, torch.transpose(layer.weight, 0, 1))
142141

143142

144143
class LinearBase(torch.nn.Module):

vllm/v1/attention/backends/rocm_attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def forward(
109109
value: torch.Tensor,
110110
kv_cache: torch.Tensor,
111111
attn_metadata: FlashAttentionMetadata,
112+
fp8_out_scale: Optional[torch.Tensor],
112113
output: Optional[torch.Tensor] = None,
113114
) -> torch.Tensor:
114115
"""Forward pass with FlashAttention.

0 commit comments

Comments
 (0)