@@ -621,24 +621,32 @@ def _wrapped_flash_attn_3(
621621) -> Tuple [torch .Tensor , torch .Tensor ]:
622622 # Hardcoded for now because pytorch does not support tuple/int type hints
623623 window_size = (- 1 , - 1 )
624- out , lse , * _ = flash_attn_3_func (
625- q = q ,
626- k = k ,
627- v = v ,
628- softmax_scale = softmax_scale ,
629- causal = causal ,
630- qv = qv ,
631- q_descale = q_descale ,
632- k_descale = k_descale ,
633- v_descale = v_descale ,
634- window_size = window_size ,
635- attention_chunk = attention_chunk ,
636- softcap = softcap ,
637- num_splits = num_splits ,
638- pack_gqa = pack_gqa ,
639- deterministic = deterministic ,
640- sm_margin = sm_margin ,
641- )
624+
625+ kwargs = {
626+ "q" : q ,
627+ "k" : k ,
628+ "v" : v ,
629+ "softmax_scale" : softmax_scale ,
630+ "causal" : causal ,
631+ "qv" : qv ,
632+ "q_descale" : q_descale ,
633+ "k_descale" : k_descale ,
634+ "v_descale" : v_descale ,
635+ "window_size" : window_size ,
636+ "attention_chunk" : attention_chunk ,
637+ "softcap" : softcap ,
638+ "num_splits" : num_splits ,
639+ "pack_gqa" : pack_gqa ,
640+ "deterministic" : deterministic ,
641+ "sm_margin" : sm_margin ,
642+ }
643+
644+ # For backward compatibility with early flash-attn-3 APIs.
645+ if "return_attn_probs" in inspect .signature (flash_attn_3_func ).parameters :
646+ kwargs ["return_attn_probs" ] = True
647+
648+ out , lse , * _ = flash_attn_3_func (** kwargs )
649+
642650 lse = lse .permute (0 , 2 , 1 )
643651 return out , lse
644652
@@ -1504,17 +1512,29 @@ def _flash_varlen_attention_3(
15041512 key_packed = torch .cat (key_valid , dim = 0 )
15051513 value_packed = torch .cat (value_valid , dim = 0 )
15061514
1507- out , lse , * _ = flash_attn_3_varlen_func (
1508- q = query_packed ,
1509- k = key_packed ,
1510- v = value_packed ,
1511- cu_seqlens_q = cu_seqlens_q ,
1512- cu_seqlens_k = cu_seqlens_k ,
1513- max_seqlen_q = max_seqlen_q ,
1514- max_seqlen_k = max_seqlen_k ,
1515- softmax_scale = scale ,
1516- causal = is_causal ,
1517- )
1515+ kwargs = {
1516+ "q" : query_packed ,
1517+ "k" : key_packed ,
1518+ "v" : value_packed ,
1519+ "cu_seqlens_q" : cu_seqlens_q ,
1520+ "cu_seqlens_k" : cu_seqlens_k ,
1521+ "max_seqlen_q" : max_seqlen_q ,
1522+ "max_seqlen_k" : max_seqlen_k ,
1523+ "softmax_scale" : scale ,
1524+ "causal" : is_causal ,
1525+ }
1526+
1527+ if "return_attn_probs" in inspect .signature (flash_attn_3_varlen_func ).parameters :
1528+ kwargs ["return_attn_probs" ] = return_lse
1529+ out = flash_attn_3_varlen_func (** kwargs )
1530+ if return_lse :
1531+ out , lse = out [0 ], out [1 ]
1532+ else :
1533+ lse = None
1534+ else :
1535+ # For backward compatibility with early flash-attn-3 APIs.
1536+ out , lse , * _ = flash_attn_3_varlen_func (** kwargs )
1537+
15181538 out = out .unflatten (0 , (batch_size , - 1 ))
15191539
15201540 return (out , lse ) if return_lse else out
0 commit comments