Skip to content

Commit f003f37

Browse files
authored
[GQA] Add regional atomic add to slightly boost performance (#1093)
* [Lint] * [BugFix] Freeze the memory order of all atomic_add operations * [Lint] * [Atomic] Move on to regional atomic add * [Lint]
1 parent 5cb5c06 commit f003f37

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -366,23 +366,23 @@ def flash_bwd(
366366
T.copy(dsT_cast, dsT_shared)
367367
T.clear(dq)
368368
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
369-
for i, d in T.Parallel(block_N, dim_qk):
370-
T.atomic_add(
371-
dQ[q_start_idx + k_base * block_N + i, bx, d],
372-
dq[i, d],
373-
memory_order="release")
374-
375-
for i, d in T.Parallel(block_M, dim_v):
376369
T.atomic_add(
377-
dV[k_start_idx + by * block_M + i, bx // groups, d],
378-
dv[i, d],
379-
memory_order="release")
380-
for i, d in T.Parallel(block_M, dim_qk):
381-
T.atomic_add(
382-
dK[k_start_idx + by * block_M + i, bx // groups, d],
383-
dk[i, d],
370+
dQ[q_start_idx + k_base * block_N:q_start_idx + k_base * block_N + block_N,
371+
bx, :],
372+
dq,
384373
memory_order="release")
385374

375+
T.atomic_add(
376+
dV[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
377+
bx // groups, :],
378+
dv,
379+
memory_order="release")
380+
T.atomic_add(
381+
dK[k_start_idx + by * block_M:k_start_idx + by * block_M + block_M,
382+
bx // groups, :],
383+
dk,
384+
memory_order="release")
385+
386386
return flash_bwd
387387

388388

0 commit comments

Comments
 (0)