Skip to content

Commit c0032d7

Browse files
authored
supports fa3_varlen api (PaddlePaddle#72805)
* supports fa3_varlen api * supports fa3_varlen api * supports fa3_varlen api * supports fa3_varlen api * supports fa3_varlen api * supports fa3_varlen api * supports fa3_varlen api * supports fa3_varlen api * supports fa3_varlen api
1 parent e7007d1 commit c0032d7

File tree

6 files changed

+351
-19
lines changed

6 files changed

+351
-19
lines changed

paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,127 @@ void FlashAttnV3Kernel(const Context &ctx,
10711071
#endif
10721072
}
10731073

1074+
template <typename T, typename Context>
1075+
void FlashAttnV3VarLenKernel(const Context &ctx,
1076+
const DenseTensor &q,
1077+
const DenseTensor &k,
1078+
const DenseTensor &v,
1079+
const DenseTensor &cu_seqlens_q,
1080+
const DenseTensor &cu_seqlens_k,
1081+
const paddle::optional<DenseTensor> &q_v_,
1082+
const paddle::optional<DenseTensor> &q_descale_,
1083+
const paddle::optional<DenseTensor> &k_descale_,
1084+
const paddle::optional<DenseTensor> &v_descale_,
1085+
const float softmax_scale,
1086+
bool is_causal,
1087+
int window_size_left,
1088+
int window_size_right,
1089+
const float softcap,
1090+
int num_splits,
1091+
const bool manual_set_pack_gqa,
1092+
const bool pack_gqa_,
1093+
const int sm_margin,
1094+
const int max_seqlen_q,
1095+
const int max_seqlen_k,
1096+
DenseTensor *out,
1097+
DenseTensor *softmax_lse) {
1098+
#ifdef PADDLE_WITH_FLASHATTN_V3
1099+
// umiswing: the following options have not been fully tested
1100+
PADDLE_ENFORCE_EQ(q_v_.is_initialized(),
1101+
false,
1102+
common::errors::InvalidArgument("q_v_ is not supported"));
1103+
PADDLE_ENFORCE_EQ(
1104+
q_descale_.is_initialized(),
1105+
false,
1106+
common::errors::InvalidArgument("q_descale_ is not supported"));
1107+
PADDLE_ENFORCE_EQ(
1108+
k_descale_.is_initialized(),
1109+
false,
1110+
common::errors::InvalidArgument("k_descale_ is not supported"));
1111+
PADDLE_ENFORCE_EQ(
1112+
v_descale_.is_initialized(),
1113+
false,
1114+
common::errors::InvalidArgument("v_descale_ is not supported"));
1115+
PADDLE_ENFORCE_EQ(
1116+
window_size_left,
1117+
-1,
1118+
common::errors::InvalidArgument("window_size is not supported, please "
1119+
"set window_size_left/right to -1"));
1120+
PADDLE_ENFORCE_EQ(
1121+
window_size_right,
1122+
-1,
1123+
common::errors::InvalidArgument("window_size is not supported, please "
1124+
"set window_size_left/right to -1"));
1125+
PADDLE_ENFORCE_EQ(softcap,
1126+
0,
1127+
common::errors::InvalidArgument(
1128+
"softcap is not supported, please set softcap to 0"));
1129+
PADDLE_ENFORCE_EQ(
1130+
num_splits,
1131+
1,
1132+
common::errors::InvalidArgument(
1133+
"num_splits is not supported, please set num_splits to 1"));
1134+
PADDLE_ENFORCE_EQ(manual_set_pack_gqa,
1135+
false,
1136+
common::errors::InvalidArgument(
1137+
"manual_set_pack_gqa is not supported, please set "
1138+
"manual_set_pack_gqa to false"));
1139+
PADDLE_ENFORCE_EQ(
1140+
pack_gqa_,
1141+
false,
1142+
common::errors::InvalidArgument(
1143+
"pack_gqa_ is not supported, please set pack_gqa_ to false"));
1144+
PADDLE_ENFORCE_EQ(
1145+
sm_margin,
1146+
0,
1147+
common::errors::InvalidArgument(
1148+
"sm_margin is not supported, please set sm_margin to 0"));
1149+
1150+
DenseTensor out_accum;
1151+
DenseTensor softmax_lse_accum;
1152+
FlashAttnV3BaseKernel<T, Context>(ctx,
1153+
q,
1154+
k,
1155+
v,
1156+
paddle::none, // k_new_
1157+
paddle::none, // v_new_
1158+
q_v_,
1159+
paddle::none, // out_
1160+
cu_seqlens_q, // cu_seqlens_q_
1161+
cu_seqlens_k, // cu_seqlens_k_
1162+
paddle::none, // cu_seqlens_k_new_
1163+
paddle::none, // seqused_q_
1164+
paddle::none, // seqused_k_
1165+
paddle::none, // page_table_
1166+
paddle::none, // kv_batch_idx_
1167+
paddle::none, // leftpad_k_
1168+
paddle::none, // rotary_cos_
1169+
paddle::none, // rotary_sin_
1170+
q_descale_,
1171+
k_descale_,
1172+
v_descale_,
1173+
paddle::none, // scheduler_metadata
1174+
max_seqlen_q, // max_seqlen_q_
1175+
max_seqlen_k, // max_seqlen_k_
1176+
softmax_scale,
1177+
is_causal,
1178+
window_size_left,
1179+
window_size_right,
1180+
softcap,
1181+
true, // is_rotary_interleaved
1182+
num_splits,
1183+
manual_set_pack_gqa,
1184+
pack_gqa_,
1185+
sm_margin,
1186+
out,
1187+
softmax_lse,
1188+
&out_accum,
1189+
&softmax_lse_accum);
1190+
#else
1191+
RaiseNotSupportedError();
1192+
#endif
1193+
}
1194+
10741195
} // namespace phi
10751196

10761197
PD_REGISTER_KERNEL(flash_attn_v3,
@@ -1079,3 +1200,10 @@ PD_REGISTER_KERNEL(flash_attn_v3,
10791200
phi::FlashAttnV3Kernel,
10801201
phi::dtype::float16,
10811202
phi::dtype::bfloat16) {}
1203+
1204+
PD_REGISTER_KERNEL(flash_attn_v3_varlen,
1205+
GPU,
1206+
ALL_LAYOUT,
1207+
phi::FlashAttnV3VarLenKernel,
1208+
phi::dtype::float16,
1209+
phi::dtype::bfloat16) {}

paddle/phi/kernels/gpu/flash_attn_v3_kernel.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,29 @@ void FlashAttnV3Kernel(const Context &ctx,
3434
const int sm_margin,
3535
DenseTensor *out,
3636
DenseTensor *softmax_lse);
37+
38+
template <typename T, typename Context>
39+
void FlashAttnV3VarLenKernel(const Context &ctx,
40+
const DenseTensor &q,
41+
const DenseTensor &k,
42+
const DenseTensor &v,
43+
const DenseTensor &cu_seqlens_q,
44+
const DenseTensor &cu_seqlens_k,
45+
const paddle::optional<DenseTensor> &q_v_,
46+
const paddle::optional<DenseTensor> &q_descale_,
47+
const paddle::optional<DenseTensor> &k_descale_,
48+
const paddle::optional<DenseTensor> &v_descale_,
49+
const float softmax_scale,
50+
bool is_causal,
51+
int window_size_left,
52+
int window_size_right,
53+
const float softcap,
54+
int num_splits,
55+
const bool manual_set_pack_gqa,
56+
const bool pack_gqa_,
57+
const int sm_margin,
58+
const int max_seqlen_q,
59+
const int max_seqlen_k,
60+
DenseTensor *out,
61+
DenseTensor *softmax_lse);
3762
} // namespace phi

paddle/phi/ops/yaml/ops.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2055,6 +2055,17 @@
20552055
data_type : q
20562056
backward : flash_attn_v3_grad
20572057

2058+
- op : flash_attn_v3_varlen
2059+
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor q_v_, Tensor q_descale_, Tensor k_descale_, Tensor v_descale_, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa_, int sm_margin, int max_seqlen_q, int max_seqlen_k)
2060+
output : Tensor(out), Tensor(softmax_lse)
2061+
optional : q_v_, q_descale_, k_descale_, v_descale_
2062+
infer_meta :
2063+
func : FlashAttnV3InferMeta
2064+
param : [q, k, v]
2065+
kernel :
2066+
func : flash_attn_v3_varlen
2067+
data_type : q
2068+
20582069
- op : flash_attn_varlen_qkvpacked
20592070
args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, Scalar max_seqlen_q, Scalar max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true)
20602071
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)

python/paddle/nn/functional/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
temporal_shift,
8888
)
8989
from .flash_attention import (
90+
flash_attention_v3_varlen,
9091
flash_attn_qkvpacked,
9192
flash_attn_varlen_qkvpacked,
9293
flashmask_attention,
@@ -294,6 +295,7 @@
294295
'scaled_dot_product_attention',
295296
'flashmask_attention',
296297
'flash_attn_qkvpacked',
298+
"flash_attention_v3_varlen",
297299
'flash_attn_varlen_qkvpacked',
298300
'group_norm',
299301
]

python/paddle/nn/functional/flash_attention.py

Lines changed: 157 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -434,15 +434,15 @@ def flash_attention(
434434
query(Tensor): The query tensor in the Attention module.
435435
4-D tensor with shape:
436436
[batch_size, seq_len, num_heads, head_dim].
437-
The dtype can be float61 or bfloat16.
437+
The dtype can be float16 or bfloat16.
438438
key(Tensor): The key tensor in the Attention module.
439439
4-D tensor with shape:
440440
[batch_size, seq_len, num_heads, head_dim].
441-
The dtype can be float61 or bfloat16.
441+
The dtype can be float16 or bfloat16.
442442
value(Tensor): The value tensor in the Attention module.
443443
4-D tensor with shape:
444444
[batch_size, seq_len, num_heads, head_dim].
445-
The dtype can be float61 or bfloat16.
445+
The dtype can be float16 or bfloat16.
446446
dropout(float): The dropout ratio.
447447
causal(bool): Whether enable causal mode.
448448
return_softmax(bool): Whether to return softmax.
@@ -623,6 +623,157 @@ def flash_attention(
623623
)
624624

625625

626+
@overload
627+
def flash_attention_v3_varlen(
628+
query: Tensor,
629+
key: Tensor,
630+
value: Tensor,
631+
cu_seqlens_q: Tensor,
632+
cu_seqlens_k: Tensor,
633+
dropout: float = ...,
634+
causal: bool = ...,
635+
return_softmax: Literal[False] = ...,
636+
*,
637+
fixed_seed_offset: Tensor | None = ...,
638+
rng_name: str = ...,
639+
training: bool = ...,
640+
softmax_scale: float | None = ...,
641+
max_seqlen_q: int = ...,
642+
max_seqlen_k: int = ...,
643+
name: str | None = ...,
644+
) -> tuple[Tensor, None]: ...
645+
646+
647+
@overload
648+
def flash_attention_v3_varlen(
649+
query: Tensor,
650+
key: Tensor,
651+
value: Tensor,
652+
cu_seqlens_q: Tensor,
653+
cu_seqlens_k: Tensor,
654+
dropout: float = ...,
655+
causal: bool = ...,
656+
return_softmax: Literal[True] = ...,
657+
*,
658+
fixed_seed_offset: Tensor | None = ...,
659+
rng_name: str = ...,
660+
training: bool = ...,
661+
softmax_scale: float | None = ...,
662+
max_seqlen_q: int = ...,
663+
max_seqlen_k: int = ...,
664+
name: str | None = ...,
665+
) -> tuple[Tensor, Tensor]: ...
666+
667+
668+
def flash_attention_v3_varlen(
669+
query,
670+
key,
671+
value,
672+
cu_seqlens_q,
673+
cu_seqlens_k,
674+
dropout=0.0,
675+
causal=False,
676+
return_softmax=False,
677+
*,
678+
fixed_seed_offset=None,
679+
rng_name="",
680+
training=True,
681+
softmax_scale=None,
682+
max_seqlen_q=0,
683+
max_seqlen_k=0,
684+
name=None,
685+
):
686+
r"""
687+
The equation is:
688+
689+
.. math::
690+
691+
result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V
692+
693+
where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
694+
The dimensions of the three parameters are the same.
695+
``d`` represents the size of the last dimension of the three parameters.
696+
This is the varlen version of flash attention.
697+
698+
Warning:
699+
This API is only support inputs with dtype float16 and bfloat16.
700+
701+
Args:
702+
query(Tensor): The query tensor in the Attention module.
703+
3-D tensor with shape:
704+
[token_num, num_heads, head_dim].
705+
The dtype can be float16 or bfloat16.
706+
key(Tensor): The key tensor in the Attention module.
707+
3-D tensor with shape:
708+
[token_num, num_heads, head_dim].
709+
The dtype can be float16 or bfloat16.
710+
value(Tensor): The value tensor in the Attention module.
711+
3-D tensor with shape:
712+
[token_num, num_heads, head_dim].
713+
The dtype can be float16 or bfloat16.
714+
cu_seqlens_q(Tensor): The cumulative sequence lengths of the sequences in the batch,
715+
used to index query.
716+
cu_seqlens_k(Tensor): The cumulative sequence lengths of the sequences in the batch,
717+
used to index key and value.
718+
dropout(float): The dropout ratio.
719+
causal(bool): Whether enable causal mode.
720+
return_softmax(bool): Whether to return softmax.
721+
fixed_seed_offset(Tensor|None, optional): With fixed seed, offset for dropout mask.
722+
rng_name(str): The name to select Generator.
723+
training(bool): Whether it is in the training phase.
724+
softmax_scale(float): The softmax scale of the attention.
725+
max_seqlen_q(int): Maximum sequence length of query in the batch. Note it's the padding length, not the max actual seqlen.
726+
max_seqlen_k(int): Maximum sequence length of key/value in the batch.
727+
name(str|None, optional): The default value is None. Normally there is no need for user
728+
to set this property. For more information, please refer to
729+
:ref:`api_guide_Name`.
730+
731+
Returns:
732+
out(Tensor): The attention tensor. 3-D tensor with shape: [token_num, num_heads, head_dim]. The dtype can be float16 or bfloat16.
733+
softmax(Tensor): The softmax tensor. None if return_softmax is False.
734+
735+
Examples:
736+
.. code-block:: python
737+
738+
>>> # doctest: +SKIP('flash_attn_v3 need H100 compile')
739+
>>> import paddle
740+
741+
>>> paddle.seed(2023)
742+
>>> q = paddle.rand((10, 2, 128), dtype="bfloat16")
743+
>>> cu_seqlens_q = paddle.to_tensor([0, 10], dtype="int32")
744+
>>> max_seq_len_q = 10
745+
746+
>>> output = paddle.nn.functional.flash_attention.flash_attention_v3_varlen(q, q, q, cu_seqlens_q, cu_seqlens_q, max_seqlen_q=max_seq_len_q, max_seqlen_k=max_seq_len_q, causal=True)
747+
>>> # doctest: -SKIP
748+
749+
"""
750+
if softmax_scale is None:
751+
softmax_scale = query.shape[-1] ** (-0.5)
752+
out, softmax_lse = _C_ops.flash_attn_v3_varlen(
753+
query,
754+
key,
755+
value,
756+
cu_seqlens_q,
757+
cu_seqlens_k,
758+
None, # q_v_
759+
None, # q_descale_
760+
None, # k_descale_
761+
None, # v_descale_
762+
softmax_scale,
763+
causal,
764+
-1, # window_size_left
765+
-1, # window_size_right
766+
0.0, # softcap
767+
1, # num_splits
768+
False, # manual_set_pack_gqa
769+
False, # pack_gqa_
770+
0, # sm_margin,
771+
max_seqlen_q,
772+
max_seqlen_k,
773+
)
774+
return out, softmax_lse # return_softmax
775+
776+
626777
@overload
627778
def flash_attn_qkvpacked(
628779
qkv: Tensor,
@@ -912,15 +1063,15 @@ def flash_attn_unpadded(
9121063
query(Tensor): The query tensor in the Attention module.
9131064
3-D tensor with shape:
9141065
[total_seq_len, num_heads, head_dim].
915-
The dtype can be float61 or bfloat16.
1066+
The dtype can be float16 or bfloat16.
9161067
key(Tensor): The key tensor in the Attention module.
9171068
3-D tensor with shape:
9181069
[total_seq_len, num_heads, head_dim].
919-
The dtype can be float61 or bfloat16.
1070+
The dtype can be float16 or bfloat16.
9201071
value(Tensor): The value tensor in the Attention module.
9211072
3-D tensor with shape:
9221073
[total_seq_len, num_heads, head_dim].
923-
The dtype can be float61 or bfloat16.
1074+
The dtype can be float16 or bfloat16.
9241075
cu_seqlens_q(Tensor): The cumulative sequence lengths of the sequences in the batch,
9251076
used to index query.
9261077
cu_seqlens_k(Tensor): The cumulative sequence lengths of the sequences in the batch,

0 commit comments

Comments
 (0)