Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MKLDNN] support using any format in pooling backward #17900

Merged
merged 2 commits into from
Apr 10, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions src/operator/nn/mkldnn/mkldnn_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
namespace mxnet {
namespace op {

static inline mkldnn::memory::data_type get_data_type(const mkldnn::memory::desc &md) {
return static_cast<mkldnn::memory::data_type>(md.data_type());
}

void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &output,
const mkldnn::memory::dims &kernel,
const mkldnn::memory::dims &strides,
Expand Down Expand Up @@ -82,7 +86,7 @@ void MKLDNNPoolingFwd::Execute(const NDArray &in_data,
auto engine = CpuEngine::Get()->get_engine();

if (workspace == nullptr) {
LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input";
LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input";
}

auto ws = std::make_shared<mkldnn::memory>((*(this->fwd_pd_)).workspace_desc(),
Expand Down Expand Up @@ -332,20 +336,21 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,

auto it = pooling_bwds.find(key);
if (it == pooling_bwds.end()) {
NDArray diff_dst_buff = out_grad;
if (in_data.IsMKLDNNData() == false && diff_dst_buff.IsMKLDNNData() == true) {
diff_dst_buff = out_grad.Reorder2Default();
}
auto diff_dst_mem = diff_dst_buff.GetMKLDNNData();
auto input_mem = in_data.GetMKLDNNData();
const mkldnn::memory::desc data_md = input_mem->get_desc();
const mkldnn::memory::desc out_md = GetMemDesc(out_grad);
auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, out_md);
const mkldnn::memory::desc diff_md = diff_dst_mem->get_desc();
auto data_md = input_mem->get_desc();

const mkldnn::memory::desc diff_in_md = GetMemDesc(in_grad);
const mkldnn::engine cpu_engine = CpuEngine::Get()->get_engine();
const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
auto dst_dims = mkldnn::memory::dims(out_grad.shape().begin(), out_grad.shape().end());
auto any = mkldnn::memory::format_tag::any;
auto dst_md = mkldnn::memory::desc(dst_dims, get_data_type(data_md), any);

// fwd hint
auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, dst_md);

// creat bwd desc
auto diff_src_dims = mkldnn::memory::dims(in_grad.shape().begin(), in_grad.shape().end());
auto diff_src_md = mkldnn::memory::desc(diff_src_dims, get_data_type(data_md), any);
auto cpu_engine = CpuEngine::Get()->get_engine();;
auto alg = GetMKLDNNPoolAlgo(param);

const int kernel_ndims = param.kernel.ndim();
mkldnn::memory::dims kernel(kernel_ndims);
Expand All @@ -355,9 +360,11 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,

InitPoolingPrimitiveParams(param, data_md, kernel, strides, pad_l, pad_r);

const mkldnn::pooling_backward::desc desc(
alg, diff_in_md, diff_md, strides, kernel, pad_l, pad_r);
const auto pdesc = mkldnn::pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd);
// use dst_md as diff_dst_md with any format
auto bwd_desc = mkldnn::pooling_backward::desc(alg, diff_src_md, dst_md,
strides, kernel, pad_l, pad_r);
auto pdesc = mkldnn::pooling_backward::primitive_desc(bwd_desc, cpu_engine, fwd_pd);

MKLDNNPoolingBwd bwd(pdesc, with_workspace);
it = AddToCache(&pooling_bwds, key, bwd);
}
Expand All @@ -371,15 +378,15 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam &param,
if (req == kNullOp) {
return;
}

TmpMemMgr::Get()->Init(ctx.requested[0]);

auto &bwd = GetPoolingBwd(param, in_data, in_grad, out_grad);
auto diff_src_mem =
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req);

auto diff_dst_mem = out_grad.GetMKLDNNDataReorder(bwd.pd.diff_dst_desc());
auto diff_src_mem = CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req);
mkldnn_args_map_t args = {
{MKLDNN_ARG_DIFF_DST, *(out_grad.GetMKLDNNData())},
{MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second },
{MKLDNN_ARG_DIFF_DST, *diff_dst_mem},
{MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second},
};
if (MKLDNNRequireWorkspace(param) && workspace != nullptr) {
args[MKLDNN_ARG_WORKSPACE] = *(workspace->GetMKLDNNData());
Expand Down