Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][C-68,C-77][BUAA] Add type annotations for python/paddle/nn/* #67186

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 81 additions & 74 deletions python/paddle/incubate/nn/functional/block_multihead_attention.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI 报错,需要编译选项,参考 https://github.com/PaddlePaddle/Paddle/pull/67178/files 在示例中添加

            >>> # doctest: +SKIP('Need compile flash attention')
            >>> # doctest: +REQUIRES(env:GPU)

Original file line number Diff line number Diff line change
Expand Up @@ -12,47 +12,54 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

from paddle import _C_ops
from paddle.framework import LayerHelper, in_dynamic_mode

if TYPE_CHECKING:
from paddle import Tensor


def block_multihead_attention(
qkv,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
padding_offsets,
cum_offsets,
cu_seqlens_q,
cu_seqlens_k,
block_tables,
pre_key_cache=None,
pre_value_cache=None,
cache_k_quant_scales=None,
cache_v_quant_scales=None,
cache_k_dequant_scales=None,
cache_v_dequant_scales=None,
qkv_out_scale=None,
qkv_bias=None,
out_shift=None,
out_smooth=None,
max_enc_len_this_time=None,
max_dec_len_this_time=None,
rope_emb=None,
mask=None,
tgt_mask=None,
max_seq_len=-1,
block_size=64,
use_neox_style=False,
use_dynamic_cachekv_quant=False,
quant_round_type=1,
quant_max_bound=127.0,
quant_min_bound=-127.0,
out_scale=-1,
compute_dtype="default",
):
qkv: Tensor,
key_cache: Tensor,
value_cache: Tensor,
seq_lens_encoder: Tensor,
seq_lens_decoder: Tensor,
seq_lens_this_time: Tensor,
padding_offsets: Tensor,
cum_offsets: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_k: Tensor,
block_tables: Tensor,
pre_key_cache: Tensor | None = None,
pre_value_cache: Tensor | None = None,
cache_k_quant_scales: Tensor | None = None,
cache_v_quant_scales: Tensor | None = None,
cache_k_dequant_scales: Tensor | None = None,
cache_v_dequant_scales: Tensor | None = None,
qkv_out_scale: Tensor | None = None,
qkv_bias: Tensor | None = None,
out_shift: Tensor | None = None,
out_smooth: Tensor | None = None,
max_enc_len_this_time: Tensor | None = None,
max_dec_len_this_time: Tensor | None = None,
rope_emb: Tensor | None = None,
mask: Tensor | None = None,
tgt_mask: Tensor | None = None,
max_seq_len: int = -1,
block_size: int = 64,
use_neox_style: bool = False,
use_dynamic_cachekv_quant: bool = False,
quant_round_type: int = 1,
quant_max_bound: float = 127.0,
quant_min_bound: float = -127.0,
out_scale: Tensor = -1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
out_scale: Tensor = -1,
out_scale: float = -1,

compute_dtype: str = "default",
) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> Tensor:
) -> tuple[Tensor, Tensor, Tensor, Tensor]:

"""
Block Multi-head attention for text summarization.

Expand Down Expand Up @@ -392,44 +399,44 @@ def block_multihead_attention(


def block_multihead_attention_xpu(
qkv,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
padding_offsets,
cum_offsets,
cu_seqlens_q,
cu_seqlens_k,
block_tables,
cache_k_per_batch_maxs,
cache_v_per_batch_maxs,
pre_key_cache=None,
pre_value_cache=None,
cache_k_quant_scales=None,
cache_v_quant_scales=None,
cache_k_dequant_scales=None,
cache_v_dequant_scales=None,
qkv_out_scale=None,
qkv_bias=None,
out_shift=None,
out_smooth=None,
max_enc_len_this_time=None,
max_dec_len_this_time=None,
rope_emb=None,
mask=None,
tgt_mask=None,
max_seq_len=-1,
block_size=64,
use_neox_style=False,
use_dynamic_cachekv_quant=False,
quant_round_type=1,
quant_max_bound=127.0,
quant_min_bound=-127.0,
out_scale=-1,
compute_dtype="default",
):
qkv: Tensor,
key_cache: Tensor,
value_cache: Tensor,
seq_lens_encoder: Tensor,
seq_lens_decoder: Tensor,
seq_lens_this_time: Tensor,
padding_offsets: Tensor,
cum_offsets: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_k: Tensor,
block_tables: Tensor,
cache_k_per_batch_maxs: Tensor,
cache_v_per_batch_maxs: Tensor,
pre_key_cache: Tensor | None = None,
pre_value_cache: Tensor | None = None,
cache_k_quant_scales: Tensor | None = None,
cache_v_quant_scales: Tensor | None = None,
cache_k_dequant_scales: Tensor | None = None,
cache_v_dequant_scales: Tensor | None = None,
qkv_out_scale: Tensor | None = None,
qkv_bias: Tensor | None = None,
out_shift: Tensor | None = None,
out_smooth: Tensor | None = None,
max_enc_len_this_time: Tensor | None = None,
max_dec_len_this_time: Tensor | None = None,
rope_emb: Tensor | None = None,
mask: Tensor | None = None,
tgt_mask: Tensor | None = None,
max_seq_len: int = -1,
block_size: int = 64,
use_neox_style: bool = False,
use_dynamic_cachekv_quant: bool = False,
quant_round_type: int = 1,
quant_max_bound: float = 127.0,
quant_min_bound: float = -127.0,
out_scale: int = -1,
compute_dtype: str = "default",
) -> Tensor:
Comment on lines +437 to +439
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

if in_dynamic_mode():
return _C_ops.block_multihead_attention_xpu(
qkv,
Expand Down
Loading