diff --git a/examples/flash_attention/example_gqa_bwd.py b/examples/flash_attention/example_gqa_bwd.py index 907a121d2..dd9c8f7c1 100644 --- a/examples/flash_attention/example_gqa_bwd.py +++ b/examples/flash_attention/example_gqa_bwd.py @@ -54,7 +54,9 @@ def flash_fwd( acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, + -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce.py b/examples/flash_attention/example_gqa_bwd_tma_reduce.py index 615c2e191..2af06e4bc 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce.py @@ -59,7 +59,9 @@ def flash_fwd( acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, T.Cast(accum_dtype, -1e30)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, + -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) diff --git a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py index ed07e7d9d..024212499 100644 --- a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py @@ -54,7 +54,9 @@ def flash_fwd( acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, + -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) T.copy(scores_max, scores_max_prev) diff --git a/examples/flash_attention/example_gqa_fwd_bshd.py b/examples/flash_attention/example_gqa_fwd_bshd.py index 4d9d06a4f..3d4bfe455 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd.py +++ b/examples/flash_attention/example_gqa_fwd_bshd.py @@ -96,7 +96,9 @@ def MMA0( acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), + 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro diff --git a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py index 1c1fc12d2..21f5e9a9d 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py @@ -63,7 +63,9 @@ def MMA0( acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), + 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro diff --git a/examples/flash_attention/example_mha_bwd_bhsd.py b/examples/flash_attention/example_mha_bwd_bhsd.py index 1595ae764..8247b2654 100644 --- a/examples/flash_attention/example_mha_bwd_bhsd.py +++ b/examples/flash_attention/example_mha_bwd_bhsd.py @@ -56,7 +56,9 @@ def flash_fwd( acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, + -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) T.copy(scores_max, scores_max_prev) @@ -213,6 +215,8 @@ def flash_bwd( for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. T.copy(dO[bz, bx, k * block_N:(k + 1) * block_N, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) diff --git a/examples/flash_attention/example_mha_bwd.py b/examples/flash_attention/example_mha_bwd_bshd.py similarity index 97% rename from examples/flash_attention/example_mha_bwd.py rename to examples/flash_attention/example_mha_bwd_bshd.py index 543c2c0e7..414061ffb 100644 --- a/examples/flash_attention/example_mha_bwd.py +++ b/examples/flash_attention/example_mha_bwd_bshd.py @@ -52,7 +52,9 @@ def flash_fwd( acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, + -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) @@ -206,6 +208,8 @@ def flash_bwd( for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do) T.clear(dsT) T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @@ -340,7 +344,7 @@ def run1(): parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='Batch size') parser.add_argument('--h', type=int, default=32, help='Number of heads') - parser.add_argument('--n_ctx', type=int, default=1024, help='Context size') + parser.add_argument('--n_ctx', type=int, default=1048, help='Context size') parser.add_argument('--d_head', type=int, default=64, help='Head dimension') parser.add_argument('--causal', type=bool, default=False, help='Causal flag') args = parser.parse_args() diff --git a/examples/flash_attention/example_mha_bwd_wgmma_pipelined.py b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py similarity index 97% rename from examples/flash_attention/example_mha_bwd_wgmma_pipelined.py rename to examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py index 7ad417ef5..e10ef5816 100644 --- a/examples/flash_attention/example_mha_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py @@ -53,7 +53,9 @@ def flash_fwd( acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, + -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared) T.copy(scores_max, scores_max_prev) @@ -193,6 +195,8 @@ def flash_bwd( for i, j in T.Parallel(block_M, block_N): qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + # We don't need to handle OOB positions for non-causal cases, + # since OOB values won't affect other positions here. T.wait_wgmma(0) T.copy(qkT, qkT_cast) T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) diff --git a/examples/flash_attention/example_mha_fwd_bhsd.py b/examples/flash_attention/example_mha_fwd_bhsd.py index f07f7a618..e936cee33 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd.py +++ b/examples/flash_attention/example_mha_fwd_bhsd.py @@ -55,7 +55,9 @@ def MMA0( k_idx = k * block_N + j acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + # We shall fill -inf for OOB positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro @@ -226,7 +228,7 @@ def main( parser.add_argument('--seq_q', type=int, default=256, help='query sequence length') parser.add_argument('--seq_kv', type=int, default=256, help='key/value sequence length') parser.add_argument('--dim', type=int, default=64, help='dim') - parser.add_argument('--is_causal', action='store_true', help='causal') + parser.add_argument('--is_causal', action='store_true', help='causal', default=False) parser.add_argument('--tune', action='store_true', help='tune configs') args = parser.parse_args() main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py index 26167b34b..e1d0130a5 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py @@ -55,7 +55,9 @@ def MMA0( k_idx = k * block_N + j acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + # We shall fill -inf for OOB positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_kv, -T.infinity(acc_s.dtype), 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro diff --git a/examples/flash_attention/example_mha_fwd_bshd.py b/examples/flash_attention/example_mha_fwd_bshd.py index 6a1f707e5..a9268019a 100644 --- a/examples/flash_attention/example_mha_fwd_bshd.py +++ b/examples/flash_attention/example_mha_fwd_bshd.py @@ -49,7 +49,10 @@ def MMA0( acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + # We shall fill -inf for OOB positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), + 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro diff --git a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py index 3928db4c3..d7023a203 100644 --- a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py @@ -49,7 +49,10 @@ def MMA0( acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) else: - T.clear(acc_s) + # We shall fill -inf for OOB positions + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), + 0) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) @T.macro diff --git a/examples/flash_attention/test_example_flash_attention.py b/examples/flash_attention/test_example_flash_attention.py index f4932aee9..b184fc601 100644 --- a/examples/flash_attention/test_example_flash_attention.py +++ b/examples/flash_attention/test_example_flash_attention.py @@ -2,7 +2,7 @@ import example_gqa_bwd import example_gqa_bwd_wgmma_pipelined -import example_mha_bwd +import example_mha_bwd_bshd import example_mha_bwd_bhsd import example_mha_fwd_bhsd_wgmma_pipelined import example_gqa_fwd_bshd @@ -10,7 +10,7 @@ import example_gqa_fwd_bshd_wgmma_pipelined import example_mha_fwd_bshd_wgmma_pipelined import example_mha_fwd_varlen -import example_mha_bwd_wgmma_pipelined +import example_mha_bwd_bshd_wgmma_pipelined import example_mha_fwd_bhsd import example_gqa_bwd_tma_reduce_varlen @@ -33,7 +33,7 @@ def test_example_gqa_bwd_wgmma_pipelined(): @tilelang.testing.requires_cuda def test_example_mha_bwd(): - example_mha_bwd.main( + example_mha_bwd_bshd.main( BATCH=1, H=16, N_CTX=512, @@ -56,7 +56,7 @@ def test_example_mha_bwd_bhsd(): @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) def test_example_mha_bwd_wgmma_pipelined(): - example_mha_bwd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False) + example_mha_bwd_bshd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False) @tilelang.testing.requires_cuda