@@ -51,7 +51,8 @@ def _flash_attn_forward(
5151 sm_margin = 0 ,
5252 s_aux = None ,
5353 cp_world_size = 1 ,
54- cp_rank = 0 ):
54+ cp_rank = 0 ,
55+ cp_tot_seqused_k = None ):
5556 q , k , k_new , v_new = [maybe_contiguous (x ) for x in (q , k , k_new , v_new )]
5657 v = v .contiguous () if v .stride (- 1 ) != 1 and v .stride (- 3 ) != 1 else v
5758 cu_seqlens_q , cu_seqlens_k , cu_seqlens_k_new = [
@@ -99,7 +100,8 @@ def _flash_attn_forward(
99100 sm_margin ,
100101 s_aux ,
101102 cp_world_size ,
102- cp_rank
103+ cp_rank ,
104+ cp_tot_seqused_k ,
103105 )
104106 return out , softmax_lse , * rest
105107
@@ -266,6 +268,7 @@ def forward(
266268 s_aux = None ,
267269 cp_world_size = 1 ,
268270 cp_rank = 0 ,
271+ cp_tot_seqused_k = None ,
269272 ):
270273 if softmax_scale is None :
271274 softmax_scale = (q .shape [- 1 ] + (qv .shape [- 1 ] if qv is not None else 0 )) ** (- 0.5 )
@@ -293,6 +296,7 @@ def forward(
293296 s_aux = s_aux ,
294297 cp_world_size = cp_world_size ,
295298 cp_rank = cp_rank ,
299+ cp_tot_seqused_k = cp_tot_seqused_k ,
296300 )
297301 # ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
298302 ctx .save_for_backward (q , k , v , out , softmax_lse )
@@ -361,6 +365,7 @@ def forward(
361365 s_aux = None ,
362366 cp_world_size = 1 ,
363367 cp_rank = 0 ,
368+ cp_tot_seqused_k = 0 ,
364369 ):
365370 if softmax_scale is None :
366371 softmax_scale = (q .shape [- 1 ] + (qv .shape [- 1 ] if qv is not None else 0 )) ** (- 0.5 )
@@ -392,6 +397,7 @@ def forward(
392397 s_aux = s_aux ,
393398 cp_world_size = cp_world_size ,
394399 cp_rank = cp_rank ,
400+ cp_tot_seqused_k = cp_tot_seqused_k ,
395401 )
396402 # ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
397403 ctx .save_for_backward (q , k , v , out , softmax_lse , cu_seqlens_q , cu_seqlens_k , seqused_q , seqused_k )
@@ -511,6 +517,7 @@ def flash_attn_func(
511517 s_aux = None ,
512518 cp_world_size = 1 ,
513519 cp_rank = 0 ,
520+ cp_tot_seqused_k = None ,
514521):
515522 """dropout_p should be set to 0.0 during evaluation
516523 Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
@@ -574,6 +581,7 @@ def flash_attn_func(
574581 s_aux ,
575582 cp_world_size ,
576583 cp_rank ,
584+ cp_tot_seqused_k ,
577585 )
578586
579587
@@ -600,6 +608,7 @@ def flash_attn_varlen_func(
600608 s_aux = None ,
601609 cp_world_size = 1 ,
602610 cp_rank = 0 ,
611+ cp_tot_seqused_k = None ,
603612):
604613 return FlashAttnVarlenFunc .apply (
605614 q ,
@@ -624,6 +633,7 @@ def flash_attn_varlen_func(
624633 s_aux ,
625634 cp_world_size ,
626635 cp_rank ,
636+ cp_tot_seqused_k ,
627637 )
628638
629639
@@ -664,6 +674,7 @@ def flash_attn_with_kvcache(
664674 s_aux = None ,
665675 cp_world_size = 1 ,
666676 cp_rank = 0 ,
677+ cp_tot_seqused_k = None ,
667678):
668679 """
669680 If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
@@ -793,6 +804,7 @@ def flash_attn_with_kvcache(
793804 s_aux = s_aux ,
794805 cp_world_size = cp_world_size ,
795806 cp_rank = cp_rank ,
807+ cp_tot_seqused_k = cp_tot_seqused_k ,
796808 )
797809 # return (out, softmax_lse) if return_softmax_lse else out
798810 return (out , softmax_lse , * rest ) if return_softmax_lse else out
0 commit comments