|
7 | 7 | from einops import rearrange, repeat |
8 | 8 | from bert_padding import pad_input, unpad_input |
9 | 9 |
|
10 | | -torch.manual_seed(1) |
11 | | - |
12 | 10 |
|
13 | 11 | def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): |
14 | 12 | assert mode in ["full", "random", "third"] |
@@ -525,7 +523,10 @@ def flash_bwd( |
525 | 523 | T.gemm(dsT_shared, K_shared, dq, transpose_A=True) |
526 | 524 | for i, j in T.Parallel(block_N, dim_qk): |
527 | 525 | if k_base * block_N + i < q_current_seqlen: |
528 | | - T.atomic_add(dQ[q_start_idx + k_base * block_N + i, bx, j], dq[i, j]) |
| 526 | + T.atomic_add( |
| 527 | + dQ[q_start_idx + k_base * block_N + i, bx, j], |
| 528 | + dq[i, j], |
| 529 | + memory_order="release") |
529 | 530 |
|
530 | 531 | T.copy(dv, dv_shared) |
531 | 532 | for i, d in T.Parallel(block_M, dim_v): |
@@ -739,9 +740,9 @@ def main(BATCH: int = 1, |
739 | 740 | dV_ref, V.grad = V.grad.clone(), None |
740 | 741 |
|
741 | 742 | torch.testing.assert_close(O, O_ref.half(), rtol=1e-2, atol=1e-2) |
742 | | - torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) |
743 | 743 | torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2) |
744 | 744 | torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2) |
| 745 | + torch.testing.assert_close(dQ, dQ_ref, rtol=1e-2, atol=1e-2) |
745 | 746 | print('All checks passed.✅') |
746 | 747 |
|
747 | 748 | def run(): |
@@ -784,8 +785,8 @@ def run1(): |
784 | 785 | elif args.use_atomic: |
785 | 786 | use_atomic = True |
786 | 787 | else: |
787 | | - # Default: use atomic |
788 | | - use_atomic = True |
| 788 | + # Default: use split |
| 789 | + use_atomic = False |
789 | 790 |
|
790 | 791 | main(args.batch, args.h, args.n_ctx, args.d_head_qk, args.d_head_v, args.groups, args.causal, |
791 | 792 | use_atomic) |
0 commit comments