From 108ffd2cce0154235b5a05188430fa1c7eb6c89e Mon Sep 17 00:00:00 2001 From: Bartosz Kuncer Date: Wed, 9 Jun 2021 17:17:29 +0200 Subject: [PATCH] [operator] Integrate matmul primitive from oneDNN in batch dot --- src/operator/nn/mkldnn/mkldnn_base-inl.h | 1 + src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h | 66 ++++++++++ src/operator/nn/mkldnn/mkldnn_batch_dot.cc | 118 ++++++++++++++++++ src/operator/nn/mkldnn/mkldnn_ops-inl.h | 6 + src/operator/tensor/dot-inl.h | 18 ++- src/operator/tensor/dot.cc | 37 ++++++ 6 files changed, 245 insertions(+), 1 deletion(-) create mode 100644 src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h create mode 100644 src/operator/nn/mkldnn/mkldnn_batch_dot.cc diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 5385b5b3a1e9..6eaaa06c41eb 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -213,6 +213,7 @@ bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const bool SupportMKLDNNLogSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output); bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data); +bool SupportMKLDNNBatchDot(const NDArray &input); } // namespace op static int GetTypeSize(int dtype) { diff --git a/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h new file mode 100644 index 000000000000..c4cd74c0764b --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_batch_dot-inl.h @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_batch_dot-inl.h + */ + +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_DOT_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_DOT_INL_H_ + +#if MXNET_USE_ONEDNN == 1 + +#include +#include +#include +#include "./mkldnn_base-inl.h" +#include "./mkldnn_ops-inl.h" +#include "../../tensor/dot-inl.h" + +namespace mxnet { +namespace op { + +using batch_dot_fwd_t = mkldnn::matmul; +using batch_dot_fwd_pd_t = mkldnn::matmul::primitive_desc; + +typedef ParamOpSign BatchDotSignature; + +class MKLDNNBatchDotFwd { + public: + static MKLDNNBatchDotFwd &GetCached(const DotParam ¶m, + const std::vector &inputs, + const std::vector &outputs); + + MKLDNNBatchDotFwd(const DotParam ¶m, + const std::vector &inputs, + const std::vector &outputs); + + void Execute(const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + + private: + std::shared_ptr fwd; + std::shared_ptr fwd_pd; +}; + +} // namespace op +} // namespace mxnet +#endif // MXNET_USE_ONEDNN == 1 +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_DOT_INL_H__ diff --git a/src/operator/nn/mkldnn/mkldnn_batch_dot.cc b/src/operator/nn/mkldnn/mkldnn_batch_dot.cc new file mode 100644 index 000000000000..3da9ff890bbc --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_batch_dot.cc @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_batch_dot.cc + */ + +#if MXNET_USE_ONEDNN == 1 + +#include "./mkldnn_batch_dot-inl.h" + +namespace mxnet { +namespace op { + +bool SupportMKLDNNBatchDot(const NDArray &data) { + return data.dtype() == mshadow::kFloat32 || + data.dtype() == mshadow::kBfloat16; +} + +void MKLDNNBatchDotForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const DotParam ¶m = nnvm::get(attrs.parsed); + MKLDNNBatchDotFwd &fwd = MKLDNNBatchDotFwd::GetCached(param, inputs, outputs); + fwd.Execute(inputs, req, outputs); +} + +MKLDNNBatchDotFwd &MKLDNNBatchDotFwd::GetCached(const DotParam ¶m, + const std::vector &inputs, + const std::vector &outputs) { + using batch_dot_fwd_map = std::unordered_map; +#if DMLC_CXX11_THREAD_LOCAL + static thread_local batch_dot_fwd_map fwds; +#else + static MX_THREAD_LOCAL batch_dot_fwd_map fwds; +#endif + + BatchDotSignature key(param); + key.AddSign(inputs[0]); + key.AddSign(inputs[1]); + key.AddSign(outputs[0]); + + auto it = fwds.find(key); + if (it == fwds.end()) { + const MKLDNNBatchDotFwd fwd(param, inputs, outputs); + it = AddToCache(&fwds, key, fwd); + } + return it->second; +} + +MKLDNNBatchDotFwd::MKLDNNBatchDotFwd(const DotParam ¶m, + const std::vector &inputs, + const std::vector &outputs) { + auto GetMemoryDesc = [](const NDArray& tensor, const bool transpose) { + if (transpose) { + auto shape = tensor.shape(); + auto ndim = shape.ndim(); + auto bigDim = shape[0]; + for (size_t i = 1; i < ndim - 2; ++i) { + bigDim *= shape[i]; + } + return mkldnn::memory::desc(mkldnn::memory::dims{bigDim, shape[ndim - 1], shape[ndim - 2]}, + get_mkldnn_type(tensor.dtype()), + mkldnn::memory::format_tag::acb); + } else { + return tensor.GetMKLDNNData()->get_desc(); + } + }; + + mkldnn::matmul::desc fwd_desc(GetMemoryDesc(inputs[0], param.transpose_a), + GetMemoryDesc(inputs[1], param.transpose_b), + outputs[0].GetMKLDNNData()->get_desc()); + fwd_pd = std::make_shared(fwd_desc, mxnet::CpuEngine::Get()->get_engine()); + fwd = std::make_shared(*fwd_pd); +} + +void MKLDNNBatchDotFwd::Execute(const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + auto engine = mxnet::CpuEngine::Get()->get_engine(); + auto data = mkldnn::memory(fwd_pd->src_desc(), engine, + reinterpret_cast(inputs[0].data().dptr_)); + auto weights = mkldnn::memory(fwd_pd->weights_desc(), engine, + reinterpret_cast(inputs[1].data().dptr_)); + mkldnn_output_t out_mem = CreateMKLDNNMem(outputs[0], fwd_pd->dst_desc(), req[0], &inputs[0]); + + mkldnn_args_map_t args = { + {MKLDNN_ARG_SRC, data}, + {MKLDNN_ARG_WEIGHTS, weights}, + {MKLDNN_ARG_DST, *out_mem.second}, + }; + + MKLDNNStream::Get()->RegisterPrimArgs(*fwd, args); + CommitOutput(outputs[0], out_mem); + MKLDNNStream::Get()->Submit(); +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_USE_ONEDNN == 1 diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h index 890e111de914..294a965e86d1 100644 --- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h @@ -128,6 +128,12 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector& req, const std::vector& outputs); +/* For batch dot */ +void MKLDNNBatchDotForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2, const mkldnn::memory &out); diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 7b39d9a087bb..f60a18d90444 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -65,6 +65,11 @@ struct DotParam : public dmlc::Parameter { .add_enum("csr", kCSRStorage) .set_default(dmlc::optional()); } + bool operator==(const DotParam& other) const { + return this->transpose_a == other.transpose_a && + this->transpose_b == other.transpose_b && + this->forward_stype == other.forward_stype; + } std::string ForwardStype2String(int forward_stype) { switch (forward_stype) { case kDefaultStorage: @@ -1482,5 +1487,16 @@ inline bool BatchDotShape(const nnvm::NodeAttrs& attrs, } // namespace op } // namespace mxnet - +namespace std { +template<> +struct hash { + size_t operator()(const mxnet::op::DotParam& val) { + size_t ret = 0; + ret = dmlc::HashCombine(ret, val.transpose_a); + ret = dmlc::HashCombine(ret, val.transpose_b); + ret = dmlc::HashCombine(ret, val.forward_stype); + return ret; + } +}; +} // namespace std #endif // MXNET_OPERATOR_TENSOR_DOT_INL_H_ diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc index b3f6331067ea..53960e7be23d 100644 --- a/src/operator/tensor/dot.cc +++ b/src/operator/tensor/dot.cc @@ -23,6 +23,10 @@ */ #include "./dot-inl.h" +#if MXNET_USE_ONEDNN == 1 +#include "./../nn/mkldnn/mkldnn_base-inl.h" +#include "./../nn/mkldnn/mkldnn_ops-inl.h" +#endif // MXNET_USE_ONEDNN namespace mxnet { namespace op { @@ -111,6 +115,34 @@ NNVM_REGISTER_OP(_backward_dot) .set_attr("FComputeEx", DotBackwardEx) .add_arguments(DotParam::__FIELDS__()); +#if MXNET_USE_ONEDNN == 1 +static void BatchDotComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (SupportMKLDNNBatchDot(inputs[0])) { + MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); + MKLDNNRun(MKLDNNBatchDotForward, attrs, ctx, inputs, req, outputs); + MKLDNN_OPCHECK_RUN(BatchDotForward_, attrs, ctx, inputs, req, outputs); + return; + } + FallBackCompute(BatchDotForward_, attrs, ctx, inputs, req, outputs); +} + +inline static bool BatchDotStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2); + CHECK_EQ(out_attrs->size(), 1); + + return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, + out_attrs); +} +#endif + NNVM_REGISTER_OP(batch_dot) .add_alias("_npx_batch_dot") .describe(R"doc(Batchwise dot product. @@ -140,6 +172,11 @@ which is computed by:: }) .set_attr("THasDeterministicOutput", true) .set_attr("FCompute", BatchDotForward_) +#if MXNET_USE_ONEDNN == 1 +.set_attr("TIsMKLDNN", true) +.set_attr("FInferStorageType", BatchDotStorageType) +.set_attr("FComputeEx", BatchDotComputeExCPU) +#endif .set_attr("FGradient", [](const nnvm::ObjectPtr& n, const std::vector& ograds) {