Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][A-64] Add type annotations for paddle/nn/functional/sparse_attention.py #65064

Merged
merged 3 commits into from
Jun 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions python/paddle/nn/functional/sparse_attention.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

张师傅帮忙把 docstring 中的 Returns 也该以下吧 ~

    Returns:
        4-D tensor with shape:

改为

    Returns:
        Tensor: 4-D tensor with shape:

也就是,return type: description 这样的形式 ~

Copy link
Member

@SigureMo SigureMo Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不可以用 <return type>: <description> 的形式,会被 sphinx 解析出 type,渲染会很奇怪,用 <return type>, <description>

cc @sunzhongkai588

Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import paddle
from paddle import _legacy_C_ops, in_dynamic_mode
from paddle.base.layer_helper import LayerHelper


def sparse_attention(
query,
key,
value,
sparse_csr_offset,
sparse_csr_columns,
key_padding_mask=None,
attn_mask=None,
name=None,
):
query: paddle.Tensor,
key: paddle.Tensor,
value: paddle.Tensor,
sparse_csr_offset: paddle.Tensor,
sparse_csr_columns: paddle.Tensor,
key_padding_mask: paddle.Tensor | None = None,
attn_mask: paddle.Tensor | None = None,
name: str | None = None,
) -> paddle.Tensor:
r"""
This operator sparsify the Attention matrix in Transformer module
to achieve the effect of reducing memory consumption and computation.
Expand Down Expand Up @@ -68,20 +71,20 @@ def sparse_attention(
3-D tensor with shape:
[batch_size, num_heads, sparse_nnz].
The dtype should be int32.
key_padding_mask(Tensor, optional):The key padding mask tensor in the Attention module.
key_padding_mask(Tensor|None, optional):The key padding mask tensor in the Attention module.
2-D tensor with shape: [batch_size, seq_len].
The dtype can be float32 and float64.
A value of 0 means that the position is masked.
attn_mask(Tensor, optional):The attention mask tensor in the Attention module.
attn_mask(Tensor|None, optional):The attention mask tensor in the Attention module.
2-D tensor with shape: [seq_len, seq_len].
The dtype can be float32 and float64.
A value of 0 means that the position is masked.
name(str, optional): The default value is None. Normally there is no need for user
name(str|None, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.

Returns:
4-D tensor with shape:
Tensor, 4-D tensor with shape:
[batch_size, num_heads, seq_len, head_dim].
The dtype can be float32 or float64.

Expand Down