From d87f7ab4eacec2d003caa8a4afc1cfe8ce3ab85b Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Wed, 28 Jul 2021 14:19:20 +0200 Subject: [PATCH] CI fix --- .../fluid/operators/mkldnn/reshape_mkldnn_op.cc | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc index d8f0d69d73db9..244430e69f234 100644 --- a/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc @@ -36,9 +36,18 @@ class ReshapeMKLDNNKernel : public framework::OpKernel { const auto& onednn_engine = dev_ctx.GetEngine(); auto* x = ctx.Input("X"); + auto* xshape = ctx.Output("XShape"); auto* out = ctx.Output("Out"); - auto x_dims = x->dims(); + framework::DDim x_dims; + // if reshape or squeeze + if (ctx.Type().find("2") == std::string::npos) { + x_dims = x->dims(); + } else { + auto xshape_dims = xshape->dims(); + x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + } + auto x_vec_dims = framework::vectorize(x_dims); framework::DDim out_dims; @@ -210,9 +219,10 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel { auto* dx = ctx.Output(framework::GradVarName("X")); framework::DDim x_dims; - if (ctx.Type() != "squeeze2_grad") { + // if reshape or squeeze + if (ctx.Type().find("2") == std::string::npos) { x_dims = dx->dims(); - } else if (ctx.Type() == "squeeze2_grad") { + } else { auto xshape_dims = ctx.Input("XShape")->dims(); x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); }