diff --git a/vllm/config.py b/vllm/config.py index 8a839bb6e5ce..d99971670ab7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3450,6 +3450,10 @@ def __post_init__(self): # FIXME(woosuk): Disable inductor to reduce the compilation time # and avoid any potential issues with the inductor. self.compilation_config.custom_ops = ["none"] + if current_platform.is_rocm(): + self.compilation_config.custom_ops = [ + "+rms_norm", "+silu_and_mul" + ] self.compilation_config.use_cudagraph = True self.compilation_config.use_inductor = True self.compilation_config.cudagraph_num_of_warmups = 1 diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index ce3ab80985bd..2f26bf5c365b 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from vllm import _custom_ops as ops +from vllm import envs from vllm.envs import VLLM_USE_ROCM_SKINNY_GEMM from vllm.platforms import current_platform from vllm.utils import is_mi250, is_navi @@ -68,6 +69,8 @@ def create_ds(self): self.solids = solds def query_sol(self, m, n, k, bias, dtype): + if envs.VLLM_USE_V1: + return 0, 0 return self.solids.get((m, n, k, bias, str(dtype)), (0, 0)) def apply_skinny(self, m, n, k, inp_view, weights): diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 640c3b3d4fbb..28093646e9fb 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -110,6 +110,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, + fp8_out_scale: Optional[torch.Tensor], output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention.