@@ -41,7 +41,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
4141 const paddle::Tensor &encoder_seq_lod_cpu,
4242 const paddle::Tensor &encoder_batch_map_cpu,
4343 const paddle::Tensor &decoder_context_len_cpu,
44- const paddle::Tensor &decoder_batch_map_cpu) {
44+ const paddle::Tensor &decoder_batch_map_cpu,
45+ const std::string &pos_emb_type=" NORMAL" ,
46+ bool rope_3d=false ) {
4547 phi::XPUPlace place (phi::backends::xpu::GetXPUCurrentDeviceId ());
4648 auto dev_ctx =
4749 paddle::experimental::DeviceContextPool::Instance ().Get (place);
@@ -72,6 +74,14 @@ std::vector<paddle::Tensor> BlockAttnKernel(
7274 int enc_batch = enc_batch_tensor.data <int32_t >()[0 ];
7375 int dec_batch = dec_batch_tensor.data <int32_t >()[0 ];
7476 int total_enc_len = total_enc_len_tensor.data <int32_t >()[0 ];
77+ int rope_max_seqlen = 0 ;
78+ int rope_3d_num_seqs = 1 ;
79+ if (rope_3d) {
80+ rope_max_seqlen = rotary_embs.dims ()[3 ];
81+ rope_3d_num_seqs = rotary_embs.dims ()[0 ];
82+ } else {
83+ rope_max_seqlen = rotary_embs.dims ()[2 ];
84+ }
7585
7686 auto block_attn_out =
7787 paddle::full ({token_num, hidden_dim}, -1 , qkv.type (), qkv.place ());
@@ -151,10 +161,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
151161 prefix_lens_vp, // start_tokens
152162 param.batch_size , // batch_size
153163 1 , // emb_batch_size
154- rotary_embs. dims ()[ 2 ], // max_seqlen
164+ rope_max_seqlen, // max_seqlen
155165 param.head_num , param.kv_head_num , param.head_dim ,
156166 param.max_batch_size , block_size, max_block_per_seq, " BLHD" ,
157- " HLD" , " NORMAL " ,
167+ " HLD" , pos_emb_type ,
158168 !p_kcache_perhead_scale.defined ()
159169 ? nullptr
160170 : p_kcache_perhead_scale.data <float >() +
@@ -246,10 +256,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
246256 vsl.slot_mapping_vp , // real_batch
247257 param.batch_size , // batch_size
248258 1 , // emb_batch_size
249- rotary_embs. dims ()[ 2 ], // max_seqlen TODO!!double check
259+ rope_max_seqlen, // max_seqlen
250260 param.head_num , param.kv_head_num , param.head_dim ,
251261 param.max_batch_size , block_size, max_block_per_seq, " BLHD" , " HLD" ,
252- " NORMAL " ,
262+ pos_emb_type ,
253263 !p_kcache_perhead_scale.defined ()
254264 ? nullptr
255265 : p_kcache_perhead_scale.data <float >() +
@@ -260,7 +270,9 @@ std::vector<paddle::Tensor> BlockAttnKernel(
260270 param.kv_head_num , // v_cache_scale_inv
261271 nullptr , // k_cache_zp
262272 nullptr , // v_cache_zp
263- false ); // b_c8_pc
273+ false , // b_c8_pc
274+ rope_3d, // rope_3d
275+ rope_3d_num_seqs);
264276 XFTBLOCK_CHECK_EQ (ret, api::SUCCESS);
265277
266278 // attn decode
@@ -314,6 +326,7 @@ PD_BUILD_OP(block_attn)
314326 " decoder_context_len_cpu" ,
315327 " decoder_batch_map_cpu" ,
316328 })
329+ .Attrs({" pos_emb_type:std::string" , " rope_3d:bool" })
317330 .Outputs({" block_attn_out" })
318331 .SetKernelFn(PD_KERNEL(BlockAttnKernel))
319332 .SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))
0 commit comments