Skip to content

Commit b579e6c

Browse files
committed
add flash_attn_varlen_func
1 parent d8713ed commit b579e6c

File tree

1 file changed

+51
-5
lines changed

1 file changed

+51
-5
lines changed

python/paddle/nn/functional/flash_attention.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,6 +1049,52 @@ def flash_attention_v3_varlen(
10491049
num_splits=1,
10501050
pack_gqa=None,
10511051
sm_margin=0,
1052+
):
1053+
return flash_attn_varlen_func(
1054+
query,
1055+
key,
1056+
value,
1057+
cu_seqlens_q,
1058+
cu_seqlens_k,
1059+
max_seqlen_q,
1060+
max_seqlen_k,
1061+
seqused_q,
1062+
seqused_k,
1063+
softmax_scale,
1064+
causal,
1065+
qv,
1066+
q_descale,
1067+
k_descale,
1068+
v_descale,
1069+
window_size,
1070+
softcap,
1071+
num_splits,
1072+
pack_gqa,
1073+
sm_margin,
1074+
)
1075+
1076+
1077+
def flash_attn_varlen_func(
1078+
query,
1079+
key,
1080+
value,
1081+
cu_seqlens_q,
1082+
cu_seqlens_k,
1083+
max_seqlen_q,
1084+
max_seqlen_k,
1085+
seqused_q=None,
1086+
seqused_k=None,
1087+
softmax_scale=None,
1088+
causal=False,
1089+
qv=None,
1090+
q_descale=None,
1091+
k_descale=None,
1092+
v_descale=None,
1093+
window_size=(-1, -1),
1094+
softcap=0.0,
1095+
num_splits=1,
1096+
pack_gqa=None,
1097+
sm_margin=0,
10521098
):
10531099
r"""
10541100
The equation is:
@@ -1097,24 +1143,24 @@ def flash_attention_v3_varlen(
10971143
"""
10981144
assert (
10991145
"xpu" not in paddle.get_device()
1100-
), "flash_attention_v3_varlen is not supported on xpu"
1146+
), "flash_attn_varlen_func is not supported on xpu"
11011147

11021148
assert not paddle.get_flags(["FLAGS_cudnn_deterministic"])[
11031149
"FLAGS_cudnn_deterministic"
1104-
], "flash_attention_v3_varlen does not support deterministic"
1150+
], "flash_attn_varlen_func does not support deterministic"
11051151

11061152
assert (
11071153
paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])[
11081154
"FLAGS_flash_attn_version"
11091155
]
11101156
== 3
1111-
), "FLAGS_flash_attn_version is 2, conflits with flash_attn_varlen_v3"
1157+
), "FLAGS_flash_attn_version is 2, conflits with flash_attn_varlen_func"
11121158

11131159
assert (
11141160
in_dynamic_or_pir_mode()
1115-
), "flash_attention_v3_varlen only support dynamic or pir mode"
1161+
), "flash_attn_varlen_func only support dynamic or pir mode"
11161162

1117-
assert qv is None, "flash_attention_v3_varlen does not support setting qv"
1163+
assert qv is None, "flash_attn_varlen_func does not support setting qv"
11181164

11191165
if softmax_scale is None:
11201166
softmax_scale = (

0 commit comments

Comments
 (0)