Skip to content

Commit 4ee7b24

Browse files
committed
Merge remote-tracking branch 'origin/main' into fix-int64
2 parents dbebd7e + 716dbef commit 4ee7b24

35 files changed

+2774
-628
lines changed

3rdparty/composable_kernel

Submodule composable_kernel updated 3710 files

examples/deepseek_nsa/example_tilelang_nsa_bwd.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def native_sparse_attention(
106106
T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared)
107107

108108
if is_causal:
109-
for i, j in T.Parallel(G, BS):
110-
acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0,
109+
for k, j in T.Parallel(G, BS):
110+
acc_s[k, j] = T.if_then_else(i_t >= (i_s + j), 0,
111111
-T.infinity(acc_s.dtype))
112112
else:
113113
T.clear(acc_s)
@@ -124,18 +124,18 @@ def native_sparse_attention(
124124
T.copy(scores_max, scores_max_prev)
125125
T.fill(scores_max, -T.infinity(accum_dtype))
126126
T.reduce_max(acc_s, scores_max, dim=1, clear=True)
127-
for i in T.Parallel(G):
128-
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
129-
for i, j in T.Parallel(G, BS):
130-
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
127+
for k in T.Parallel(G):
128+
scores_scale[k] = T.exp2(scores_max_prev[k] * scale - scores_max[k] * scale)
129+
for k, j in T.Parallel(G, BS):
130+
acc_s[k, j] = T.exp2(acc_s[k, j] * scale - scores_max[k] * scale)
131131
T.reduce_sum(acc_s, scores_sum, dim=1)
132-
for i in T.Parallel(G):
133-
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
132+
for k in T.Parallel(G):
133+
logsum[k] = logsum[k] * scores_scale[k] + scores_sum[k]
134134
T.copy(acc_s, acc_s_cast)
135135

136136
# Rescale
137-
for i, j in T.Parallel(G, BV):
138-
acc_o[i, j] *= scores_scale[i]
137+
for k, j in T.Parallel(G, BV):
138+
acc_o[k, j] *= scores_scale[k]
139139

140140
# V * softmax(Q * K)
141141
T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared)
@@ -465,8 +465,8 @@ def flash_bwd_dqkv(
465465
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
466466
# [G]
467467
T.copy(Delta_slc[i_b, i, i_h * G:(i_h + 1) * G], delta)
468-
for i, j in T.Parallel(BS, G):
469-
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
468+
for _i, _j in T.Parallel(BS, G):
469+
dsT_cast[_i, _j] = qkT[_i, _j] * (dsT[_i, _j] - delta[_j]) * sm_scale
470470

471471
# [BS, G] @ [G, BK] -> [BS, BK]
472472
T.gemm(dsT_cast, Q_shared, dk, policy=T.GemmWarpPolicy.FullRow)

examples/flash_attention/example_gqa_bwd.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def flash_fwd(
5454
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
5555
-T.infinity(acc_s.dtype))
5656
else:
57-
T.clear(acc_s)
57+
for i, j in T.Parallel(block_M, block_N):
58+
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
59+
-T.infinity(acc_s.dtype), 0)
5860
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
5961
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
6062
T.copy(scores_max, scores_max_prev)

examples/flash_attention/example_gqa_bwd_tma_reduce.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def flash_fwd(
5959
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
6060
T.Cast(accum_dtype, -1e30))
6161
else:
62-
T.clear(acc_s)
62+
for i, j in T.Parallel(block_M, block_N):
63+
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
64+
-T.infinity(acc_s.dtype), 0)
6365
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
6466
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
6567
T.copy(scores_max, scores_max_prev)

examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def flash_fwd(
5454
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
5555
-T.infinity(acc_s.dtype))
5656
else:
57-
T.clear(acc_s)
57+
for i, j in T.Parallel(block_M, block_N):
58+
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
59+
-T.infinity(acc_s.dtype), 0)
5860
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
5961
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
6062
T.copy(scores_max, scores_max_prev)

examples/flash_attention/example_gqa_fwd_bshd.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ def MMA0(
9696
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
9797
-T.infinity(acc_s.dtype))
9898
else:
99-
T.clear(acc_s)
99+
for i, j in T.Parallel(block_M, block_N):
100+
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype),
101+
0)
100102
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
101103

102104
@T.macro

examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ def MMA0(
6363
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
6464
-T.infinity(acc_s.dtype))
6565
else:
66-
T.clear(acc_s)
66+
for i, j in T.Parallel(block_M, block_N):
67+
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype),
68+
0)
6769
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
6870

6971
@T.macro

examples/flash_attention/example_mha_bwd_bhsd.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ def flash_fwd(
5656
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
5757
-T.infinity(acc_s.dtype))
5858
else:
59-
T.clear(acc_s)
59+
for i, j in T.Parallel(block_M, block_N):
60+
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
61+
-T.infinity(acc_s.dtype), 0)
6062
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
6163
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
6264
T.copy(scores_max, scores_max_prev)
@@ -213,6 +215,8 @@ def flash_bwd(
213215
for i, j in T.Parallel(block_M, block_N):
214216
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
215217
0)
218+
# We don't need to handle OOB positions for non-causal cases,
219+
# since OOB values won't affect other positions here.
216220
T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], do)
217221
T.clear(dsT)
218222
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

examples/flash_attention/example_mha_bwd.py renamed to examples/flash_attention/example_mha_bwd_bshd.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def flash_fwd(
5252
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
5353
-T.infinity(acc_s.dtype))
5454
else:
55-
T.clear(acc_s)
55+
for i, j in T.Parallel(block_M, block_N):
56+
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
57+
-T.infinity(acc_s.dtype), 0)
5658
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
5759
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
5860
T.copy(scores_max, scores_max_prev)
@@ -206,6 +208,8 @@ def flash_bwd(
206208
for i, j in T.Parallel(block_M, block_N):
207209
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
208210
0)
211+
# We don't need to handle OOB positions for non-causal cases,
212+
# since OOB values won't affect other positions here.
209213
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
210214
T.clear(dsT)
211215
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@@ -340,7 +344,7 @@ def run1():
340344
parser = argparse.ArgumentParser()
341345
parser.add_argument('--batch', type=int, default=8, help='Batch size')
342346
parser.add_argument('--h', type=int, default=32, help='Number of heads')
343-
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
347+
parser.add_argument('--n_ctx', type=int, default=1048, help='Context size')
344348
parser.add_argument('--d_head', type=int, default=64, help='Head dimension')
345349
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
346350
args = parser.parse_args()

examples/flash_attention/example_mha_bwd_wgmma_pipelined.py renamed to examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def flash_fwd(
5353
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
5454
-T.infinity(acc_s.dtype))
5555
else:
56-
T.clear(acc_s)
56+
for i, j in T.Parallel(block_M, block_N):
57+
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
58+
-T.infinity(acc_s.dtype), 0)
5759
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
5860
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
5961
T.copy(scores_max, scores_max_prev)
@@ -193,6 +195,8 @@ def flash_bwd(
193195
for i, j in T.Parallel(block_M, block_N):
194196
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
195197
0)
198+
# We don't need to handle OOB positions for non-causal cases,
199+
# since OOB values won't affect other positions here.
196200
T.wait_wgmma(0)
197201
T.copy(qkT, qkT_cast)
198202
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)

0 commit comments

Comments
 (0)