Skip to content

Commit

Permalink
Merge pull request #6 from xymyeah/paddlebox
Browse files Browse the repository at this point in the history
fix fused_seqpool_cvm_op_xpu op bug
  • Loading branch information
jack603047588 authored Sep 13, 2023
2 parents 495198b + 59b3ab2 commit de6dd8f
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion paddle/fluid/operators/fused/fused_seqpool_cvm_op_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class FusedSeqpoolCVMOpXPUKernel : public framework::OpKernel<T> {
auto clk_coeff = ctx.Attr<float>("clk_coeff");
auto threshold = ctx.Attr<float>("threshold");
auto cvm_offset = ctx.Attr<int>("cvm_offset");
auto embed_thres_size = ctx.Attr<int>("embed_thres_size");

auto x0_lod = ins[0]->lod();
auto x0_dims = ins[0]->dims();
Expand All @@ -98,12 +99,13 @@ class FusedSeqpoolCVMOpXPUKernel : public framework::OpKernel<T> {
phi::Place l3_place = ctx.template device_context<DeviceContext>().GetL3Place();
int w = ins[0]->numel() / x0_dims[0];
if(use_cvm) {
if(clk_filter) w = w - 1;
PADDLE_ENFORCE_EQ(y_dims[1] % w, 0,
paddle::platform::errors::InvalidArgument(
"The output of dims[1] should be dividable of w"));
}
else{
PADDLE_ENFORCE_EQ(y_dims[1] % (w-2), 0,
PADDLE_ENFORCE_EQ(y_dims[1] % (w - cvm_offset - embed_thres_size), 0,
paddle::platform::errors::InvalidArgument(
"The output of dims[1] should be dividable of (w-2)"));
}
Expand Down

0 comments on commit de6dd8f

Please sign in to comment.