@@ -366,23 +366,23 @@ def flash_bwd(
366366 T .copy (dsT_cast , dsT_shared )
367367 T .clear (dq )
368368 T .gemm (dsT_shared , K_shared , dq , transpose_A = True )
369- for i , d in T .Parallel (block_N , dim_qk ):
370- T .atomic_add (
371- dQ [q_start_idx + k_base * block_N + i , bx , d ],
372- dq [i , d ],
373- memory_order = "release" )
374-
375- for i , d in T .Parallel (block_M , dim_v ):
376369 T .atomic_add (
377- dV [k_start_idx + by * block_M + i , bx // groups , d ],
378- dv [i , d ],
379- memory_order = "release" )
380- for i , d in T .Parallel (block_M , dim_qk ):
381- T .atomic_add (
382- dK [k_start_idx + by * block_M + i , bx // groups , d ],
383- dk [i , d ],
370+ dQ [q_start_idx + k_base * block_N :q_start_idx + k_base * block_N + block_N ,
371+ bx , :],
372+ dq ,
384373 memory_order = "release" )
385374
375+ T .atomic_add (
376+ dV [k_start_idx + by * block_M :k_start_idx + by * block_M + block_M ,
377+ bx // groups , :],
378+ dv ,
379+ memory_order = "release" )
380+ T .atomic_add (
381+ dK [k_start_idx + by * block_M :k_start_idx + by * block_M + block_M ,
382+ bx // groups , :],
383+ dk ,
384+ memory_order = "release" )
385+
386386 return flash_bwd
387387
388388
0 commit comments