Skip to content

Commit ae10dd4

Browse files
committed
pre-commit
1 parent 55cc841 commit ae10dd4

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

examples/flash_attention/example_mha_bwd_bshd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def flash_fwd(
5353
-T.infinity(acc_s.dtype))
5454
else:
5555
for i, j in T.Parallel(block_M, block_N):
56-
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
56+
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
57+
-T.infinity(acc_s.dtype), 0)
5758
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
5859
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
5960
T.copy(scores_max, scores_max_prev)

examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def flash_fwd(
5454
-T.infinity(acc_s.dtype))
5555
else:
5656
for i, j in T.Parallel(block_M, block_N):
57-
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0)
57+
acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len,
58+
-T.infinity(acc_s.dtype), 0)
5859
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
5960
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
6061
T.copy(scores_max, scores_max_prev)

0 commit comments

Comments
 (0)