Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[operator] Integrate matmul primitive from oneDNN in batch dot #20340

Merged
merged 5 commits into from
Jun 24, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const
bool SupportMKLDNNLogSoftmax(const SoftmaxParam& param, const NDArray &input,
const NDArray &output);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
bool SupportMKLDNNBatchDot(const std::vector<NDArray> &inputs, const NDArray &output);
} // namespace op

static int GetTypeSize(int dtype) {
Expand Down
65 changes: 65 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file mkldnn_batch_dot-inl.h
*/

#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_DOT_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_DOT_INL_H_

#if MXNET_USE_ONEDNN == 1

#include <numeric>
#include <utility>
#include <vector>
#include "../../tensor/dot-inl.h"
#include "./mkldnn_base-inl.h"
#include "./mkldnn_ops-inl.h"

namespace mxnet {
namespace op {

using batch_dot_fwd_t = mkldnn::matmul;
using batch_dot_fwd_pd_t = mkldnn::matmul::primitive_desc;

typedef ParamOpSign<DotParam> BatchDotSignature;

class MKLDNNBatchDotFwd {
public:
static MKLDNNBatchDotFwd &GetCached(const DotParam &param,
const std::vector<NDArray> &inputs,
const std::vector<NDArray> &outputs);

MKLDNNBatchDotFwd(const DotParam &param, const std::vector<NDArray> &inputs,
const std::vector<NDArray> &outputs);

void Execute(const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs);

private:
std::shared_ptr<batch_dot_fwd_t> fwd;
std::shared_ptr<batch_dot_fwd_pd_t> fwd_pd;
};

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_ONEDNN == 1
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_DOT_INL_H__
132 changes: 132 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_batch_dot.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file mkldnn_batch_dot.cc
*/

#if MXNET_USE_ONEDNN == 1

#include "./mkldnn_batch_dot-inl.h"

namespace mxnet {
namespace op {

bool SupportMKLDNNBatchDot(const std::vector<NDArray> &inputs,
const NDArray &output) {
return inputs[0].shape().Size() != 0 && inputs[1].shape().Size() != 0 &&
output.shape().Size() != 0 &&
(inputs[0].dtype() == mshadow::kFloat32 ||
inputs[0].dtype() == mshadow::kBfloat16);
}

void MKLDNNBatchDotForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
const DotParam &param = nnvm::get<DotParam>(attrs.parsed);
MKLDNNBatchDotFwd &fwd = MKLDNNBatchDotFwd::GetCached(param, inputs, outputs);
fwd.Execute(inputs, req, outputs);
}

MKLDNNBatchDotFwd &MKLDNNBatchDotFwd::GetCached(
const DotParam &param, const std::vector<NDArray> &inputs,
const std::vector<NDArray> &outputs) {
using batch_dot_fwd_map =
std::unordered_map<BatchDotSignature, MKLDNNBatchDotFwd, OpHash>;
#if DMLC_CXX11_THREAD_LOCAL
static thread_local batch_dot_fwd_map fwds;
#else
static MX_THREAD_LOCAL batch_dot_fwd_map fwds;
#endif

BatchDotSignature key(param);
key.AddSign(inputs[0]);
key.AddSign(inputs[1]);
key.AddSign(outputs[0]);

auto it = fwds.find(key);
if (it == fwds.end()) {
const MKLDNNBatchDotFwd fwd(param, inputs, outputs);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}

MKLDNNBatchDotFwd::MKLDNNBatchDotFwd(const DotParam &param,
const std::vector<NDArray> &inputs,
const std::vector<NDArray> &outputs) {
auto shape = inputs[0].shape();
auto ndim = shape.ndim();
auto bigDim = shape[0];
for (size_t i = 1; i < ndim - 2; ++i) {
bigDim *= shape[i];
}

auto GetMemoryDesc = [&ndim, &bigDim](const NDArray &tensor,
const bool transpose) {
auto shape = tensor.shape();
if (transpose) {
return mkldnn::memory::desc(
mkldnn::memory::dims{bigDim, shape[ndim - 1], shape[ndim - 2]},
get_mkldnn_type(tensor.dtype()), mkldnn::memory::format_tag::acb);
} else {
return mkldnn::memory::desc(
mkldnn::memory::dims{bigDim, shape[ndim - 2], shape[ndim - 1]},
get_mkldnn_type(tensor.dtype()), mkldnn::memory::format_tag::any);
}
};

mkldnn::memory::desc data_md = GetMemoryDesc(inputs[0], param.transpose_a);
mkldnn::memory::desc weights_md = GetMemoryDesc(inputs[1], param.transpose_b);
mkldnn::memory::desc out_md({bigDim, data_md.dims()[1], weights_md.dims()[2]},
get_mkldnn_type(outputs[0].dtype()),
mkldnn::memory::format_tag::any);
mkldnn::matmul::desc fwd_desc(data_md, weights_md, out_md);
fwd_pd = std::make_shared<batch_dot_fwd_pd_t>(
fwd_desc, mxnet::CpuEngine::Get()->get_engine());
fwd = std::make_shared<batch_dot_fwd_t>(*fwd_pd);
}

void MKLDNNBatchDotFwd::Execute(const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
bartekkuncer marked this conversation as resolved.
Show resolved Hide resolved
auto engine = mxnet::CpuEngine::Get()->get_engine();
auto data = mkldnn::memory(fwd_pd->src_desc(), engine,
reinterpret_cast<void *>(inputs[0].data().dptr_));
auto weights =
mkldnn::memory(fwd_pd->weights_desc(), engine,
reinterpret_cast<void *>(inputs[1].data().dptr_));
mkldnn_output_t out_mem =
CreateMKLDNNMem(outputs[0], fwd_pd->dst_desc(), req[0], &inputs[0]);

mkldnn_args_map_t args = {
{MKLDNN_ARG_SRC, data},
{MKLDNN_ARG_WEIGHTS, weights},
{MKLDNN_ARG_DST, *out_mem.second},
};

MKLDNNStream::Get()->RegisterPrimArgs(*fwd, args);
CommitOutput(outputs[0], out_mem);
MKLDNNStream::Get()->Submit();
}

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_ONEDNN == 1
6 changes: 6 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);

/* For batch dot */
void MKLDNNBatchDotForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs);

void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
const mkldnn::memory &out);

Expand Down
18 changes: 17 additions & 1 deletion src/operator/tensor/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ struct DotParam : public dmlc::Parameter<DotParam> {
.add_enum("csr", kCSRStorage)
.set_default(dmlc::optional<int>());
}
bool operator==(const DotParam& other) const {
return this->transpose_a == other.transpose_a &&
this->transpose_b == other.transpose_b &&
this->forward_stype == other.forward_stype;
}
std::string ForwardStype2String(int forward_stype) {
switch (forward_stype) {
case kDefaultStorage:
Expand Down Expand Up @@ -1482,5 +1487,16 @@ inline bool BatchDotShape(const nnvm::NodeAttrs& attrs,

} // namespace op
} // namespace mxnet

namespace std {
template<>
struct hash<mxnet::op::DotParam> {
size_t operator()(const mxnet::op::DotParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.transpose_a);
ret = dmlc::HashCombine(ret, val.transpose_b);
ret = dmlc::HashCombine(ret, val.forward_stype);
return ret;
}
};
} // namespace std
#endif // MXNET_OPERATOR_TENSOR_DOT_INL_H_
37 changes: 37 additions & 0 deletions src/operator/tensor/dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
*/

#include "./dot-inl.h"
#if MXNET_USE_ONEDNN == 1
#include "./../nn/mkldnn/mkldnn_base-inl.h"
#include "./../nn/mkldnn/mkldnn_ops-inl.h"
#endif // MXNET_USE_ONEDNN

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -111,6 +115,34 @@ NNVM_REGISTER_OP(_backward_dot)
.set_attr<FComputeEx>("FComputeEx<cpu>", DotBackwardEx<cpu>)
.add_arguments(DotParam::__FIELDS__());

#if MXNET_USE_ONEDNN == 1
static void BatchDotComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNBatchDot(inputs, outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNRun(MKLDNNBatchDotForward, attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(BatchDotForward_<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
FallBackCompute(BatchDotForward_<cpu>, attrs, ctx, inputs, req, outputs);
}

bool BatchDotStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2);
CHECK_EQ(out_attrs->size(), 1);

return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
out_attrs);
}
#endif

NNVM_REGISTER_OP(batch_dot)
.add_alias("_npx_batch_dot")
.describe(R"doc(Batchwise dot product.
Expand Down Expand Up @@ -140,6 +172,11 @@ which is computed by::
})
.set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
.set_attr<FCompute>("FCompute<cpu>", BatchDotForward_<cpu>)
#if MXNET_USE_ONEDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FInferStorageType>("FInferStorageType", BatchDotStorageType)
.set_attr<FComputeEx>("FComputeEx<cpu>", BatchDotComputeExCPU)
#endif
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::ObjectPtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
Expand Down