From 31a5a817bac2465da56bd9f4141649866ba31e41 Mon Sep 17 00:00:00 2001 From: Bartosz Kuncer Date: Fri, 11 Jun 2021 11:39:06 +0200 Subject: [PATCH] Fix test_npx_batch_dot --- src/operator/nn/mkldnn/mkldnn_base-inl.h | 2 +- src/operator/nn/mkldnn/mkldnn_batch_dot.cc | 40 ++++++++++++++-------- src/operator/tensor/dot.cc | 2 +- 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 6eaaa06c41eb..6e865818b155 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -213,7 +213,7 @@ bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const bool SupportMKLDNNLogSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output); bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data); -bool SupportMKLDNNBatchDot(const NDArray &input); +bool SupportMKLDNNBatchDot(const std::vector &inputs, const NDArray &output); } // namespace op static int GetTypeSize(int dtype) { diff --git a/src/operator/nn/mkldnn/mkldnn_batch_dot.cc b/src/operator/nn/mkldnn/mkldnn_batch_dot.cc index 3da9ff890bbc..05c42d316c51 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_dot.cc +++ b/src/operator/nn/mkldnn/mkldnn_batch_dot.cc @@ -28,9 +28,12 @@ namespace mxnet { namespace op { -bool SupportMKLDNNBatchDot(const NDArray &data) { - return data.dtype() == mshadow::kFloat32 || - data.dtype() == mshadow::kBfloat16; +bool SupportMKLDNNBatchDot(const std::vector &inputs, const NDArray &output) { + return inputs[0].shape().Size() != 0 && + inputs[1].shape().Size() != 0 && + output.shape().Size() != 0 && + (inputs[0].dtype() == mshadow::kFloat32 || + inputs[0].dtype() == mshadow::kBfloat16); } void MKLDNNBatchDotForward(const nnvm::NodeAttrs &attrs, @@ -69,25 +72,32 @@ MKLDNNBatchDotFwd &MKLDNNBatchDotFwd::GetCached(const DotParam ¶m, MKLDNNBatchDotFwd::MKLDNNBatchDotFwd(const DotParam ¶m, const std::vector &inputs, const std::vector &outputs) { - auto GetMemoryDesc = [](const NDArray& tensor, const bool transpose) { + auto shape = inputs[0].shape(); + auto ndim = shape.ndim(); + auto bigDim = shape[0]; + for (size_t i = 1; i < ndim - 2; ++i) { + bigDim *= shape[i]; + } + + auto GetMemoryDesc = [&ndim, &bigDim](const NDArray& tensor, const bool transpose) { + auto shape = tensor.shape(); if (transpose) { - auto shape = tensor.shape(); - auto ndim = shape.ndim(); - auto bigDim = shape[0]; - for (size_t i = 1; i < ndim - 2; ++i) { - bigDim *= shape[i]; - } return mkldnn::memory::desc(mkldnn::memory::dims{bigDim, shape[ndim - 1], shape[ndim - 2]}, get_mkldnn_type(tensor.dtype()), mkldnn::memory::format_tag::acb); } else { - return tensor.GetMKLDNNData()->get_desc(); + return mkldnn::memory::desc(mkldnn::memory::dims{bigDim, shape[ndim - 2], shape[ndim - 1]}, + get_mkldnn_type(tensor.dtype()), + mkldnn::memory::format_tag::any); } }; - mkldnn::matmul::desc fwd_desc(GetMemoryDesc(inputs[0], param.transpose_a), - GetMemoryDesc(inputs[1], param.transpose_b), - outputs[0].GetMKLDNNData()->get_desc()); + mkldnn::memory::desc data_md = GetMemoryDesc(inputs[0], param.transpose_a); + mkldnn::memory::desc weights_md = GetMemoryDesc(inputs[1], param.transpose_b); + mkldnn::memory::desc out_md({bigDim, data_md.dims()[1], weights_md.dims()[2]}, + get_mkldnn_type(outputs[0].dtype()), + mkldnn::memory::format_tag::any); + mkldnn::matmul::desc fwd_desc(data_md, weights_md, out_md); fwd_pd = std::make_shared(fwd_desc, mxnet::CpuEngine::Get()->get_engine()); fwd = std::make_shared(*fwd_pd); } @@ -102,7 +112,7 @@ void MKLDNNBatchDotFwd::Execute(const std::vector &inputs, reinterpret_cast(inputs[1].data().dptr_)); mkldnn_output_t out_mem = CreateMKLDNNMem(outputs[0], fwd_pd->dst_desc(), req[0], &inputs[0]); - mkldnn_args_map_t args = { + mkldnn_args_map_t args = { {MKLDNN_ARG_SRC, data}, {MKLDNN_ARG_WEIGHTS, weights}, {MKLDNN_ARG_DST, *out_mem.second}, diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc index 53960e7be23d..e3df64c79e0c 100644 --- a/src/operator/tensor/dot.cc +++ b/src/operator/tensor/dot.cc @@ -121,7 +121,7 @@ static void BatchDotComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - if (SupportMKLDNNBatchDot(inputs[0])) { + if (SupportMKLDNNBatchDot(inputs, outputs[0])) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); MKLDNNRun(MKLDNNBatchDotForward, attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(BatchDotForward_, attrs, ctx, inputs, req, outputs);