From dcb70b9ddbee1563e60d5e34d0a82ee12e32fdfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Wed, 12 Jun 2024 13:14:09 +0800 Subject: [PATCH 1/3] Update sparse_attention.py --- .../paddle/nn/functional/sparse_attention.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/python/paddle/nn/functional/sparse_attention.py b/python/paddle/nn/functional/sparse_attention.py index 1d5f5013435bb..d9527e0cb6751 100644 --- a/python/paddle/nn/functional/sparse_attention.py +++ b/python/paddle/nn/functional/sparse_attention.py @@ -12,20 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import paddle from paddle import _legacy_C_ops, in_dynamic_mode from paddle.base.layer_helper import LayerHelper def sparse_attention( - query, - key, - value, - sparse_csr_offset, - sparse_csr_columns, - key_padding_mask=None, - attn_mask=None, - name=None, -): + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + sparse_csr_offset: paddle.Tensor, + sparse_csr_columns: paddle.Tensor, + key_padding_mask: paddle.Tensor = None, + attn_mask: paddle.Tensor = None, + name: str | None = None, +) -> paddle.Tensor: r""" This operator sparsify the Attention matrix in Transformer module to achieve the effect of reducing memory consumption and computation. @@ -76,7 +79,7 @@ def sparse_attention( 2-D tensor with shape: [seq_len, seq_len]. The dtype can be float32 and float64. A value of 0 means that the position is masked. - name(str, optional): The default value is None. Normally there is no need for user + name(str|None, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. From 7d223ce28596524a94bef68ba327d74c22f9fd1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Wed, 12 Jun 2024 21:32:38 +0800 Subject: [PATCH 2/3] Update python/paddle/nn/functional/sparse_attention.py Co-authored-by: Nyakku Shigure --- python/paddle/nn/functional/sparse_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/functional/sparse_attention.py b/python/paddle/nn/functional/sparse_attention.py index d9527e0cb6751..485a1d83f3fb2 100644 --- a/python/paddle/nn/functional/sparse_attention.py +++ b/python/paddle/nn/functional/sparse_attention.py @@ -25,8 +25,8 @@ def sparse_attention( value: paddle.Tensor, sparse_csr_offset: paddle.Tensor, sparse_csr_columns: paddle.Tensor, - key_padding_mask: paddle.Tensor = None, - attn_mask: paddle.Tensor = None, + key_padding_mask: paddle.Tensor | None = None, + attn_mask: paddle.Tensor | None = None, name: str | None = None, ) -> paddle.Tensor: r""" From ae9c296af828d37ba94241bf27ddea98a0adfd44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=98=A5=E4=B9=94?= <83450930+Liyulingyue@users.noreply.github.com> Date: Wed, 12 Jun 2024 21:36:20 +0800 Subject: [PATCH 3/3] Update sparse_attention.py --- python/paddle/nn/functional/sparse_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/nn/functional/sparse_attention.py b/python/paddle/nn/functional/sparse_attention.py index 485a1d83f3fb2..2243d6204300d 100644 --- a/python/paddle/nn/functional/sparse_attention.py +++ b/python/paddle/nn/functional/sparse_attention.py @@ -71,11 +71,11 @@ def sparse_attention( 3-D tensor with shape: [batch_size, num_heads, sparse_nnz]. The dtype should be int32. - key_padding_mask(Tensor, optional):The key padding mask tensor in the Attention module. + key_padding_mask(Tensor|None, optional):The key padding mask tensor in the Attention module. 2-D tensor with shape: [batch_size, seq_len]. The dtype can be float32 and float64. A value of 0 means that the position is masked. - attn_mask(Tensor, optional):The attention mask tensor in the Attention module. + attn_mask(Tensor|None, optional):The attention mask tensor in the Attention module. 2-D tensor with shape: [seq_len, seq_len]. The dtype can be float32 and float64. A value of 0 means that the position is masked. @@ -84,7 +84,7 @@ def sparse_attention( :ref:`api_guide_Name`. Returns: - 4-D tensor with shape: + Tensor, 4-D tensor with shape: [batch_size, num_heads, seq_len, head_dim]. The dtype can be float32 or float64.