Skip to content

Commit

Permalink
Update block_multihead_attention.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lwkhahaha authored Aug 8, 2024
1 parent 949d56a commit 21229fd
Showing 1 changed file with 30 additions and 30 deletions.
60 changes: 30 additions & 30 deletions python/paddle/incubate/nn/functional/block_multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,21 @@ def block_multihead_attention(
cu_seqlens_q: Tensor,
cu_seqlens_k: Tensor,
block_tables: Tensor,
pre_key_cache: Tensor|None = None,
pre_value_cache: Tensor|None = None,
cache_k_quant_scales: Tensor|None = None,
cache_v_quant_scales: Tensor|None = None,
cache_k_dequant_scales: Tensor|None = None,
cache_v_dequant_scales: Tensor|None = None,
qkv_out_scale: Tensor|None = None,
qkv_bias: Tensor|None = None,
out_shift: Tensor|None = None,
out_smooth: Tensor|None = None,
max_enc_len_this_time: Tensor|None = None,
max_dec_len_this_time: Tensor|None = None,
rope_emb: Tensor|None = None,
mask: Tensor|None = None,
tgt_mask: Tensor|None = None,
pre_key_cache: Tensor | None = None,
pre_value_cache: Tensor | None = None,
cache_k_quant_scales: Tensor | None = None,
cache_v_quant_scales: Tensor | None = None,
cache_k_dequant_scales: Tensor | None = None,
cache_v_dequant_scales: Tensor | None = None,
qkv_out_scale: Tensor | None = None,
qkv_bias: Tensor | None = None,
out_shift: Tensor | None = None,
out_smooth: Tensor | None = None,
max_enc_len_this_time: Tensor | None = None,
max_dec_len_this_time: Tensor | None = None,
rope_emb: Tensor | None = None,
mask: Tensor | None = None,
tgt_mask: Tensor | None = None,
max_seq_len: int = -1,
block_size: int = 64,
use_neox_style: bool = False,
Expand Down Expand Up @@ -412,21 +412,21 @@ def block_multihead_attention_xpu(
block_tables: Tensor,
cache_k_per_batch_maxs: Tensor,
cache_v_per_batch_maxs: Tensor,
pre_key_cache: Tensor|None = None,
pre_value_cache: Tensor|None = None,
cache_k_quant_scales: Tensor|None = None,
cache_v_quant_scales: Tensor|None = None,
cache_k_dequant_scales: Tensor|None = None,
cache_v_dequant_scales: Tensor|None = None,
qkv_out_scale: Tensor|None = None,
qkv_bias: Tensor|None = None,
out_shift: Tensor|None = None,
out_smooth: Tensor|None = None,
max_enc_len_this_time: Tensor|None = None,
max_dec_len_this_time: Tensor|None = None,
rope_emb: Tensor|None = None,
mask: Tensor|None = None,
tgt_mask: Tensor|None = None,
pre_key_cache: Tensor | None = None,
pre_value_cache: Tensor | None = None,
cache_k_quant_scales: Tensor | None = None,
cache_v_quant_scales: Tensor | None = None,
cache_k_dequant_scales: Tensor | None = None,
cache_v_dequant_scales: Tensor | None = None,
qkv_out_scale: Tensor | None = None,
qkv_bias: Tensor | None = None,
out_shift: Tensor | None = None,
out_smooth: Tensor | None = None,
max_enc_len_this_time: Tensor | None = None,
max_dec_len_this_time: Tensor | None = None,
rope_emb: Tensor | None = None,
mask: Tensor | None = None,
tgt_mask: Tensor | None = None,
max_seq_len: int = -1,
block_size: int = 64,
use_neox_style: bool = False,
Expand Down

0 comments on commit 21229fd

Please sign in to comment.