Skip to content

Commit

Permalink
CI fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jakpiase committed Jul 28, 2021
1 parent 7625f68 commit d87f7ab
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,18 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
const auto& onednn_engine = dev_ctx.GetEngine();

auto* x = ctx.Input<LoDTensor>("X");
auto* xshape = ctx.Output<LoDTensor>("XShape");
auto* out = ctx.Output<LoDTensor>("Out");

auto x_dims = x->dims();
framework::DDim x_dims;
// if reshape or squeeze
if (ctx.Type().find("2") == std::string::npos) {
x_dims = x->dims();
} else {
auto xshape_dims = xshape->dims();
x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
}

auto x_vec_dims = framework::vectorize(x_dims);

framework::DDim out_dims;
Expand Down Expand Up @@ -210,9 +219,10 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T> {
auto* dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));

framework::DDim x_dims;
if (ctx.Type() != "squeeze2_grad") {
// if reshape or squeeze
if (ctx.Type().find("2") == std::string::npos) {
x_dims = dx->dims();
} else if (ctx.Type() == "squeeze2_grad") {
} else {
auto xshape_dims = ctx.Input<framework::LoDTensor>("XShape")->dims();
x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
}
Expand Down

1 comment on commit d87f7ab

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on d87f7ab Jul 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍PR: #34219 Commit ID: d87f7ab contains failed CI.

Please sign in to comment.