Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix test_npx_batch_dot
Browse files Browse the repository at this point in the history
  • Loading branch information
bartekkuncer committed Jun 11, 2021
1 parent 108ffd2 commit 6b9c0b5
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const
bool SupportMKLDNNLogSoftmax(const SoftmaxParam& param, const NDArray &input,
const NDArray &output);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
bool SupportMKLDNNBatchDot(const NDArray &input);
bool SupportMKLDNNBatchDot(const std::vector<NDArray> &inputs, const NDArray &output);
} // namespace op

static int GetTypeSize(int dtype) {
Expand Down
42 changes: 28 additions & 14 deletions src/operator/nn/mkldnn/mkldnn_batch_dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@
namespace mxnet {
namespace op {

bool SupportMKLDNNBatchDot(const NDArray &data) {
return data.dtype() == mshadow::kFloat32 ||
data.dtype() == mshadow::kBfloat16;
bool SupportMKLDNNBatchDot(const std::vector<NDArray> &inputs, const NDArray &output) {
return inputs[0].shape().Size() != 0 &&
inputs[1].shape().Size() != 0 &&
output.shape().Size() != 0 &&
(inputs[0].dtype() == mshadow::kFloat32 ||
inputs[0].dtype() == mshadow::kBfloat16);
}

void MKLDNNBatchDotForward(const nnvm::NodeAttrs &attrs,
Expand Down Expand Up @@ -69,25 +72,36 @@ MKLDNNBatchDotFwd &MKLDNNBatchDotFwd::GetCached(const DotParam &param,
MKLDNNBatchDotFwd::MKLDNNBatchDotFwd(const DotParam &param,
const std::vector<NDArray> &inputs,
const std::vector<NDArray> &outputs) {
auto GetMemoryDesc = [](const NDArray& tensor, const bool transpose) {
auto shape = inputs[0].shape();
auto ndim = shape.ndim();
auto bigDim = shape[0];
for (size_t i = 1; i < ndim - 2; ++i) {
bigDim *= shape[i];
}

auto GetMemoryDesc = [&bigDim](const NDArray& tensor, const bool transpose) {
auto shape = tensor.shape();
auto ndim = shape.ndim();
if (transpose) {
auto shape = tensor.shape();
auto ndim = shape.ndim();
auto bigDim = shape[0];
for (size_t i = 1; i < ndim - 2; ++i) {
bigDim *= shape[i];
}
return mkldnn::memory::desc(mkldnn::memory::dims{bigDim, shape[ndim - 1], shape[ndim - 2]},
get_mkldnn_type(tensor.dtype()),
mkldnn::memory::format_tag::acb);
} else {
return tensor.GetMKLDNNData()->get_desc();
return mkldnn::memory::desc(mkldnn::memory::dims{bigDim, shape[ndim - 2], shape[ndim - 1]},
get_mkldnn_type(tensor.dtype()),
mkldnn::memory::format_tag::any);
}
};

mkldnn::matmul::desc fwd_desc(GetMemoryDesc(inputs[0], param.transpose_a),
GetMemoryDesc(inputs[1], param.transpose_b),
outputs[0].GetMKLDNNData()->get_desc());
mkldnn::memory::desc data_md = GetMemoryDesc(inputs[0], param.transpose_a);
mkldnn::memory::desc weights_md = GetMemoryDesc(inputs[1], param.transpose_b);
mkldnn::matmul::desc fwd_desc(data_md,
weights_md,
mkldnn::memory::desc(mkldnn::memory::dims{bigDim,
data_md.dims()[1],
weights_md.dims()[2]},
get_mkldnn_type(outputs[0].dtype()),
mkldnn::memory::format_tag::any));
fwd_pd = std::make_shared<batch_dot_fwd_pd_t>(fwd_desc, mxnet::CpuEngine::Get()->get_engine());
fwd = std::make_shared<batch_dot_fwd_t>(*fwd_pd);
}
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ static void BatchDotComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNBatchDot(inputs[0])) {
if (SupportMKLDNNBatchDot(inputs, outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNRun(MKLDNNBatchDotForward, attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(BatchDotForward_<cpu>, attrs, ctx, inputs, req, outputs);
Expand Down

0 comments on commit 6b9c0b5

Please sign in to comment.