Skip to content

Commit c2f5511

Browse files
committed
supports fa3_varlen api
1 parent f980cfa commit c2f5511

File tree

1 file changed

+90
-7
lines changed

1 file changed

+90
-7
lines changed

python/paddle/nn/functional/flash_attention.py

Lines changed: 90 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -434,15 +434,15 @@ def flash_attention(
434434
query(Tensor): The query tensor in the Attention module.
435435
4-D tensor with shape:
436436
[batch_size, seq_len, num_heads, head_dim].
437-
The dtype can be float61 or bfloat16.
437+
The dtype can be float16 or bfloat16.
438438
key(Tensor): The key tensor in the Attention module.
439439
4-D tensor with shape:
440440
[batch_size, seq_len, num_heads, head_dim].
441-
The dtype can be float61 or bfloat16.
441+
The dtype can be float16 or bfloat16.
442442
value(Tensor): The value tensor in the Attention module.
443443
4-D tensor with shape:
444444
[batch_size, seq_len, num_heads, head_dim].
445-
The dtype can be float61 or bfloat16.
445+
The dtype can be float16 or bfloat16.
446446
dropout(float): The dropout ratio.
447447
causal(bool): Whether enable causal mode.
448448
return_softmax(bool): Whether to return softmax.
@@ -651,11 +651,94 @@ def flash_attention_v3_varlen(
651651
where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
652652
The dimensions of the three parameters are the same.
653653
``d`` represents the size of the last dimension of the three parameters.
654+
This is the varlen version of flash attention.
654655
655656
Warning:
656657
This API is only support inputs with dtype float16 and bfloat16.
657658
658-
This is the varlen version of flash attention.
659+
Args:
660+
query(Tensor): The query tensor in the Attention module.
661+
3-D tensor with shape:
662+
[token_num, num_heads, head_dim].
663+
The dtype can be float16 or bfloat16.
664+
key(Tensor): The key tensor in the Attention module.
665+
3-D tensor with shape:
666+
[token_num, num_heads, head_dim].
667+
The dtype can be float16 or bfloat16.
668+
value(Tensor): The value tensor in the Attention module.
669+
3-D tensor with shape:
670+
[token_num, num_heads, head_dim].
671+
The dtype can be float16 or bfloat16.
672+
cu_seqlens_q(Tensor): The cumsum q seq lens tensor in the Attention module.
673+
1-D tensor with shape: [batch_size + 1].
674+
The dtype is int32.
675+
cu_seqlens_k(Tensor): The cumsum kv seq lens tensor in the Attention module.
676+
1-D tensor with shape: [batch_size + 1].
677+
The dtype is int32.
678+
dropout(float): The dropout ratio.
679+
causal(bool): Whether enable causal mode.
680+
return_softmax(bool): Whether to return softmax.
681+
fixed_seed_offset(Tensor|None, optional): With fixed seed, offset for dropout mask.
682+
rng_name(str): The name to select Generator.
683+
training(bool): Whether it is in the training phase.
684+
name(str|None, optional): The default value is None. Normally there is no need for user
685+
to set this property. For more information, please refer to
686+
:ref:`api_guide_Name`.
687+
softmax_scale(float): The softmax scale of the attention.
688+
max_seqlen_q(int): The max seq len of query.
689+
max_seqlen_k(int): The max seq len of kye/value.
690+
691+
Returns:
692+
out(Tensor): The attention tensor.
693+
3-D tensor with shape: [token_num, num_heads, head_dim].
694+
The dtype can be float16 or bfloat16.
695+
softmax(Tensor): The softmax tensor. None if return_softmax is False.
696+
697+
Examples:
698+
.. code-block:: python
699+
700+
>>> import paddle
701+
702+
>>> paddle.seed(2023)
703+
>>> q = paddle.rand((10, 2, 128), dtype="bfloat16")
704+
>>> cu_seqlens_q = paddle.to_tensor([0, 10], dtype="int32")
705+
>>> max_seq_len_q = 10
706+
707+
>>> output = paddle.nn.functional.flash_attention.flash_attention_v3_varlen(q, q, q, cu_seqlens_q, cu_seqlens_q, max_seqlen_q=max_seq_len_q, max_seqlen_k=max_seq_len_q, causal=True)
708+
>>> print(output)
709+
(Tensor(shape=[10, 2, 128], dtype=bfloat16, place=Place(gpu:0), stop_gradient=True,
710+
[[[0.71875000, 0.47265625, 0.15722656, ..., 0.01062012,
711+
0.27148438, 0.68750000],
712+
[0.46289062, 0.57421875, 0.94921875, ..., 0.26171875,
713+
0.91015625, 0.61718750]],
714+
715+
[[0.55078125, 0.20898438, 0.69921875, ..., 0.06298828,
716+
0.26367188, 0.32031250],
717+
[0.27148438, 0.75781250, 0.26367188, ..., 0.37890625,
718+
0.83984375, 0.74609375]],
719+
720+
[[0.42968750, 0.23144531, 0.51562500, ..., 0.33007812,
721+
0.51562500, 0.44531250],
722+
[0.46093750, 0.85156250, 0.51953125, ..., 0.64843750,
723+
0.82812500, 0.62890625]],
724+
725+
...,
726+
727+
[[0.36132812, 0.61718750, 0.53906250, ..., 0.45312500,
728+
0.41015625, 0.52343750],
729+
[0.57421875, 0.70703125, 0.44531250, ..., 0.38867188,
730+
0.68359375, 0.41015625]],
731+
732+
[[0.37304688, 0.68359375, 0.59375000, ..., 0.56640625,
733+
0.36718750, 0.45898438],
734+
[0.37695312, 0.64453125, 0.51171875, ..., 0.53906250,
735+
0.75390625, 0.35546875]],
736+
737+
[[0.46484375, 0.54296875, 0.47656250, ..., 0.51171875,
738+
0.31640625, 0.50781250],
739+
[0.52734375, 0.58984375, 0.53515625, ..., 0.60156250,
740+
0.74218750, 0.32617188]]]), None)
741+
659742
"""
660743
if softmax_scale is None:
661744
softmax_scale = query.shape[-1] ** (-0.5)
@@ -973,15 +1056,15 @@ def flash_attn_unpadded(
9731056
query(Tensor): The query tensor in the Attention module.
9741057
3-D tensor with shape:
9751058
[total_seq_len, num_heads, head_dim].
976-
The dtype can be float61 or bfloat16.
1059+
The dtype can be float16 or bfloat16.
9771060
key(Tensor): The key tensor in the Attention module.
9781061
3-D tensor with shape:
9791062
[total_seq_len, num_heads, head_dim].
980-
The dtype can be float61 or bfloat16.
1063+
The dtype can be float16 or bfloat16.
9811064
value(Tensor): The value tensor in the Attention module.
9821065
3-D tensor with shape:
9831066
[total_seq_len, num_heads, head_dim].
984-
The dtype can be float61 or bfloat16.
1067+
The dtype can be float16 or bfloat16.
9851068
cu_seqlens_q(Tensor): The cumulative sequence lengths of the sequences in the batch,
9861069
used to index query.
9871070
cu_seqlens_k(Tensor): The cumulative sequence lengths of the sequences in the batch,

0 commit comments

Comments
 (0)