Skip to content

Commit 34dbe31

Browse files
maleksan85maleksan85
andauthored
V1 rocm support (#469)
* Initial commit for V1 successfull compilation * Small improvement for linear * Small improvement for linear * making use of forward_cuda for all except ROPE in llama --------- Co-authored-by: maleksan85 <maleksan@amd.com>
1 parent 1095cff commit 34dbe31

File tree

3 files changed

+8
-0
lines changed

3 files changed

+8
-0
lines changed

vllm/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3450,6 +3450,10 @@ def __post_init__(self):
34503450
# FIXME(woosuk): Disable inductor to reduce the compilation time
34513451
# and avoid any potential issues with the inductor.
34523452
self.compilation_config.custom_ops = ["none"]
3453+
if current_platform.is_rocm():
3454+
self.compilation_config.custom_ops = [
3455+
"+rms_norm", "+silu_and_mul"
3456+
]
34533457
self.compilation_config.use_cudagraph = True
34543458
self.compilation_config.use_inductor = True
34553459
self.compilation_config.cudagraph_num_of_warmups = 1

vllm/model_executor/layers/tuned_gemm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch.nn.functional as F
99

1010
from vllm import _custom_ops as ops
11+
from vllm import envs
1112
from vllm.envs import VLLM_USE_ROCM_SKINNY_GEMM
1213
from vllm.platforms import current_platform
1314
from vllm.utils import is_mi250, is_navi
@@ -68,6 +69,8 @@ def create_ds(self):
6869
self.solids = solds
6970

7071
def query_sol(self, m, n, k, bias, dtype):
72+
if envs.VLLM_USE_V1:
73+
return 0, 0
7174
return self.solids.get((m, n, k, bias, str(dtype)), (0, 0))
7275

7376
def apply_skinny(self, m, n, k, inp_view, weights):

vllm/v1/attention/backends/rocm_attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def forward(
110110
value: torch.Tensor,
111111
kv_cache: torch.Tensor,
112112
attn_metadata: FlashAttentionMetadata,
113+
fp8_out_scale: Optional[torch.Tensor],
113114
output: Optional[torch.Tensor] = None,
114115
) -> torch.Tensor:
115116
"""Forward pass with FlashAttention.

0 commit comments

Comments
 (0)