Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/flash_attention/example_gqa_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def flash_fwd(
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
Expand Down
4 changes: 3 additions & 1 deletion examples/flash_attention/example_gqa_bwd_tma_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def flash_fwd(
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
T.Cast(accum_dtype, -1e30))
else:
T.clear(acc_s)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
Expand Down
4 changes: 3 additions & 1 deletion examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def flash_fwd(
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
Expand Down
4 changes: 3 additions & 1 deletion examples/flash_attention/example_gqa_fwd_bshd.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def MMA0(
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype),
0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

@T.macro
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def MMA0(
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype),
0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

@T.macro
Expand Down
6 changes: 5 additions & 1 deletion examples/flash_attention/example_mha_bwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def flash_fwd(
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(scores_max, scores_max_prev)
Expand Down Expand Up @@ -213,6 +215,8 @@ def flash_bwd(
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
# We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here.
T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def flash_fwd(
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.copy(scores_max, scores_max_prev)
Expand Down Expand Up @@ -206,6 +208,8 @@ def flash_bwd(
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
# We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here.
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
Expand Down Expand Up @@ -340,7 +344,7 @@ def run1():
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='Batch size')
parser.add_argument('--h', type=int, default=32, help='Number of heads')
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
parser.add_argument('--n_ctx', type=int, default=1048, help='Context size')
parser.add_argument('--d_head', type=int, default=64, help='Head dimension')
parser.add_argument('--causal', type=bool, default=False, help='Causal flag')
args = parser.parse_args()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def flash_fwd(
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
-T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.copy(scores_max, scores_max_prev)
Expand Down Expand Up @@ -193,6 +195,8 @@ def flash_bwd(
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
# We don't need to handle OOB positions for non-causal cases,
# since OOB values won't affect other positions here.
T.wait_wgmma(0)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1)
Expand Down
6 changes: 4 additions & 2 deletions examples/flash_attention/example_mha_fwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def MMA0(
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
# We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

@T.macro
Expand Down Expand Up @@ -226,7 +228,7 @@ def main(
parser.add_argument('--seq_q', type=int, default=256, help='query sequence length')
parser.add_argument('--seq_kv', type=int, default=256, help='key/value sequence length')
parser.add_argument('--dim', type=int, default=64, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--is_causal', action='store_true', help='causal', default=False)
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune)
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def MMA0(
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
# We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

@T.macro
Expand Down
5 changes: 4 additions & 1 deletion examples/flash_attention/example_mha_fwd_bshd.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def MMA0(
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
# We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype),
0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

@T.macro
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def MMA0(
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
# We shall fill -inf for OOB positions
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype),
0)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

@T.macro
Expand Down
8 changes: 4 additions & 4 deletions examples/flash_attention/test_example_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import example_gqa_bwd
import example_gqa_bwd_wgmma_pipelined
import example_mha_bwd
import example_mha_bwd_bshd
import example_mha_bwd_bhsd
import example_mha_fwd_bhsd_wgmma_pipelined
import example_gqa_fwd_bshd
import example_mha_fwd_bshd
import example_gqa_fwd_bshd_wgmma_pipelined
import example_mha_fwd_bshd_wgmma_pipelined
import example_mha_fwd_varlen
import example_mha_bwd_wgmma_pipelined
import example_mha_bwd_bshd_wgmma_pipelined
import example_mha_fwd_bhsd
import example_gqa_bwd_tma_reduce_varlen

Expand All @@ -33,7 +33,7 @@ def test_example_gqa_bwd_wgmma_pipelined():

@tilelang.testing.requires_cuda
def test_example_mha_bwd():
example_mha_bwd.main(
example_mha_bwd_bshd.main(
BATCH=1,
H=16,
N_CTX=512,
Expand All @@ -56,7 +56,7 @@ def test_example_mha_bwd_bhsd():
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_mha_bwd_wgmma_pipelined():
example_mha_bwd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False)
example_mha_bwd_bshd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False)


@tilelang.testing.requires_cuda
Expand Down
Loading