diff --git a/paddle/fluid/operators/fused/fused_seqpool_cvm_op_xpu.cc b/paddle/fluid/operators/fused/fused_seqpool_cvm_op_xpu.cc index ed54b6aafdae6..ae0a320d87800 100644 --- a/paddle/fluid/operators/fused/fused_seqpool_cvm_op_xpu.cc +++ b/paddle/fluid/operators/fused/fused_seqpool_cvm_op_xpu.cc @@ -74,6 +74,7 @@ class FusedSeqpoolCVMOpXPUKernel : public framework::OpKernel { auto clk_coeff = ctx.Attr("clk_coeff"); auto threshold = ctx.Attr("threshold"); auto cvm_offset = ctx.Attr("cvm_offset"); + auto embed_thres_size = ctx.Attr("embed_thres_size"); auto x0_lod = ins[0]->lod(); auto x0_dims = ins[0]->dims(); @@ -98,12 +99,13 @@ class FusedSeqpoolCVMOpXPUKernel : public framework::OpKernel { phi::Place l3_place = ctx.template device_context().GetL3Place(); int w = ins[0]->numel() / x0_dims[0]; if(use_cvm) { + if(clk_filter) w = w - 1; PADDLE_ENFORCE_EQ(y_dims[1] % w, 0, paddle::platform::errors::InvalidArgument( "The output of dims[1] should be dividable of w")); } else{ - PADDLE_ENFORCE_EQ(y_dims[1] % (w-2), 0, + PADDLE_ENFORCE_EQ(y_dims[1] % (w - cvm_offset - embed_thres_size), 0, paddle::platform::errors::InvalidArgument( "The output of dims[1] should be dividable of (w-2)")); }