@@ -154,6 +154,8 @@ def flashattn_bwd(batch, heads, seq_len, dim_qk, dim_v, is_causal, block_M, bloc
154154 q_shape = [batch , seq_len , heads , dim_qk ]
155155 k_shape = [batch , seq_len , head_kv , dim_qk ]
156156 v_shape = [batch , seq_len , head_kv , dim_v ]
157+ dk_shape = [groups , batch , seq_len , head_kv , dim_qk ] # sum after kernel
158+ dv_shape = [groups , batch , seq_len , head_kv , dim_v ] # sum after kernel
157159 dtype = "float16"
158160 accum_dtype = "float"
159161
@@ -166,8 +168,8 @@ def flash_bwd(
166168 lse : T .Tensor ([batch , heads , seq_len ], accum_dtype ), # type: ignore
167169 Delta : T .Tensor ([batch , heads , seq_len ], accum_dtype ), # type: ignore
168170 dQ : T .Tensor (q_shape , accum_dtype ), # type: ignore
169- dK : T .Tensor (k_shape , dtype ), # type: ignore
170- dV : T .Tensor (v_shape , dtype ), # type: ignore
171+ dK : T .Tensor (dk_shape , dtype ), # type: ignore
172+ dV : T .Tensor (dv_shape , dtype ), # type: ignore
171173 ):
172174 with T .Kernel (heads , T .ceildiv (seq_len , block_M ), batch , threads = 128 ) as (bx , by , bz ):
173175 K_shared = T .alloc_shared ([block_M , dim_qk ], dtype )
@@ -184,8 +186,8 @@ def flash_bwd(
184186 dv = T .alloc_fragment ([block_M , dim_v ], accum_dtype )
185187 dk = T .alloc_fragment ([block_M , dim_qk ], accum_dtype )
186188 dq = T .alloc_fragment ([block_N , dim_qk ], accum_dtype )
187- dv_shared = T .alloc_shared ([block_N , dim_v ], dtype )
188- dk_shared = T .alloc_shared ([block_N , dim_qk ], dtype )
189+ dv_shared = T .alloc_shared ([block_M , dim_v ], dtype )
190+ dk_shared = T .alloc_shared ([block_M , dim_qk ], dtype )
189191
190192 T .annotate_layout ({
191193 dQ : make_dq_layout (dQ ),
@@ -230,10 +232,10 @@ def flash_bwd(
230232 if k * block_N + i < seq_len :
231233 T .atomic_add (dQ [bz , k * block_N + i , bx , j ], dq [i , j ])
232234
233- for i , j in T . Parallel ( block_M , dim_v ):
234- T . atomic_add ( dV [bz , by * block_M + i , bx // groups , j ], dv [ i , j ])
235- for i , j in T . Parallel ( block_M , dim_qk ):
236- T . atomic_add ( dK [bz , by * block_M + i , bx // groups , j ], dk [ i , j ])
235+ T . copy ( dv , dv_shared )
236+ T . copy ( dv_shared , dV [bx % groups , bz , by * block_M :( by + 1 ) * block_M , bx // groups , : ])
237+ T . copy ( dk , dk_shared )
238+ T . copy ( dk , dK [bx % groups , bz , by * block_M :( by + 1 ) * block_M , bx // groups , : ])
237239
238240 return flash_bwd
239241
@@ -274,13 +276,14 @@ def maybe_contiguous(x):
274276 kernel = flashattn_bwd (BATCH , H , N_CTX , D_HEAD_QK , D_HEAD_V , ctx .causal , block_M , block_N ,
275277 groups )
276278 shape_q = [BATCH , N_CTX , H , D_HEAD_QK ]
277- shape_k = [BATCH , N_CTX , HEAD_KV , D_HEAD_QK ]
278- shape_v = [BATCH , N_CTX , HEAD_KV , D_HEAD_V ]
279+ shape_k = [groups , BATCH , N_CTX , HEAD_KV , D_HEAD_QK ] # sum after kernel
280+ shape_v = [groups , BATCH , N_CTX , HEAD_KV , D_HEAD_V ] # sum after kernel
279281 dq = torch .zeros (shape_q , dtype = torch .float32 , device = q .device )
280- dk = torch .zeros (shape_k , dtype = torch .float16 , device = q .device )
281- dv = torch .zeros (shape_v , dtype = torch .float16 , device = q .device )
282+ dk = torch .empty (shape_k , dtype = torch .float16 , device = q .device )
283+ dv = torch .empty (shape_v , dtype = torch .float16 , device = q .device )
282284 kernel (q , k , v , do , lse , delta , dq , dk , dv )
283285 dq = mod_post (dq )
286+ dk , dv = dk .sum (0 ), dv .sum (0 )
284287 return dq , dk , dv , None , None
285288
286289
@@ -354,6 +357,7 @@ def main(BATCH: int = 1,
354357 torch .testing .assert_close (dV , dV_ref , rtol = 1e-2 , atol = 1e-2 )
355358 torch .testing .assert_close (dK , dK_ref , rtol = 1e-2 , atol = 1e-2 )
356359 torch .testing .assert_close (dQ , dQ_ref , rtol = 1e-2 , atol = 1e-2 )
360+ print ('All checks passed.✅' )
357361
358362 def run ():
359363 O_ref .backward (dO , retain_graph = True )
0 commit comments