Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KVCache][Test] Fix TIR attn kernels for uncommon group size #17074

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1181,8 +1181,8 @@ def batch_prefill_paged_kv(

if T.tvm_thread_invariant(batch_idx[0] < batch_size):
b_idx: T.int32 = batch_idx[0]
L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta
H_qo_start: T.int32 = by * group_size
LH_start: T.int32 = tile_id[0] * tile_x
q_indptr_val: T.int32 = q_indptr[b_idx]

cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
Expand Down Expand Up @@ -1212,8 +1212,8 @@ def batch_prefill_paged_kv(
i, j = T.axis.remap("SS", [li, lj])
T.reads()
T.writes()
cur_L = L_start + i // group_size
cur_H_qo = H_qo_start + i % group_size
cur_L = q_indptr_val + (LH_start + i) // group_size
cur_H_qo = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
Q_smem[i, j] = T.if_then_else(
rotary_mode == 1,
Expand Down Expand Up @@ -1282,9 +1282,10 @@ def batch_prefill_paged_kv(
m_prev[i] = m_smem[row]
m_new[i] = m_smem[row]
# mask out of kv_chunk_len S
row_: T.int32 = (LH_start + row) // group_size
for j in T.serial(tile_z):
if _causal_mask(causal,
row=tile_id[0] * L_per_cta + row // group_size,
row=row_,
col=L_kv_start + j,
kv_len=kv_chunk_len[0],
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
Expand All @@ -1297,8 +1298,9 @@ def batch_prefill_paged_kv(
for j in T.serial(tile_z):
# this is to avoid sync inside condition branch
if row < tile_x:
row_: T.int32 = (LH_start + row) // group_size
if _causal_mask(causal,
row=tile_id[0] * L_per_cta + row // group_size,
row=row_,
col=L_kv_start + j,
kv_len=kv_chunk_len[0],
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
Expand Down Expand Up @@ -1330,15 +1332,19 @@ def batch_prefill_paged_kv(
for li, lj in T.grid(tile_x, tile_y):
with T.block("O_store"):
i, j = T.axis.remap("SS", [li, lj])
if L_start + i // group_size < q_indptr[b_idx + 1]:
output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i]
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i]

# Store LSE to gmem
for li in T.grid(tile_x):
with T.block("lse_store"):
i = T.axis.remap("S", [li])
if L_start + i // group_size < q_indptr[b_idx + 1]:
lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i])
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])

# move to next tile
tile_id[0] += NUM_BLKS
Expand Down Expand Up @@ -1688,7 +1694,6 @@ def _attention_prefill_ragged(
bdx = 32
num_warps = 4
tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16
L_per_cta = tile_x // group_size

# Otherwise we would exceed maxComputeWorkgroupStorageSize
if (
Expand Down Expand Up @@ -1784,8 +1789,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branc

if T.tvm_thread_invariant(batch_idx[0] < batch_size):
b_idx: T.int32 = batch_idx[0]
L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta
H_qo_start: T.int32 = by * group_size
LH_start: T.int32 = tile_id[0] * tile_x
q_indptr_val: T.int32 = q_indptr[b_idx]

kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
T.tvm_storage_sync("shared")
Expand All @@ -1809,8 +1814,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branc
i, j = T.axis.remap("SS", [li, lj])
T.reads()
T.writes()
cur_L = L_start + i // group_size
cur_H_qo = H_qo_start + i % group_size
cur_L = q_indptr_val + (LH_start + i) // group_size
cur_H_qo = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
Q_smem[i, j] = T.if_then_else(
rotary_mode == 1,
Expand Down Expand Up @@ -1874,9 +1879,10 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branc
m_prev[i] = m_smem[row]
m_new[i] = m_smem[row]
# mask out of kv_chunk_len S
row_: T.int32 = (LH_start + row) // group_size
for j in T.serial(tile_z):
if _causal_mask(causal,
row=tile_id[0] * L_per_cta + row // group_size,
row=row_,
col=L_kv_start + j,
kv_len=kv_chunk_len[0],
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
Expand All @@ -1889,8 +1895,9 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branc
for j in T.serial(tile_z):
# this is to avoid sync inside condition branch
if row < tile_x:
row_: T.int32 = (LH_start + row) // group_size
if _causal_mask(causal,
row=tile_id[0] * L_per_cta + row // group_size,
row=row_,
col=L_kv_start + j,
kv_len=kv_chunk_len[0],
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
Expand Down Expand Up @@ -1922,15 +1929,19 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branc
for li, lj in T.grid(tile_x, tile_y):
with T.block("O_store"):
i, j = T.axis.remap("SS", [li, lj])
if L_start + i // group_size < q_indptr[b_idx + 1]:
output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i]
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i]

# Store LSE to gmem
for li in T.grid(tile_x):
with T.block("lse_store"):
i = T.axis.remap("S", [li])
if L_start + i // group_size < q_indptr[b_idx + 1]:
lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i])
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])

# move to next tile
tile_id[0] += NUM_BLKS
Expand Down Expand Up @@ -2122,8 +2133,8 @@ def batch_tree_attn( # pylint: disable=too-many-branches

if T.tvm_thread_invariant(batch_idx[0] < batch_size):
b_idx: T.int32 = batch_idx[0]
L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta
H_qo_start: T.int32 = by * group_size
LH_start: T.int32 = tile_id[0] * tile_x
q_indptr_val: T.int32 = q_indptr[b_idx]

kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
T.tvm_storage_sync("shared")
Expand All @@ -2147,8 +2158,8 @@ def batch_tree_attn( # pylint: disable=too-many-branches
i, j = T.axis.remap("SS", [li, lj])
T.reads()
T.writes()
cur_L = L_start + i // group_size
cur_H_qo = H_qo_start + i % group_size
cur_L = q_indptr_val + (LH_start + i) // group_size
cur_H_qo = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
Q_smem[i, j] = T.if_then_else(
rotary_mode == 1,
Expand Down Expand Up @@ -2203,13 +2214,15 @@ def batch_tree_attn( # pylint: disable=too-many-branches
m_prev[i] = m_smem[row]
m_new[i] = m_smem[row]
# mask out of kv_chunk_len S
row_: T.int32 = (LH_start + row) // group_size
for j in T.serial(tile_z):
if _tree_mask(row=tile_id[0] * L_per_cta + row // group_size,
col=L_kv_start + j,
mask_ptr=mask,
offset=mn_indptr[b_idx],
stride=q_indptr[b_idx + 1] - q_indptr[b_idx],
kv_len=kv_chunk_len[0]):
if _tree_mask(
row=row_,
col=L_kv_start + j,
mask_ptr=mask,
offset=mn_indptr[b_idx],
stride=q_indptr[b_idx + 1] - q_indptr[b_idx],
kv_len=kv_chunk_len[0]):
m_new[i] = T.max(m_new[i], S_smem[row, j])
d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])

Expand All @@ -2219,12 +2232,14 @@ def batch_tree_attn( # pylint: disable=too-many-branches
for j in T.serial(tile_z):
# this is to avoid sync inside condition branch
if row < tile_x:
if _tree_mask(row=tile_id[0] * L_per_cta + row // group_size,
col=L_kv_start + j,
mask_ptr=mask,
offset=mn_indptr[b_idx],
stride=q_indptr[b_idx + 1] - q_indptr[b_idx],
kv_len=kv_chunk_len[0]):
row_: T.int32 = (LH_start + row) // group_size
if _tree_mask(
row=row_,
col=L_kv_start + j,
mask_ptr=mask,
offset=mn_indptr[b_idx],
stride=q_indptr[b_idx + 1] - q_indptr[b_idx],
kv_len=kv_chunk_len[0]):
S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
else:
S_smem[row, j] = T.exp2(-5e4 - m_new[i])
Expand Down Expand Up @@ -2253,15 +2268,19 @@ def batch_tree_attn( # pylint: disable=too-many-branches
for li, lj in T.grid(tile_x, tile_y):
with T.block("O_store"):
i, j = T.axis.remap("SS", [li, lj])
if L_start + i // group_size < q_indptr[b_idx + 1]:
output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i]
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i]

# Store LSE to gmem
for li in T.grid(tile_x):
with T.block("lse_store"):
i = T.axis.remap("S", [li])
if L_start + i // group_size < q_indptr[b_idx + 1]:
lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i])
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])

# move to next tile
tile_id[0] += NUM_BLKS
Expand Down
Loading