From 3f0c5205f874bf94b7b93b2b3f215b7bcad9b6dd Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Fri, 7 Jul 2023 08:23:34 +0000 Subject: [PATCH 1/9] scaled_dot_product_attention api --- python/paddle/nn/functional/__init__.py | 1 + python/paddle/nn/functional/flash_attention.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index ec5ee96e3cc91..233eb62708c02 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -255,4 +255,5 @@ 'multi_margin_loss', 'soft_margin_loss', 'gaussian_nll_loss', + 'scaled_dot_product_attention', ] diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 178e3ebc90fa7..207607e4a4cb1 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -407,4 +407,8 @@ def flash_attn_unpadded( return out, softmax if return_softmax else None -scaled_dot_product_attention = flash_attention +def scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False +): + assert attn_mask is None, "attn_mask is not supported yet" + return flash_attention(query, key, value, dropout_p, is_causal) From 3362289ba5d09e30ad1756da0b5f2fc7cf5c748a Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Fri, 7 Jul 2023 08:59:12 +0000 Subject: [PATCH 2/9] add test --- .../paddle/nn/functional/flash_attention.py | 3 +- test/legacy_test/test_flash_attention.py | 51 +++++++++++++++---- 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 207607e4a4cb1..277c633c8f3d3 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -411,4 +411,5 @@ def scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False ): assert attn_mask is None, "attn_mask is not supported yet" - return flash_attention(query, key, value, dropout_p, is_causal) + out, _ = flash_attention(query, key, value, dropout_p, is_causal) + return out diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index da5febf4de3e6..4b290d80aecc4 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -25,6 +25,7 @@ from paddle.nn.functional.flash_attention import ( flash_attention, flash_attn_unpadded, + scaled_dot_product_attention, ) @@ -85,6 +86,7 @@ def setUp(self): self.causal = False self.return_softmax = False self.use_sdp_kernel = False + self.use_sdp_api = False def test_unpadded(self): print( @@ -212,9 +214,15 @@ def test_all(self): enable_flash=self.enable_flash, enable_mem_efficient=self.enable_mem_efficient, ): - out, _ = flash_attention( - q, k, v, self.dropout, self.causal, self.return_softmax - ) + if self.use_sdp_api: + out = scaled_dot_product_attention( + q, k, v, None, self.dropout, self.causal + ) + else: + out, _ = flash_attention( + q, k, v, self.dropout, self.causal, self.return_softmax + ) + else: out, _ = flash_attention( q, k, v, self.dropout, self.causal, self.return_softmax @@ -253,14 +261,19 @@ def test_all(self): enable_flash=self.enable_flash, enable_mem_efficient=self.enable_mem_efficient, ): - outs, softmax = flash_attention( - qs, - ks, - vs, - self.dropout, - self.causal, - self.return_softmax, - ) + if self.use_sdp_api: + out = scaled_dot_product_attention( + q, k, v, None, self.dropout, self.causal + ) + else: + outs, softmax = flash_attention( + qs, + ks, + vs, + self.dropout, + self.causal, + self.return_softmax, + ) else: outs, softmax = flash_attention( qs, ks, vs, self.dropout, self.causal, self.return_softmax @@ -334,6 +347,22 @@ def setUp(self): self.causal = False self.return_softmax = False self.use_sdp_kernel = True + self.use_sdp_api = False + self.enable_math = True + self.enable_flash = False + self.enable_mem_efficient = False + + +class TestSDPAttentionAPITest(TestFlashAttentionAPI): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (8, 1024, 16, 128) + self.dtype = paddle.float16 + self.dropout = 0.0 + self.causal = False + self.return_softmax = False + self.use_sdp_kernel = True + self.use_sdp_api = True self.enable_math = True self.enable_flash = False self.enable_mem_efficient = False From d419bc701cc842282802ac643051796bd8b57cab Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Sun, 9 Jul 2023 09:38:44 +0000 Subject: [PATCH 3/9] polish --- .../paddle/nn/functional/flash_attention.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 277c633c8f3d3..6d6f85655268c 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -410,6 +410,53 @@ def flash_attn_unpadded( def scaled_dot_product_attention( query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False ): + r""" + The equation is: + + .. math:: + + result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V + + where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module. + The dimensions of the three parameters are the same. + ``d`` represents the size of the last dimension of the three parameters. + + Warning: + This API is only support inputs with dtype float16 and bfloat16. + + Args: + query(Tensor): The query tensor in the Attention module. + 4-D tensor with shape: + [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float61 or bfloat16. + key(Tensor): The key tensor in the Attention module. + 4-D tensor with shape: + [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float61 or bfloat16. + value(Tensor): The value tensor in the Attention module. + 4-D tensor with shape: + [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float61 or bfloat16. + attn_mask(Tensor,optional): A float mask of the same type as query, + key, value that is added to the attention score. + not supported yet. + dropout_p(float): The dropout ratio. + is_causal(bool): Whether enable causal mode. + + Returns: + out(Tensor): The attention tensor. + 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. + The dtype can be float16 or bfloat16. + + Examples: + .. code-block:: python + + # required: skiptest + import paddle + q = paddle.rand((1, 128, 2, 16), dtype=paddle.float16) + output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False) + print(output) + """ assert attn_mask is None, "attn_mask is not supported yet" out, _ = flash_attention(query, key, value, dropout_p, is_causal) return out From d161b82ac254d52cb738d2fae08c30a09bc4430f Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 18 Jul 2023 12:43:11 +0000 Subject: [PATCH 4/9] test --- test/legacy_test/test_flash_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/legacy_test/test_flash_attention.py b/test/legacy_test/test_flash_attention.py index 4b290d80aecc4..6bde691bd2f95 100644 --- a/test/legacy_test/test_flash_attention.py +++ b/test/legacy_test/test_flash_attention.py @@ -262,8 +262,8 @@ def test_all(self): enable_mem_efficient=self.enable_mem_efficient, ): if self.use_sdp_api: - out = scaled_dot_product_attention( - q, k, v, None, self.dropout, self.causal + outs = scaled_dot_product_attention( + qs, ks, vs, None, self.dropout, self.causal ) else: outs, softmax = flash_attention( From bb9aecd2276e00efa496e511524bdbeea4ee6af8 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Thu, 20 Jul 2023 11:45:49 +0000 Subject: [PATCH 5/9] polish doc --- python/paddle/nn/functional/flash_attention.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 6d6f85655268c..e11dd4a2d3a7d 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -422,7 +422,7 @@ def scaled_dot_product_attention( ``d`` represents the size of the last dimension of the three parameters. Warning: - This API is only support inputs with dtype float16 and bfloat16. + This API only supports inputs with dtype float16 and bfloat16. Args: query(Tensor): The query tensor in the Attention module. @@ -452,10 +452,10 @@ def scaled_dot_product_attention( .. code-block:: python # required: skiptest - import paddle - q = paddle.rand((1, 128, 2, 16), dtype=paddle.float16) - output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False) - print(output) + >>> import paddle + >>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.float16) + >>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False) + >>> print(output) """ assert attn_mask is None, "attn_mask is not supported yet" out, _ = flash_attention(query, key, value, dropout_p, is_causal) From bea4ee35af76732365e115877174433dd5c6daf9 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Sat, 22 Jul 2023 14:08:52 +0000 Subject: [PATCH 6/9] polish --- python/paddle/nn/functional/flash_attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index e11dd4a2d3a7d..b5af8cac8ab49 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -451,7 +451,6 @@ def scaled_dot_product_attention( Examples: .. code-block:: python - # required: skiptest >>> import paddle >>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.float16) >>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False) From 1d2a94ee159c6cca50567f9c4fa0cff7ca9f5254 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 25 Jul 2023 10:38:00 +0000 Subject: [PATCH 7/9] polish --- python/paddle/nn/functional/flash_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index b5af8cac8ab49..4d737b870d0f7 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -452,7 +452,7 @@ def scaled_dot_product_attention( .. code-block:: python >>> import paddle - >>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.float16) + >>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.bfloat16) >>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False) >>> print(output) """ From e68519e7baf7273e3d88396fd95d952775be0b36 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 25 Jul 2023 11:13:49 +0000 Subject: [PATCH 8/9] polish --- python/paddle/nn/functional/flash_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 4d737b870d0f7..46d8a6ea8d0eb 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -450,7 +450,7 @@ def scaled_dot_product_attention( Examples: .. code-block:: python - + # required: skiptest >>> import paddle >>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.bfloat16) >>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False) From 29a118d3ccffcdab5936d45325483b026e049e57 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Tue, 1 Aug 2023 02:32:42 +0000 Subject: [PATCH 9/9] skip xtest --- python/paddle/nn/functional/flash_attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/nn/functional/flash_attention.py b/python/paddle/nn/functional/flash_attention.py index 46d8a6ea8d0eb..b36bd5d74ec7b 100644 --- a/python/paddle/nn/functional/flash_attention.py +++ b/python/paddle/nn/functional/flash_attention.py @@ -451,10 +451,12 @@ def scaled_dot_product_attention( Examples: .. code-block:: python # required: skiptest + >>> # xdoctest: +SKIP() >>> import paddle >>> q = paddle.rand((1, 128, 2, 16), dtype=paddle.bfloat16) >>> output = paddle.nn.functional.scaled_dot_product_attention(q, q, q, None, 0.9, False) >>> print(output) + >>> # xdoctest: -SKIP """ assert attn_mask is None, "attn_mask is not supported yet" out, _ = flash_attention(query, key, value, dropout_p, is_causal)