Skip to content

Commit b81a1d2

Browse files
committed
Revert "supports fa3_varlen api (#72805)"
This reverts commit c0032d7.
1 parent 3cd2b19 commit b81a1d2

File tree

6 files changed

+19
-351
lines changed

6 files changed

+19
-351
lines changed

paddle/phi/kernels/gpu/flash_attn_v3_kernel.cu

Lines changed: 0 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,127 +1071,6 @@ void FlashAttnV3Kernel(const Context &dev_ctx,
10711071
#endif
10721072
}
10731073

1074-
template <typename T, typename Context>
1075-
void FlashAttnV3VarLenKernel(const Context &dev_ctx,
1076-
const DenseTensor &q,
1077-
const DenseTensor &k,
1078-
const DenseTensor &v,
1079-
const DenseTensor &cu_seqlens_q,
1080-
const DenseTensor &cu_seqlens_k,
1081-
const paddle::optional<DenseTensor> &q_v_,
1082-
const paddle::optional<DenseTensor> &q_descale_,
1083-
const paddle::optional<DenseTensor> &k_descale_,
1084-
const paddle::optional<DenseTensor> &v_descale_,
1085-
const float softmax_scale,
1086-
bool is_causal,
1087-
int window_size_left,
1088-
int window_size_right,
1089-
const float softcap,
1090-
int num_splits,
1091-
const bool manual_set_pack_gqa,
1092-
const bool pack_gqa_,
1093-
const int sm_margin,
1094-
const int max_seqlen_q,
1095-
const int max_seqlen_k,
1096-
DenseTensor *out,
1097-
DenseTensor *softmax_lse) {
1098-
#ifdef PADDLE_WITH_FLASHATTN_V3
1099-
// umiswing: the following options have not been fully tested
1100-
PADDLE_ENFORCE_EQ(q_v_.is_initialized(),
1101-
false,
1102-
common::errors::InvalidArgument("q_v_ is not supported"));
1103-
PADDLE_ENFORCE_EQ(
1104-
q_descale_.is_initialized(),
1105-
false,
1106-
common::errors::InvalidArgument("q_descale_ is not supported"));
1107-
PADDLE_ENFORCE_EQ(
1108-
k_descale_.is_initialized(),
1109-
false,
1110-
common::errors::InvalidArgument("k_descale_ is not supported"));
1111-
PADDLE_ENFORCE_EQ(
1112-
v_descale_.is_initialized(),
1113-
false,
1114-
common::errors::InvalidArgument("v_descale_ is not supported"));
1115-
PADDLE_ENFORCE_EQ(
1116-
window_size_left,
1117-
-1,
1118-
common::errors::InvalidArgument("window_size is not supported, please "
1119-
"set window_size_left/right to -1"));
1120-
PADDLE_ENFORCE_EQ(
1121-
window_size_right,
1122-
-1,
1123-
common::errors::InvalidArgument("window_size is not supported, please "
1124-
"set window_size_left/right to -1"));
1125-
PADDLE_ENFORCE_EQ(softcap,
1126-
0,
1127-
common::errors::InvalidArgument(
1128-
"softcap is not supported, please set softcap to 0"));
1129-
PADDLE_ENFORCE_EQ(
1130-
num_splits,
1131-
1,
1132-
common::errors::InvalidArgument(
1133-
"num_splits is not supported, please set num_splits to 1"));
1134-
PADDLE_ENFORCE_EQ(manual_set_pack_gqa,
1135-
false,
1136-
common::errors::InvalidArgument(
1137-
"manual_set_pack_gqa is not supported, please set "
1138-
"manual_set_pack_gqa to false"));
1139-
PADDLE_ENFORCE_EQ(
1140-
pack_gqa_,
1141-
false,
1142-
common::errors::InvalidArgument(
1143-
"pack_gqa_ is not supported, please set pack_gqa_ to false"));
1144-
PADDLE_ENFORCE_EQ(
1145-
sm_margin,
1146-
0,
1147-
common::errors::InvalidArgument(
1148-
"sm_margin is not supported, please set sm_margin to 0"));
1149-
1150-
DenseTensor out_accum;
1151-
DenseTensor softmax_lse_accum;
1152-
FlashAttnV3BaseKernel<T, Context>(dev_ctx,
1153-
q,
1154-
k,
1155-
v,
1156-
paddle::none, // k_new_
1157-
paddle::none, // v_new_
1158-
q_v_,
1159-
paddle::none, // out_
1160-
cu_seqlens_q, // cu_seqlens_q_
1161-
cu_seqlens_k, // cu_seqlens_k_
1162-
paddle::none, // cu_seqlens_k_new_
1163-
paddle::none, // seqused_q_
1164-
paddle::none, // seqused_k_
1165-
paddle::none, // page_table_
1166-
paddle::none, // kv_batch_idx_
1167-
paddle::none, // leftpad_k_
1168-
paddle::none, // rotary_cos_
1169-
paddle::none, // rotary_sin_
1170-
q_descale_,
1171-
k_descale_,
1172-
v_descale_,
1173-
paddle::none, // scheduler_metadata
1174-
max_seqlen_q, // max_seqlen_q_
1175-
max_seqlen_k, // max_seqlen_k_
1176-
softmax_scale,
1177-
is_causal,
1178-
window_size_left,
1179-
window_size_right,
1180-
softcap,
1181-
true, // is_rotary_interleaved
1182-
num_splits,
1183-
manual_set_pack_gqa,
1184-
pack_gqa_,
1185-
sm_margin,
1186-
out,
1187-
softmax_lse,
1188-
&out_accum,
1189-
&softmax_lse_accum);
1190-
#else
1191-
RaiseNotSupportedError();
1192-
#endif
1193-
}
1194-
11951074
} // namespace phi
11961075

11971076
PD_REGISTER_KERNEL(flash_attn_v3,
@@ -1200,10 +1079,3 @@ PD_REGISTER_KERNEL(flash_attn_v3,
12001079
phi::FlashAttnV3Kernel,
12011080
phi::dtype::float16,
12021081
phi::dtype::bfloat16) {}
1203-
1204-
PD_REGISTER_KERNEL(flash_attn_v3_varlen,
1205-
GPU,
1206-
ALL_LAYOUT,
1207-
phi::FlashAttnV3VarLenKernel,
1208-
phi::dtype::float16,
1209-
phi::dtype::bfloat16) {}

paddle/phi/kernels/gpu/flash_attn_v3_kernel.h

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,29 +34,4 @@ void FlashAttnV3Kernel(const Context &ctx,
3434
const int sm_margin,
3535
DenseTensor *out,
3636
DenseTensor *softmax_lse);
37-
38-
template <typename T, typename Context>
39-
void FlashAttnV3VarLenKernel(const Context &ctx,
40-
const DenseTensor &q,
41-
const DenseTensor &k,
42-
const DenseTensor &v,
43-
const DenseTensor &cu_seqlens_q,
44-
const DenseTensor &cu_seqlens_k,
45-
const paddle::optional<DenseTensor> &q_v_,
46-
const paddle::optional<DenseTensor> &q_descale_,
47-
const paddle::optional<DenseTensor> &k_descale_,
48-
const paddle::optional<DenseTensor> &v_descale_,
49-
const float softmax_scale,
50-
bool is_causal,
51-
int window_size_left,
52-
int window_size_right,
53-
const float softcap,
54-
int num_splits,
55-
const bool manual_set_pack_gqa,
56-
const bool pack_gqa_,
57-
const int sm_margin,
58-
const int max_seqlen_q,
59-
const int max_seqlen_k,
60-
DenseTensor *out,
61-
DenseTensor *softmax_lse);
6237
} // namespace phi

paddle/phi/ops/yaml/ops.yaml

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2088,17 +2088,6 @@
20882088
data_type : q
20892089
backward : flash_attn_v3_grad
20902090

2091-
- op : flash_attn_v3_varlen
2092-
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor q_v_, Tensor q_descale_, Tensor k_descale_, Tensor v_descale_, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, float softcap, int num_splits, bool manual_set_pack_gqa, bool pack_gqa_, int sm_margin, int max_seqlen_q, int max_seqlen_k)
2093-
output : Tensor(out), Tensor(softmax_lse)
2094-
optional : q_v_, q_descale_, k_descale_, v_descale_
2095-
infer_meta :
2096-
func : FlashAttnV3InferMeta
2097-
param : [q, k, v]
2098-
kernel :
2099-
func : flash_attn_v3_varlen
2100-
data_type : q
2101-
21022091
- op : flash_attn_varlen_qkvpacked
21032092
args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, Scalar max_seqlen_q, Scalar max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true)
21042093
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)

python/paddle/nn/functional/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@
8787
temporal_shift,
8888
)
8989
from .flash_attention import (
90-
flash_attention_v3_varlen,
9190
flash_attn_qkvpacked,
9291
flash_attn_varlen_qkvpacked,
9392
flashmask_attention,
@@ -297,7 +296,6 @@
297296
'scaled_dot_product_attention',
298297
'flashmask_attention',
299298
'flash_attn_qkvpacked',
300-
"flash_attention_v3_varlen",
301299
'flash_attn_varlen_qkvpacked',
302300
'group_norm',
303301
'moe_permute',

python/paddle/nn/functional/flash_attention.py

Lines changed: 6 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -446,15 +446,15 @@ def flash_attention(
446446
query(Tensor): The query tensor in the Attention module.
447447
4-D tensor with shape:
448448
[batch_size, seq_len, num_heads, head_dim].
449-
The dtype can be float16 or bfloat16.
449+
The dtype can be float61 or bfloat16.
450450
key(Tensor): The key tensor in the Attention module.
451451
4-D tensor with shape:
452452
[batch_size, seq_len, num_heads, head_dim].
453-
The dtype can be float16 or bfloat16.
453+
The dtype can be float61 or bfloat16.
454454
value(Tensor): The value tensor in the Attention module.
455455
4-D tensor with shape:
456456
[batch_size, seq_len, num_heads, head_dim].
457-
The dtype can be float16 or bfloat16.
457+
The dtype can be float61 or bfloat16.
458458
dropout(float): The dropout ratio.
459459
causal(bool): Whether enable causal mode.
460460
return_softmax(bool): Whether to return softmax.
@@ -635,157 +635,6 @@ def flash_attention(
635635
)
636636

637637

638-
@overload
639-
def flash_attention_v3_varlen(
640-
query: Tensor,
641-
key: Tensor,
642-
value: Tensor,
643-
cu_seqlens_q: Tensor,
644-
cu_seqlens_k: Tensor,
645-
dropout: float = ...,
646-
causal: bool = ...,
647-
return_softmax: Literal[False] = ...,
648-
*,
649-
fixed_seed_offset: Tensor | None = ...,
650-
rng_name: str = ...,
651-
training: bool = ...,
652-
softmax_scale: float | None = ...,
653-
max_seqlen_q: int = ...,
654-
max_seqlen_k: int = ...,
655-
name: str | None = ...,
656-
) -> tuple[Tensor, None]: ...
657-
658-
659-
@overload
660-
def flash_attention_v3_varlen(
661-
query: Tensor,
662-
key: Tensor,
663-
value: Tensor,
664-
cu_seqlens_q: Tensor,
665-
cu_seqlens_k: Tensor,
666-
dropout: float = ...,
667-
causal: bool = ...,
668-
return_softmax: Literal[True] = ...,
669-
*,
670-
fixed_seed_offset: Tensor | None = ...,
671-
rng_name: str = ...,
672-
training: bool = ...,
673-
softmax_scale: float | None = ...,
674-
max_seqlen_q: int = ...,
675-
max_seqlen_k: int = ...,
676-
name: str | None = ...,
677-
) -> tuple[Tensor, Tensor]: ...
678-
679-
680-
def flash_attention_v3_varlen(
681-
query,
682-
key,
683-
value,
684-
cu_seqlens_q,
685-
cu_seqlens_k,
686-
dropout=0.0,
687-
causal=False,
688-
return_softmax=False,
689-
*,
690-
fixed_seed_offset=None,
691-
rng_name="",
692-
training=True,
693-
softmax_scale=None,
694-
max_seqlen_q=0,
695-
max_seqlen_k=0,
696-
name=None,
697-
):
698-
r"""
699-
The equation is:
700-
701-
.. math::
702-
703-
result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V
704-
705-
where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
706-
The dimensions of the three parameters are the same.
707-
``d`` represents the size of the last dimension of the three parameters.
708-
This is the varlen version of flash attention.
709-
710-
Warning:
711-
This API is only support inputs with dtype float16 and bfloat16.
712-
713-
Args:
714-
query(Tensor): The query tensor in the Attention module.
715-
3-D tensor with shape:
716-
[token_num, num_heads, head_dim].
717-
The dtype can be float16 or bfloat16.
718-
key(Tensor): The key tensor in the Attention module.
719-
3-D tensor with shape:
720-
[token_num, num_heads, head_dim].
721-
The dtype can be float16 or bfloat16.
722-
value(Tensor): The value tensor in the Attention module.
723-
3-D tensor with shape:
724-
[token_num, num_heads, head_dim].
725-
The dtype can be float16 or bfloat16.
726-
cu_seqlens_q(Tensor): The cumulative sequence lengths of the sequences in the batch,
727-
used to index query.
728-
cu_seqlens_k(Tensor): The cumulative sequence lengths of the sequences in the batch,
729-
used to index key and value.
730-
dropout(float): The dropout ratio.
731-
causal(bool): Whether enable causal mode.
732-
return_softmax(bool): Whether to return softmax.
733-
fixed_seed_offset(Tensor|None, optional): With fixed seed, offset for dropout mask.
734-
rng_name(str): The name to select Generator.
735-
training(bool): Whether it is in the training phase.
736-
softmax_scale(float): The softmax scale of the attention.
737-
max_seqlen_q(int): Maximum sequence length of query in the batch. Note it's the padding length, not the max actual seqlen.
738-
max_seqlen_k(int): Maximum sequence length of key/value in the batch.
739-
name(str|None, optional): The default value is None. Normally there is no need for user
740-
to set this property. For more information, please refer to
741-
:ref:`api_guide_Name`.
742-
743-
Returns:
744-
out(Tensor): The attention tensor. 3-D tensor with shape: [token_num, num_heads, head_dim]. The dtype can be float16 or bfloat16.
745-
softmax(Tensor): The softmax tensor. None if return_softmax is False.
746-
747-
Examples:
748-
.. code-block:: python
749-
750-
>>> # doctest: +SKIP('flash_attn_v3 need H100 compile')
751-
>>> import paddle
752-
753-
>>> paddle.seed(2023)
754-
>>> q = paddle.rand((10, 2, 128), dtype="bfloat16")
755-
>>> cu_seqlens_q = paddle.to_tensor([0, 10], dtype="int32")
756-
>>> max_seq_len_q = 10
757-
758-
>>> output = paddle.nn.functional.flash_attention.flash_attention_v3_varlen(q, q, q, cu_seqlens_q, cu_seqlens_q, max_seqlen_q=max_seq_len_q, max_seqlen_k=max_seq_len_q, causal=True)
759-
>>> # doctest: -SKIP
760-
761-
"""
762-
if softmax_scale is None:
763-
softmax_scale = query.shape[-1] ** (-0.5)
764-
out, softmax_lse = _C_ops.flash_attn_v3_varlen(
765-
query,
766-
key,
767-
value,
768-
cu_seqlens_q,
769-
cu_seqlens_k,
770-
None, # q_v_
771-
None, # q_descale_
772-
None, # k_descale_
773-
None, # v_descale_
774-
softmax_scale,
775-
causal,
776-
-1, # window_size_left
777-
-1, # window_size_right
778-
0.0, # softcap
779-
1, # num_splits
780-
False, # manual_set_pack_gqa
781-
False, # pack_gqa_
782-
0, # sm_margin,
783-
max_seqlen_q,
784-
max_seqlen_k,
785-
)
786-
return out, softmax_lse # return_softmax
787-
788-
789638
@overload
790639
def flash_attn_qkvpacked(
791640
qkv: Tensor,
@@ -1075,15 +924,15 @@ def flash_attn_unpadded(
1075924
query(Tensor): The query tensor in the Attention module.
1076925
3-D tensor with shape:
1077926
[total_seq_len, num_heads, head_dim].
1078-
The dtype can be float16 or bfloat16.
927+
The dtype can be float61 or bfloat16.
1079928
key(Tensor): The key tensor in the Attention module.
1080929
3-D tensor with shape:
1081930
[total_seq_len, num_heads, head_dim].
1082-
The dtype can be float16 or bfloat16.
931+
The dtype can be float61 or bfloat16.
1083932
value(Tensor): The value tensor in the Attention module.
1084933
3-D tensor with shape:
1085934
[total_seq_len, num_heads, head_dim].
1086-
The dtype can be float16 or bfloat16.
935+
The dtype can be float61 or bfloat16.
1087936
cu_seqlens_q(Tensor): The cumulative sequence lengths of the sequences in the batch,
1088937
used to index query.
1089938
cu_seqlens_k(Tensor): The cumulative sequence lengths of the sequences in the batch,

0 commit comments

Comments
 (0)