From 474014a8ae53f0aa66e1afcc628ba4b69a623eb5 Mon Sep 17 00:00:00 2001 From: maleksan85 Date: Mon, 10 Mar 2025 18:39:30 +0000 Subject: [PATCH 1/4] Initial commit for V1 successfull compilation --- vllm/model_executor/layers/activation.py | 4 +++- vllm/model_executor/layers/layernorm.py | 1 + vllm/model_executor/layers/linear.py | 3 +-- vllm/v1/attention/backends/rocm_attn.py | 1 + 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 6087cd76de98..35b8123bd731 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -70,7 +70,9 @@ def __init__(self): from vllm._ipex_ops import ipex_ops self.op = ipex_ops.silu_and_mul - def forward_native(self, x: torch.Tensor) -> torch.Tensor: + def forward_native(self, + x: torch.Tensor, + scale: Optional[torch.Tensor] = None) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 93ea30938ff0..9654e811bc71 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -39,6 +39,7 @@ def forward_native( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" orig_dtype = x.dtype diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 079f6184e94b..d43fb4fa882f 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -15,7 +15,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.tuned_gemm import tgemm # yapf: disable from vllm.model_executor.parameter import (BasevLLMParameter, BlockQuantScaleParameter, @@ -138,7 +137,7 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return tgemm.mm(x, layer.weight, bias) + return torch.mm(x, torch.transpose(layer.weight, 0, 1)) class LinearBase(torch.nn.Module): diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index a625d99f4a15..56798f09aa4b 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -109,6 +109,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, + fp8_out_scale: Optional[torch.Tensor], output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. From 622f1c432a4b12853d180cb18cd487d1ffcfe9fc Mon Sep 17 00:00:00 2001 From: maleksan85 Date: Mon, 10 Mar 2025 19:29:48 +0000 Subject: [PATCH 2/4] Small improvement for linear --- vllm/model_executor/layers/linear.py | 3 ++- vllm/model_executor/layers/tuned_gemm.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index d43fb4fa882f..079f6184e94b 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -15,6 +15,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.tuned_gemm import tgemm # yapf: disable from vllm.model_executor.parameter import (BasevLLMParameter, BlockQuantScaleParameter, @@ -137,7 +138,7 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return torch.mm(x, torch.transpose(layer.weight, 0, 1)) + return tgemm.mm(x, layer.weight, bias) class LinearBase(torch.nn.Module): diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index ce3ab80985bd..7685b8f87f5a 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -8,12 +8,13 @@ import torch.nn.functional as F from vllm import _custom_ops as ops +from vllm import envs from vllm.envs import VLLM_USE_ROCM_SKINNY_GEMM from vllm.platforms import current_platform from vllm.utils import is_mi250, is_navi support_tuned_gemms = False -if current_platform.is_rocm(): +if current_platform.is_rocm() and not envs.VLLM_USE_V1: import vllm._gradlib_C # noqa: F401 support_tuned_gemms = True From 3d859889901a33203a5d2c9923ddc832e98dcd88 Mon Sep 17 00:00:00 2001 From: maleksan85 Date: Mon, 10 Mar 2025 20:03:47 +0000 Subject: [PATCH 3/4] Small improvement for linear --- vllm/model_executor/layers/tuned_gemm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 7685b8f87f5a..2f26bf5c365b 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -14,7 +14,7 @@ from vllm.utils import is_mi250, is_navi support_tuned_gemms = False -if current_platform.is_rocm() and not envs.VLLM_USE_V1: +if current_platform.is_rocm(): import vllm._gradlib_C # noqa: F401 support_tuned_gemms = True @@ -69,6 +69,8 @@ def create_ds(self): self.solids = solds def query_sol(self, m, n, k, bias, dtype): + if envs.VLLM_USE_V1: + return 0, 0 return self.solids.get((m, n, k, bias, str(dtype)), (0, 0)) def apply_skinny(self, m, n, k, inp_view, weights): From 75a2709d64a6bebd93b282903bf2d8ca1335ff58 Mon Sep 17 00:00:00 2001 From: maleksan85 Date: Mon, 10 Mar 2025 22:06:14 +0000 Subject: [PATCH 4/4] making use of forward_cuda for all except ROPE in llama --- vllm/config.py | 4 ++++ vllm/model_executor/layers/activation.py | 4 +--- vllm/model_executor/layers/layernorm.py | 1 - 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 818138ee38ff..c0345710775c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3436,6 +3436,10 @@ def __post_init__(self): # FIXME(woosuk): Disable inductor to reduce the compilation time # and avoid any potential issues with the inductor. self.compilation_config.custom_ops = ["none"] + if current_platform.is_rocm(): + self.compilation_config.custom_ops = [ + "+rms_norm", "+silu_and_mul" + ] self.compilation_config.use_cudagraph = True self.compilation_config.use_inductor = True self.compilation_config.cudagraph_num_of_warmups = 1 diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 35b8123bd731..6087cd76de98 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -70,9 +70,7 @@ def __init__(self): from vllm._ipex_ops import ipex_ops self.op = ipex_ops.silu_and_mul - def forward_native(self, - x: torch.Tensor, - scale: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 9654e811bc71..93ea30938ff0 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -39,7 +39,6 @@ def forward_native( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """PyTorch-native implementation equivalent to forward().""" orig_dtype = x.dtype