@@ -434,15 +434,15 @@ def flash_attention(
434434 query(Tensor): The query tensor in the Attention module.
435435 4-D tensor with shape:
436436 [batch_size, seq_len, num_heads, head_dim].
437- The dtype can be float61 or bfloat16.
437+ The dtype can be float16 or bfloat16.
438438 key(Tensor): The key tensor in the Attention module.
439439 4-D tensor with shape:
440440 [batch_size, seq_len, num_heads, head_dim].
441- The dtype can be float61 or bfloat16.
441+ The dtype can be float16 or bfloat16.
442442 value(Tensor): The value tensor in the Attention module.
443443 4-D tensor with shape:
444444 [batch_size, seq_len, num_heads, head_dim].
445- The dtype can be float61 or bfloat16.
445+ The dtype can be float16 or bfloat16.
446446 dropout(float): The dropout ratio.
447447 causal(bool): Whether enable causal mode.
448448 return_softmax(bool): Whether to return softmax.
@@ -651,11 +651,94 @@ def flash_attention_v3_varlen(
651651 where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
652652 The dimensions of the three parameters are the same.
653653 ``d`` represents the size of the last dimension of the three parameters.
654+ This is the varlen version of flash attention.
654655
655656 Warning:
656657 This API is only support inputs with dtype float16 and bfloat16.
657658
658- This is the varlen version of flash attention.
659+ Args:
660+ query(Tensor): The query tensor in the Attention module.
661+ 3-D tensor with shape:
662+ [token_num, num_heads, head_dim].
663+ The dtype can be float16 or bfloat16.
664+ key(Tensor): The key tensor in the Attention module.
665+ 3-D tensor with shape:
666+ [token_num, num_heads, head_dim].
667+ The dtype can be float16 or bfloat16.
668+ value(Tensor): The value tensor in the Attention module.
669+ 3-D tensor with shape:
670+ [token_num, num_heads, head_dim].
671+ The dtype can be float16 or bfloat16.
672+ cu_seqlens_q(Tensor): The cumsum q seq lens tensor in the Attention module.
673+ 1-D tensor with shape: [batch_size + 1].
674+ The dtype is int32.
675+ cu_seqlens_k(Tensor): The cumsum kv seq lens tensor in the Attention module.
676+ 1-D tensor with shape: [batch_size + 1].
677+ The dtype is int32.
678+ dropout(float): The dropout ratio.
679+ causal(bool): Whether enable causal mode.
680+ return_softmax(bool): Whether to return softmax.
681+ fixed_seed_offset(Tensor|None, optional): With fixed seed, offset for dropout mask.
682+ rng_name(str): The name to select Generator.
683+ training(bool): Whether it is in the training phase.
684+ name(str|None, optional): The default value is None. Normally there is no need for user
685+ to set this property. For more information, please refer to
686+ :ref:`api_guide_Name`.
687+ softmax_scale(float): The softmax scale of the attention.
688+ max_seqlen_q(int): The max seq len of query.
689+ max_seqlen_k(int): The max seq len of kye/value.
690+
691+ Returns:
692+ out(Tensor): The attention tensor.
693+ 3-D tensor with shape: [token_num, num_heads, head_dim].
694+ The dtype can be float16 or bfloat16.
695+ softmax(Tensor): The softmax tensor. None if return_softmax is False.
696+
697+ Examples:
698+ .. code-block:: python
699+
700+ >>> import paddle
701+
702+ >>> paddle.seed(2023)
703+ >>> q = paddle.rand((10, 2, 128), dtype="bfloat16")
704+ >>> cu_seqlens_q = paddle.to_tensor([0, 10], dtype="int32")
705+ >>> max_seq_len_q = 10
706+
707+ >>> output = paddle.nn.functional.flash_attention.flash_attention_v3_varlen(q, q, q, cu_seqlens_q, cu_seqlens_q, max_seqlen_q=max_seq_len_q, max_seqlen_k=max_seq_len_q, causal=True)
708+ >>> print(output)
709+ (Tensor(shape=[10, 2, 128], dtype=bfloat16, place=Place(gpu:0), stop_gradient=True,
710+ [[[0.71875000, 0.47265625, 0.15722656, ..., 0.01062012,
711+ 0.27148438, 0.68750000],
712+ [0.46289062, 0.57421875, 0.94921875, ..., 0.26171875,
713+ 0.91015625, 0.61718750]],
714+
715+ [[0.55078125, 0.20898438, 0.69921875, ..., 0.06298828,
716+ 0.26367188, 0.32031250],
717+ [0.27148438, 0.75781250, 0.26367188, ..., 0.37890625,
718+ 0.83984375, 0.74609375]],
719+
720+ [[0.42968750, 0.23144531, 0.51562500, ..., 0.33007812,
721+ 0.51562500, 0.44531250],
722+ [0.46093750, 0.85156250, 0.51953125, ..., 0.64843750,
723+ 0.82812500, 0.62890625]],
724+
725+ ...,
726+
727+ [[0.36132812, 0.61718750, 0.53906250, ..., 0.45312500,
728+ 0.41015625, 0.52343750],
729+ [0.57421875, 0.70703125, 0.44531250, ..., 0.38867188,
730+ 0.68359375, 0.41015625]],
731+
732+ [[0.37304688, 0.68359375, 0.59375000, ..., 0.56640625,
733+ 0.36718750, 0.45898438],
734+ [0.37695312, 0.64453125, 0.51171875, ..., 0.53906250,
735+ 0.75390625, 0.35546875]],
736+
737+ [[0.46484375, 0.54296875, 0.47656250, ..., 0.51171875,
738+ 0.31640625, 0.50781250],
739+ [0.52734375, 0.58984375, 0.53515625, ..., 0.60156250,
740+ 0.74218750, 0.32617188]]]), None)
741+
659742 """
660743 if softmax_scale is None :
661744 softmax_scale = query .shape [- 1 ] ** (- 0.5 )
@@ -973,15 +1056,15 @@ def flash_attn_unpadded(
9731056 query(Tensor): The query tensor in the Attention module.
9741057 3-D tensor with shape:
9751058 [total_seq_len, num_heads, head_dim].
976- The dtype can be float61 or bfloat16.
1059+ The dtype can be float16 or bfloat16.
9771060 key(Tensor): The key tensor in the Attention module.
9781061 3-D tensor with shape:
9791062 [total_seq_len, num_heads, head_dim].
980- The dtype can be float61 or bfloat16.
1063+ The dtype can be float16 or bfloat16.
9811064 value(Tensor): The value tensor in the Attention module.
9821065 3-D tensor with shape:
9831066 [total_seq_len, num_heads, head_dim].
984- The dtype can be float61 or bfloat16.
1067+ The dtype can be float16 or bfloat16.
9851068 cu_seqlens_q(Tensor): The cumulative sequence lengths of the sequences in the batch,
9861069 used to index query.
9871070 cu_seqlens_k(Tensor): The cumulative sequence lengths of the sequences in the batch,
0 commit comments