Skip to content

Commit

Permalink
polish
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhenhai93 committed Jul 9, 2023
1 parent 3362289 commit d419bc7
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions python/paddle/nn/functional/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,53 @@ def flash_attn_unpadded(
def scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
):
r"""
The equation is:
.. math::
result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V
where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
The dimensions of the three parameters are the same.
``d`` represents the size of the last dimension of the three parameters.
Warning:
This API is only support inputs with dtype float16 and bfloat16.
Args:
query(Tensor): The query tensor in the Attention module.
4-D tensor with shape:
[batch_size, seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
key(Tensor): The key tensor in the Attention module.
4-D tensor with shape:
[batch_size, seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
value(Tensor): The value tensor in the Attention module.
4-D tensor with shape:
[batch_size, seq_len, num_heads, head_dim].
The dtype can be float61 or bfloat16.
attn_mask(Tensor,optional): A float mask of the same type as query,
key, value that is added to the attention score.
not supported yet.
dropout_p(float): The dropout ratio.
is_causal(bool): Whether enable causal mode.
Returns:
out(Tensor): The attention tensor.
4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim].
The dtype can be float16 or bfloat16.
Examples:
.. code-block:: python
# required: skiptest
import paddle
q = paddle.rand((1, 128, 2, 16), dtype=paddle.float16)
output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False)
print(output)
"""
assert attn_mask is None, "attn_mask is not supported yet"
out, _ = flash_attention(query, key, value, dropout_p, is_causal)
return out

0 comments on commit d419bc7

Please sign in to comment.