Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions vllm/attention/ops/triton_decode_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def _decode_att_m_fwd(
page_size,
logit_cap,
):
BLOCK = 64
BLOCK = 64 if not is_hip_ else 8

NUM_KV_SPLITS = num_kv_splits
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
Expand All @@ -188,7 +189,9 @@ def _decode_att_m_fwd(
grid = (batch, head_num, NUM_KV_SPLITS)
kv_group_num = q.shape[1] // k_buffer.shape[-2]

num_warps = 4 if kv_group_num == 1 else 2
num_warps = 4
if kv_group_num != 1:
num_warps = 1 if is_hip_ else 2

BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DV = triton.next_power_of_2(Lv)
Expand Down Expand Up @@ -418,14 +421,16 @@ def _decode_grouped_att_m_fwd(
)

extra_kargs = {}
num_stages = 2
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {
"waves_per_eu": 4,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
num_stages = 1

_fwd_grouped_kernel_stage1[grid](
q,
Expand Down Expand Up @@ -456,7 +461,7 @@ def _decode_grouped_att_m_fwd(
PAGE_SIZE=page_size,
logit_cap=logit_cap,
num_warps=4,
num_stages=2,
num_stages=num_stages,
Lk=Lk,
Lv=Lv,
**extra_kargs,
Expand Down