Skip to content

Commit bc05166

Browse files
committed
Triton MLA parameter tweak for AMD GPU
1 parent be404be commit bc05166

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

vllm/attention/ops/triton_decode_attention.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ def _decode_att_m_fwd(
178178
page_size,
179179
logit_cap,
180180
):
181-
BLOCK = 64
181+
BLOCK = 64 if not is_hip_ else 8
182+
182183
NUM_KV_SPLITS = num_kv_splits
183184
Lk = k_buffer.shape[-1]
184185
Lv = v_buffer.shape[-1]
@@ -188,7 +189,9 @@ def _decode_att_m_fwd(
188189
grid = (batch, head_num, NUM_KV_SPLITS)
189190
kv_group_num = q.shape[1] // k_buffer.shape[-2]
190191

191-
num_warps = 4 if kv_group_num == 1 else 2
192+
num_warps = 4
193+
if kv_group_num != 1:
194+
num_warps = 1 if is_hip_ else 2
192195

193196
BLOCK_DMODEL = triton.next_power_of_2(Lk)
194197
BLOCK_DV = triton.next_power_of_2(Lv)
@@ -418,14 +421,16 @@ def _decode_grouped_att_m_fwd(
418421
)
419422

420423
extra_kargs = {}
424+
num_stages = 2
421425
if is_hip_:
422-
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
426+
# https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization
423427
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
424428
extra_kargs = {
425-
"waves_per_eu": 4,
429+
"waves_per_eu": 1,
426430
"matrix_instr_nonkdim": 16,
427431
"kpack": 2
428432
}
433+
num_stages = 1
429434

430435
_fwd_grouped_kernel_stage1[grid](
431436
q,
@@ -456,7 +461,7 @@ def _decode_grouped_att_m_fwd(
456461
PAGE_SIZE=page_size,
457462
logit_cap=logit_cap,
458463
num_warps=4,
459-
num_stages=2,
464+
num_stages=num_stages,
460465
Lk=Lk,
461466
Lv=Lv,
462467
**extra_kargs,

0 commit comments

Comments
 (0)