Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix axis Bug in MKLDNN Softmax #11335

Merged
merged 5 commits into from
Jun 20, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/operator/nn/mkldnn/mkldnn_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "../softmax-inl.h"
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"
#include "../../tensor/broadcast_reduce_op.h"

#if MXNET_USE_MKLDNN == 1
namespace mxnet {
Expand All @@ -38,11 +39,13 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
auto input_mem = in_data.GetMKLDNNData();
mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
mkldnn::memory::desc data_md = data_mpd.desc();
int axis = CheckAxis(param.axis, in_data.shape().ndim());

auto cpu_engine = data_mpd.get_engine();
auto prop = ctx.is_train
? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
mkldnn::softmax_forward::desc desc = mkldnn::softmax_forward::desc(prop,
data_md, param.axis);
data_md, axis);
mkldnn::softmax_forward::primitive_desc pdesc(desc, cpu_engine);

auto output_memory = out_data.GetMKLDNNData();
Expand Down
3 changes: 1 addition & 2 deletions src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
// It seems MKLDNN softmax doesn't support training.
// and it only supports non-negative axis.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please also remove this comment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not using the latest MKLDNN.

if (SupportMKLDNN(inputs[0]) && !ctx.is_train && param.axis >= 0) {
if (SupportMKLDNN(inputs[0]) && !ctx.is_train) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNSoftmaxForward(attrs, ctx, inputs[0], req[0], outputs[0]);
auto fn = SoftmaxCompute<cpu, mxnet_op::softmax_fwd>;
Expand Down