From 21229fd526e580120d83816eeaf4d9dcab211399 Mon Sep 17 00:00:00 2001 From: lwkhahaha <124662571+lwkhahaha@users.noreply.github.com> Date: Thu, 8 Aug 2024 17:30:07 +0800 Subject: [PATCH] Update block_multihead_attention.py --- .../functional/block_multihead_attention.py | 60 +++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/python/paddle/incubate/nn/functional/block_multihead_attention.py b/python/paddle/incubate/nn/functional/block_multihead_attention.py index 4ae8ea5690b7a..c2cad4c5c6bea 100644 --- a/python/paddle/incubate/nn/functional/block_multihead_attention.py +++ b/python/paddle/incubate/nn/functional/block_multihead_attention.py @@ -35,21 +35,21 @@ def block_multihead_attention( cu_seqlens_q: Tensor, cu_seqlens_k: Tensor, block_tables: Tensor, - pre_key_cache: Tensor|None = None, - pre_value_cache: Tensor|None = None, - cache_k_quant_scales: Tensor|None = None, - cache_v_quant_scales: Tensor|None = None, - cache_k_dequant_scales: Tensor|None = None, - cache_v_dequant_scales: Tensor|None = None, - qkv_out_scale: Tensor|None = None, - qkv_bias: Tensor|None = None, - out_shift: Tensor|None = None, - out_smooth: Tensor|None = None, - max_enc_len_this_time: Tensor|None = None, - max_dec_len_this_time: Tensor|None = None, - rope_emb: Tensor|None = None, - mask: Tensor|None = None, - tgt_mask: Tensor|None = None, + pre_key_cache: Tensor | None = None, + pre_value_cache: Tensor | None = None, + cache_k_quant_scales: Tensor | None = None, + cache_v_quant_scales: Tensor | None = None, + cache_k_dequant_scales: Tensor | None = None, + cache_v_dequant_scales: Tensor | None = None, + qkv_out_scale: Tensor | None = None, + qkv_bias: Tensor | None = None, + out_shift: Tensor | None = None, + out_smooth: Tensor | None = None, + max_enc_len_this_time: Tensor | None = None, + max_dec_len_this_time: Tensor | None = None, + rope_emb: Tensor | None = None, + mask: Tensor | None = None, + tgt_mask: Tensor | None = None, max_seq_len: int = -1, block_size: int = 64, use_neox_style: bool = False, @@ -412,21 +412,21 @@ def block_multihead_attention_xpu( block_tables: Tensor, cache_k_per_batch_maxs: Tensor, cache_v_per_batch_maxs: Tensor, - pre_key_cache: Tensor|None = None, - pre_value_cache: Tensor|None = None, - cache_k_quant_scales: Tensor|None = None, - cache_v_quant_scales: Tensor|None = None, - cache_k_dequant_scales: Tensor|None = None, - cache_v_dequant_scales: Tensor|None = None, - qkv_out_scale: Tensor|None = None, - qkv_bias: Tensor|None = None, - out_shift: Tensor|None = None, - out_smooth: Tensor|None = None, - max_enc_len_this_time: Tensor|None = None, - max_dec_len_this_time: Tensor|None = None, - rope_emb: Tensor|None = None, - mask: Tensor|None = None, - tgt_mask: Tensor|None = None, + pre_key_cache: Tensor | None = None, + pre_value_cache: Tensor | None = None, + cache_k_quant_scales: Tensor | None = None, + cache_v_quant_scales: Tensor | None = None, + cache_k_dequant_scales: Tensor | None = None, + cache_v_dequant_scales: Tensor | None = None, + qkv_out_scale: Tensor | None = None, + qkv_bias: Tensor | None = None, + out_shift: Tensor | None = None, + out_smooth: Tensor | None = None, + max_enc_len_this_time: Tensor | None = None, + max_dec_len_this_time: Tensor | None = None, + rope_emb: Tensor | None = None, + mask: Tensor | None = None, + tgt_mask: Tensor | None = None, max_seq_len: int = -1, block_size: int = 64, use_neox_style: bool = False,