diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index d4514479ca343..b3e17f5fbd34a 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -261,4 +261,5 @@ 'multi_margin_loss', 'soft_margin_loss', 'gaussian_nll_loss', + 'scaled_dot_product_attention', ] diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 178e3ebc90fa7..b36bd5d74ec7b 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -407,4 +407,57 @@ def flash_attn_unpadded( return out, softmax if return_softmax else None -scaled_dot_product_attention = flash_attention +def scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False +): + r""" + The equation is: + + .. math:: + + result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V + + where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. + The dimensions of the three parameters are the same. + ``d`` represents the size of the last dimension of the three parameters. + + Warning: + This API only supports inputs with dtype float16 and bfloat16. + + Args: + query(Tensor): The query tensor in the Attention module. + 4-D tensor with shape: + [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float61 or bfloat16. + key(Tensor): The key tensor in the Attention module. + 4-D tensor with shape: + [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float61 or bfloat16. + value(Tensor): The value tensor in the Attention module. + 4-D tensor with shape: + [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float61 or bfloat16. + attn_mask(Tensor,optional): A float mask of the same type as query, + key, value that is added to the attention score. + not supported yet. + dropout_p(float): The dropout ratio. + is_causal(bool): Whether enable causal mode. + + Returns: + out(Tensor): The attention tensor. + 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float16 or bfloat16. + + Examples: + .. code-block:: python + # required: skiptest + >>> # xdoctest: +SKIP() + >>> import paddle + >>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.bfloat16) + >>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False) + >>> print(output) + >>> # xdoctest: -SKIP + """ + assert attn_mask is None, "attn_mask is not supported yet" + out, _ = flash_attention(query, key, value, dropout_p, is_causal) + return out diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index da5febf4de3e6..6bde691bd2f95 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -25,6 +25,7 @@ from paddle.nn.functional.flash_attention import ( flash_attention, flash_attn_unpadded, + scaled_dot_product_attention, ) @@ -85,6 +86,7 @@ def setUp(self): self.causal = False self.return_softmax = False self.use_sdp_kernel = False + self.use_sdp_api = False def test_unpadded(self): print( @@ -212,9 +214,15 @@ def test_all(self): enable_flash=self.enable_flash, enable_mem_efficient=self.enable_mem_efficient, ): - out, _ = flash_attention( - q, k, v, self.dropout, self.causal, self.return_softmax - ) + if self.use_sdp_api: + out = scaled_dot_product_attention( + q, k, v, None, self.dropout, self.causal + ) + else: + out, _ = flash_attention( + q, k, v, self.dropout, self.causal, self.return_softmax + ) + else: out, _ = flash_attention( q, k, v, self.dropout, self.causal, self.return_softmax @@ -253,14 +261,19 @@ def test_all(self): enable_flash=self.enable_flash, enable_mem_efficient=self.enable_mem_efficient, ): - outs, softmax = flash_attention( - qs, - ks, - vs, - self.dropout, - self.causal, - self.return_softmax, - ) + if self.use_sdp_api: + outs = scaled_dot_product_attention( + qs, ks, vs, None, self.dropout, self.causal + ) + else: + outs, softmax = flash_attention( + qs, + ks, + vs, + self.dropout, + self.causal, + self.return_softmax, + ) else: outs, softmax = flash_attention( qs, ks, vs, self.dropout, self.causal, self.return_softmax @@ -334,6 +347,22 @@ def setUp(self): self.causal = False self.return_softmax = False self.use_sdp_kernel = True + self.use_sdp_api = False + self.enable_math = True + self.enable_flash = False + self.enable_mem_efficient = False + + +class TestSDPAttentionAPITest(TestFlashAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (8, 1024, 16, 128) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + self.use_sdp_kernel = True + self.use_sdp_api = True self.enable_math = True self.enable_flash = False self.enable_mem_efficient = False