@@ -1049,6 +1049,52 @@ def flash_attention_v3_varlen(
10491049 num_splits = 1 ,
10501050 pack_gqa = None ,
10511051 sm_margin = 0 ,
1052+ ):
1053+ return flash_attn_varlen_func (
1054+ query ,
1055+ key ,
1056+ value ,
1057+ cu_seqlens_q ,
1058+ cu_seqlens_k ,
1059+ max_seqlen_q ,
1060+ max_seqlen_k ,
1061+ seqused_q ,
1062+ seqused_k ,
1063+ softmax_scale ,
1064+ causal ,
1065+ qv ,
1066+ q_descale ,
1067+ k_descale ,
1068+ v_descale ,
1069+ window_size ,
1070+ softcap ,
1071+ num_splits ,
1072+ pack_gqa ,
1073+ sm_margin ,
1074+ )
1075+
1076+
1077+ def flash_attn_varlen_func (
1078+ query ,
1079+ key ,
1080+ value ,
1081+ cu_seqlens_q ,
1082+ cu_seqlens_k ,
1083+ max_seqlen_q ,
1084+ max_seqlen_k ,
1085+ seqused_q = None ,
1086+ seqused_k = None ,
1087+ softmax_scale = None ,
1088+ causal = False ,
1089+ qv = None ,
1090+ q_descale = None ,
1091+ k_descale = None ,
1092+ v_descale = None ,
1093+ window_size = (- 1 , - 1 ),
1094+ softcap = 0.0 ,
1095+ num_splits = 1 ,
1096+ pack_gqa = None ,
1097+ sm_margin = 0 ,
10521098):
10531099 r"""
10541100 The equation is:
@@ -1097,24 +1143,24 @@ def flash_attention_v3_varlen(
10971143 """
10981144 assert (
10991145 "xpu" not in paddle .get_device ()
1100- ), "flash_attention_v3_varlen is not supported on xpu"
1146+ ), "flash_attn_varlen_func is not supported on xpu"
11011147
11021148 assert not paddle .get_flags (["FLAGS_cudnn_deterministic" ])[
11031149 "FLAGS_cudnn_deterministic"
1104- ], "flash_attention_v3_varlen does not support deterministic"
1150+ ], "flash_attn_varlen_func does not support deterministic"
11051151
11061152 assert (
11071153 paddle .base .framework .get_flags (["FLAGS_flash_attn_version" ])[
11081154 "FLAGS_flash_attn_version"
11091155 ]
11101156 == 3
1111- ), "FLAGS_flash_attn_version is 2, conflits with flash_attn_varlen_v3 "
1157+ ), "FLAGS_flash_attn_version is 2, conflits with flash_attn_varlen_func "
11121158
11131159 assert (
11141160 in_dynamic_or_pir_mode ()
1115- ), "flash_attention_v3_varlen only support dynamic or pir mode"
1161+ ), "flash_attn_varlen_func only support dynamic or pir mode"
11161162
1117- assert qv is None , "flash_attention_v3_varlen does not support setting qv"
1163+ assert qv is None , "flash_attn_varlen_func does not support setting qv"
11181164
11191165 if softmax_scale is None :
11201166 softmax_scale = (
0 commit comments