diff --git a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc index 446b8aac398a0..ba3ce00547ae5 100644 --- a/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/matmul_v2_mkldnn_op.cc @@ -166,16 +166,17 @@ class MatMulV2MKLDNNKernel } } - if ((y_dims.size() == x_dims.size()) && y_dims.size() > 2) { - for (size_t i = 0; i < x_dims.size() - 2; ++i) { + if (x_dims.size() > 2 && y_dims.size() > 2) { + for (size_t i = 0; i < x_bd_dims.size() - 2; ++i) { PADDLE_ENFORCE_EQ( - x_dims[i] == y_dims[i] || x_dims[i] == 1 || y_dims[i] == 1, true, - paddle::platform::errors::InvalidArgument( - "Tensor dimensions are incorrect for broadcasting." - "Dimensions in X and Y must be same or equal to 1, but " - "received x_dim[%d]=%d and y_dims[%d]= %d", - i, x_dims[i], i, y_dims[i])); - out_dims[i] = std::max(x_dims[i], y_dims[i]); + x_bd_dims[i] == y_bd_dims[i] || x_bd_dims[i] == 1 || + y_bd_dims[i] == 1, + true, paddle::platform::errors::InvalidArgument( + "Tensor dimensions are incorrect for broadcasting." + "Dimensions in X and Y must be same or equal to 1, but " + "received x_dim[%d]=%d and y_dims[%d]= %d", + i, x_bd_dims[i], i, y_bd_dims[i])); + out_dims[i] = std::max(x_bd_dims[i], y_bd_dims[i]); } out->Resize(make_ddim(out_dims)); } @@ -237,11 +238,11 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel { dy_tmp->mutable_data(ctx.GetPlace()); } - void ReduceSumForMatmulGradOutput(const ExecutionContext& ctx, - const MKLDNNDeviceContext& dev_ctx, - const mkldnn::engine onednn_engine, - const Tensor* dx_tmp, Tensor* dx, - std::vector dx_dims) const { + void ReduceSumForMatmulGradOutput( + const ExecutionContext& ctx, const MKLDNNDeviceContext& dev_ctx, + const dnnl::engine onednn_engine, const Tensor* dx_tmp, Tensor* dx, + std::vector& dx_dims, + const std::vector& squeezed_dims) const { paddle::platform::ReductionMKLDNNHandler handler( dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine, ctx.GetPlace(), dx_tmp, dx, dx_dims); @@ -257,6 +258,9 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel { reduction_p->execute(astream, reduction_args); astream.wait(); + + dx->set_format(paddle::platform::GetMKLDNNFormat( + dst_memory_p->get_desc().reshape(squeezed_dims))); } std::vector ExtendDimsWithOnes(const std::vector& dims, @@ -356,21 +360,21 @@ class MatMulV2GradMKLDNNKernel : public MatMulV2MKLDNNKernel { if (x_dims != dx_bd_dims) { ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dx_tmp, dx, - x_dims); + x_dims, + paddle::framework::vectorize(x->dims())); } else { *dx = std::move(dx_tmp); } if (y_dims != dy_bd_dims) { ReduceSumForMatmulGradOutput(ctx, dev_ctx, onednn_engine, &dy_tmp, dy, - y_dims); + y_dims, + paddle::framework::vectorize(y->dims())); } else { *dy = std::move(dy_tmp); } - dx->set_layout(paddle::framework::DataLayout::kMKLDNN); - dx->set_format(x->format()); - dy->set_layout(paddle::framework::DataLayout::kMKLDNN); - dy->set_format(y->format()); + dx->Resize(x->dims()); + dy->Resize(y->dims()); } }; } // anonymous namespace diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py index 1fc5fb494bb15..2fe28c934b1bc 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_v2_mkldnn_op.py @@ -278,6 +278,14 @@ def config(self): self.trans_y = True +class TestMatMulV2MatrixXMatrix3Dx4DOneDNNOp(TestMatMulV2VectorXVectorOneDNNOp): + def config(self): + self.x_shape = (1, 1, 32, 16) + self.y_shape = (16, 16, 16) + self.trans_x = False + self.trans_y = False + + # BF16 TESTS def create_bf16_test_class(parent): @OpTestTool.skip_if_not_cpu_bf16()