diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 277c633c8f3d3..6d6f85655268c 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -410,6 +410,53 @@ def flash_attn_unpadded( def scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False ): + r""" + The equation is: + + .. math:: + + result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V + + where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. + The dimensions of the three parameters are the same. + ``d`` represents the size of the last dimension of the three parameters. + + Warning: + This API is only support inputs with dtype float16 and bfloat16. + + Args: + query(Tensor): The query tensor in the Attention module. + 4-D tensor with shape: + [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float61 or bfloat16. + key(Tensor): The key tensor in the Attention module. + 4-D tensor with shape: + [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float61 or bfloat16. + value(Tensor): The value tensor in the Attention module. + 4-D tensor with shape: + [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float61 or bfloat16. + attn_mask(Tensor,optional): A float mask of the same type as query, + key, value that is added to the attention score. + not supported yet. + dropout_p(float): The dropout ratio. + is_causal(bool): Whether enable causal mode. + + Returns: + out(Tensor): The attention tensor. + 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float16 or bfloat16. + + Examples: + .. code-block:: python + + # required: skiptest + import paddle + q = paddle.rand((1, 128, 2, 16), dtype=paddle.float16) + output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False) + print(output) + """ assert attn_mask is None, "attn_mask is not supported yet" out, _ = flash_attention(query, key, value, dropout_p, is_causal) return out