diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index a7e63f6d4139..5d722581257f 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -147,7 +147,10 @@ void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - if (SupportMKLDNNFC(inputs[0])) { + // TODO(rongzha1): disable due to flakiness in cpp test IMPERATIVE.FullyConnectedOp + // Will be fixed when we decide to enable the backward of FC. + bool mkldnn_fc_backward_enable = false; + if (mkldnn_fc_backward_enable && SupportMKLDNNFC(inputs[0])) { MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs); MKLDNNFCBackward(attrs, ctx, inputs, req, outputs); MKLDNN_OPCHECK_RUN(FullyConnectedGradCompute, attrs, ctx, inputs, req, @@ -229,10 +232,12 @@ static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs, uint32_t out_expected = param.no_bias ? 2 : 3; CHECK_EQ(in_attrs->size(), 3U); CHECK_EQ(out_attrs->size(), out_expected); + // TODO(zhengda) let's disable MKLDNN for FullyConnected for now. + // It seems there is a bug. bool dispatched = false; if (!dispatched && common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) { dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage, - dispatch_mode, DispatchMode::kFComputeEx); + dispatch_mode, DispatchMode::kFCompute); } if (!dispatched && common::ContainsStorageType(*in_attrs, mxnet::kRowSparseStorage)) { dispatched = dispatch_fallback(out_attrs, dispatch_mode); diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index 7627d02c4702..1403cd114201 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -290,6 +290,24 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad)); CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace"; + if (req[fullc::kData]) { + mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData( + data, weight, out_grad, fwd_pd); + auto out_grad_mem = out_grad.GetMKLDNNDataReorder( + ipBwdData_pd.diff_dst_desc()); + auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc()); + auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData], + ipBwdData_pd.diff_src_desc(), + req[fullc::kData]); + mkldnn_args_map_t args = { + {MKLDNN_ARG_DIFF_DST, *out_grad_mem}, + {MKLDNN_ARG_WEIGHTS, *weight_mem}, + {MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second} + }; + + MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args); + CommitOutput(in_grad[fullc::kData], in_grad_mem); + } if (req[fullc::kWeight]) { mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd = GetFCBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], @@ -318,24 +336,6 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, CommitOutput(in_grad[fullc::kWeight], in_grad_weight); CommitOutput(in_grad[fullc::kBias], in_grad_bias); } - if (req[fullc::kData]) { - mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData( - data, weight, out_grad, fwd_pd); - auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - ipBwdData_pd.diff_dst_desc()); - auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc()); - auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData], - ipBwdData_pd.diff_src_desc(), - req[fullc::kData]); - mkldnn_args_map_t args = { - {MKLDNN_ARG_DIFF_DST, *out_grad_mem}, - {MKLDNN_ARG_WEIGHTS, *weight_mem}, - {MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second} - }; - - MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args); - CommitOutput(in_grad[fullc::kData], in_grad_mem); - } MKLDNNStream::Get()->Submit(); }