@@ -446,15 +446,15 @@ def flash_attention(
446446 query(Tensor): The query tensor in the Attention module.
447447 4-D tensor with shape:
448448 [batch_size, seq_len, num_heads, head_dim].
449- The dtype can be float16 or bfloat16.
449+ The dtype can be float61 or bfloat16.
450450 key(Tensor): The key tensor in the Attention module.
451451 4-D tensor with shape:
452452 [batch_size, seq_len, num_heads, head_dim].
453- The dtype can be float16 or bfloat16.
453+ The dtype can be float61 or bfloat16.
454454 value(Tensor): The value tensor in the Attention module.
455455 4-D tensor with shape:
456456 [batch_size, seq_len, num_heads, head_dim].
457- The dtype can be float16 or bfloat16.
457+ The dtype can be float61 or bfloat16.
458458 dropout(float): The dropout ratio.
459459 causal(bool): Whether enable causal mode.
460460 return_softmax(bool): Whether to return softmax.
@@ -635,157 +635,6 @@ def flash_attention(
635635 )
636636
637637
638- @overload
639- def flash_attention_v3_varlen (
640- query : Tensor ,
641- key : Tensor ,
642- value : Tensor ,
643- cu_seqlens_q : Tensor ,
644- cu_seqlens_k : Tensor ,
645- dropout : float = ...,
646- causal : bool = ...,
647- return_softmax : Literal [False ] = ...,
648- * ,
649- fixed_seed_offset : Tensor | None = ...,
650- rng_name : str = ...,
651- training : bool = ...,
652- softmax_scale : float | None = ...,
653- max_seqlen_q : int = ...,
654- max_seqlen_k : int = ...,
655- name : str | None = ...,
656- ) -> tuple [Tensor , None ]: ...
657-
658-
659- @overload
660- def flash_attention_v3_varlen (
661- query : Tensor ,
662- key : Tensor ,
663- value : Tensor ,
664- cu_seqlens_q : Tensor ,
665- cu_seqlens_k : Tensor ,
666- dropout : float = ...,
667- causal : bool = ...,
668- return_softmax : Literal [True ] = ...,
669- * ,
670- fixed_seed_offset : Tensor | None = ...,
671- rng_name : str = ...,
672- training : bool = ...,
673- softmax_scale : float | None = ...,
674- max_seqlen_q : int = ...,
675- max_seqlen_k : int = ...,
676- name : str | None = ...,
677- ) -> tuple [Tensor , Tensor ]: ...
678-
679-
680- def flash_attention_v3_varlen (
681- query ,
682- key ,
683- value ,
684- cu_seqlens_q ,
685- cu_seqlens_k ,
686- dropout = 0.0 ,
687- causal = False ,
688- return_softmax = False ,
689- * ,
690- fixed_seed_offset = None ,
691- rng_name = "" ,
692- training = True ,
693- softmax_scale = None ,
694- max_seqlen_q = 0 ,
695- max_seqlen_k = 0 ,
696- name = None ,
697- ):
698- r"""
699- The equation is:
700-
701- .. math::
702-
703- result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V
704-
705- where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
706- The dimensions of the three parameters are the same.
707- ``d`` represents the size of the last dimension of the three parameters.
708- This is the varlen version of flash attention.
709-
710- Warning:
711- This API is only support inputs with dtype float16 and bfloat16.
712-
713- Args:
714- query(Tensor): The query tensor in the Attention module.
715- 3-D tensor with shape:
716- [token_num, num_heads, head_dim].
717- The dtype can be float16 or bfloat16.
718- key(Tensor): The key tensor in the Attention module.
719- 3-D tensor with shape:
720- [token_num, num_heads, head_dim].
721- The dtype can be float16 or bfloat16.
722- value(Tensor): The value tensor in the Attention module.
723- 3-D tensor with shape:
724- [token_num, num_heads, head_dim].
725- The dtype can be float16 or bfloat16.
726- cu_seqlens_q(Tensor): The cumulative sequence lengths of the sequences in the batch,
727- used to index query.
728- cu_seqlens_k(Tensor): The cumulative sequence lengths of the sequences in the batch,
729- used to index key and value.
730- dropout(float): The dropout ratio.
731- causal(bool): Whether enable causal mode.
732- return_softmax(bool): Whether to return softmax.
733- fixed_seed_offset(Tensor|None, optional): With fixed seed, offset for dropout mask.
734- rng_name(str): The name to select Generator.
735- training(bool): Whether it is in the training phase.
736- softmax_scale(float): The softmax scale of the attention.
737- max_seqlen_q(int): Maximum sequence length of query in the batch. Note it's the padding length, not the max actual seqlen.
738- max_seqlen_k(int): Maximum sequence length of key/value in the batch.
739- name(str|None, optional): The default value is None. Normally there is no need for user
740- to set this property. For more information, please refer to
741- :ref:`api_guide_Name`.
742-
743- Returns:
744- out(Tensor): The attention tensor. 3-D tensor with shape: [token_num, num_heads, head_dim]. The dtype can be float16 or bfloat16.
745- softmax(Tensor): The softmax tensor. None if return_softmax is False.
746-
747- Examples:
748- .. code-block:: python
749-
750- >>> # doctest: +SKIP('flash_attn_v3 need H100 compile')
751- >>> import paddle
752-
753- >>> paddle.seed(2023)
754- >>> q = paddle.rand((10, 2, 128), dtype="bfloat16")
755- >>> cu_seqlens_q = paddle.to_tensor([0, 10], dtype="int32")
756- >>> max_seq_len_q = 10
757-
758- >>> output = paddle.nn.functional.flash_attention.flash_attention_v3_varlen(q, q, q, cu_seqlens_q, cu_seqlens_q, max_seqlen_q=max_seq_len_q, max_seqlen_k=max_seq_len_q, causal=True)
759- >>> # doctest: -SKIP
760-
761- """
762- if softmax_scale is None :
763- softmax_scale = query .shape [- 1 ] ** (- 0.5 )
764- out , softmax_lse = _C_ops .flash_attn_v3_varlen (
765- query ,
766- key ,
767- value ,
768- cu_seqlens_q ,
769- cu_seqlens_k ,
770- None , # q_v_
771- None , # q_descale_
772- None , # k_descale_
773- None , # v_descale_
774- softmax_scale ,
775- causal ,
776- - 1 , # window_size_left
777- - 1 , # window_size_right
778- 0.0 , # softcap
779- 1 , # num_splits
780- False , # manual_set_pack_gqa
781- False , # pack_gqa_
782- 0 , # sm_margin,
783- max_seqlen_q ,
784- max_seqlen_k ,
785- )
786- return out , softmax_lse # return_softmax
787-
788-
789638@overload
790639def flash_attn_qkvpacked (
791640 qkv : Tensor ,
@@ -1075,15 +924,15 @@ def flash_attn_unpadded(
1075924 query(Tensor): The query tensor in the Attention module.
1076925 3-D tensor with shape:
1077926 [total_seq_len, num_heads, head_dim].
1078- The dtype can be float16 or bfloat16.
927+ The dtype can be float61 or bfloat16.
1079928 key(Tensor): The key tensor in the Attention module.
1080929 3-D tensor with shape:
1081930 [total_seq_len, num_heads, head_dim].
1082- The dtype can be float16 or bfloat16.
931+ The dtype can be float61 or bfloat16.
1083932 value(Tensor): The value tensor in the Attention module.
1084933 3-D tensor with shape:
1085934 [total_seq_len, num_heads, head_dim].
1086- The dtype can be float16 or bfloat16.
935+ The dtype can be float61 or bfloat16.
1087936 cu_seqlens_q(Tensor): The cumulative sequence lengths of the sequences in the batch,
1088937 used to index query.
1089938 cu_seqlens_k(Tensor): The cumulative sequence lengths of the sequences in the batch,
0 commit comments