Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1234] Fix shape inference problems in Activation backward #13409

Merged
merged 5 commits into from
Dec 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions src/operator/elemwise_op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,29 +128,33 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
if (n_out != -1)
out_size = static_cast<size_t>(n_out);

auto deduce = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we revert this change? The common practice is that if we change the elements in the vector, we will use pointer instead of const reference.

Copy link
Contributor Author

@larroy larroy Nov 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. We are not changing the vector in this function if you look closely. So your comment doesn't apply.

CHECK_LE(in_size, in_attrs->size());
CHECK_LE(out_size, out_attrs->size());
auto deduce = [&](const std::vector<AttrType>& vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
CHECK(assign(&dattr, (*vec)[i]))
CHECK(assign(&dattr, vec.at(i)))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< name << ": " << "expected " << attr_string(dattr)
<< ", got " << attr_string((*vec)[i]);
<< ", got " << attr_string(vec.at(i));
}
};
deduce(in_attrs, in_size, "input");
if (reverse_infer) deduce(out_attrs, out_size, "output");
deduce(*in_attrs, in_size, "input");
if (reverse_infer)
deduce(*out_attrs, out_size, "output");

auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
CHECK(assign(&(*vec)[i], dattr))
CHECK(assign(&(vec->at(i)), dattr))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< name << ": " << "expected " << attr_string(dattr)
<< ", got " << attr_string((*vec)[i]);
<< ", got " << attr_string(vec->at(i));
}
};
write(in_attrs, in_size, "input");
write(out_attrs, out_size, "output");

if (is_none(dattr)) return false;
if (is_none(dattr))
return false;
return true;
}

Expand Down
12 changes: 5 additions & 7 deletions src/operator/nn/activation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ enum ActivationOpInputs {kData};
enum ActivationOpOutputs {kOut};
enum ActivationOpResource {kTempSpace};
enum ActivationOpType {kReLU, kSigmoid, kTanh, kSoftReLU, kSoftSign};

// Get the number of inputs to the gradient depending on the activation type
int GradNumInputs(int act_type);
} // activation

struct ActivationParam : public dmlc::Parameter<ActivationParam> {
Expand Down Expand Up @@ -199,13 +202,8 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
bool relu = param.act_type == activation::kReLU;
CHECK_EQ(inputs.size(), relu ? 2U : 3U);
#else
bool softsign = param.act_type == activation::kSoftSign;
CHECK_EQ(inputs.size(), softsign ? 3U : 2U);
#endif
const int act_type = param.act_type;
CHECK_EQ(inputs.size(), activation::GradNumInputs(act_type));
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
ActivationGradComputeImpl<xpu>(attrs, ctx, inputs, req, outputs);
Expand Down
79 changes: 48 additions & 31 deletions src/operator/nn/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,38 +30,63 @@
#if MXNET_USE_MKLDNN == 1
#include "./mkldnn/mkldnn_base-inl.h"
#include "./mkldnn/mkldnn_ops-inl.h"
#endif // MXNET_USE_MKLDNN
#endif // MXNET_USE_MKLDNN == 1
#include "../operator_common.h"
#include "../../common/utils.h"

namespace mxnet {
namespace op {

namespace activation {

int GradNumInputs(int act_type) {
// check activation.cu \sa ActivationGradCompute
switch (act_type) {
case kReLU:
return 2;
case kSoftReLU:
case kSoftSign:
case kTanh:
case kSigmoid:
return 3;
default:
CHECK(false) << "missing activation type";
}
// unreachable
return -1;
}

} // namespace activation

DMLC_REGISTER_PARAMETER(ActivationParam);

// This will determine the order of the inputs for backward computation.
struct ActivationGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
// ograds, output...
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
heads.emplace_back(nnvm::NodeEntry{n, activation::kOut, 0});

const NodeAttrs& attrs = n->attrs;
using namespace activation;
int act_type = dmlc::get<ActivationParam>(attrs.parsed).act_type;
if (act_type == activation::kSoftSign) {
// for softsign need the inputs to compute the activation.
heads.push_back(n->inputs[activation::kData]);
}

#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
// for ReLU, no need to pass input data. This enables inplace optimization during the
// forward pass.
if (act_type != activation::kReLU &&
act_type != activation::kSoftSign) {
heads.push_back(n->inputs[activation::kData]);
// check activation.cu \sa ActivationGradCompute
switch (act_type) {
case kReLU:
break;
case kSoftReLU:
case kSoftSign:
case kTanh:
case kSigmoid:
heads.push_back(n->inputs[activation::kData]);
break;
default:
CHECK(false) << "missing activation type";
}
#endif
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};
Expand Down Expand Up @@ -89,21 +114,19 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
bool relu = param.act_type == activation::kReLU;
CHECK_EQ(inputs.size(), relu ? 2U : 3U);
CHECK_EQ(inputs.size(), activation::GradNumInputs(param.act_type));
if (SupportMKLDNN(inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
// XXX: for y = relu(x), y is passed as "in_data" to Backward()
MKLDNNActivationBackward(attrs, ctx, inputs[0], relu ? inputs[1] : inputs[2], req[0],
const bool relu = param.act_type == activation::kReLU;
MKLDNNActivationBackward(attrs, ctx, inputs.at(0), relu ? inputs.at(1) : inputs.at(2), req[0],
outputs[0]);
MKLDNN_OPCHECK_RUN(ActivationGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ActivationGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
FallBackCompute(ActivationGradComputeImpl<cpu>, attrs, ctx, inputs, req, outputs);
}
#endif

#if MXNET_USE_MKLDNN == 1
inline static bool ActivationStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
Expand All @@ -122,16 +145,12 @@ inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
if (param.act_type != activation::kReLU) {
CHECK_EQ(in_attrs->size(), 3U);
} else {
// for ReLU activation, the backward pass only needs ograd and output
CHECK_EQ(in_attrs->size(), 2U);
}
CHECK_EQ(in_attrs->size(), activation::GradNumInputs(param.act_type));
return MKLDNNStorageType(attrs, dev_mask, SupportMKLDNNAct(param),
dispatch_mode, in_attrs, out_attrs);
}
#endif
#endif // MXNET_USE_MKLDNN == 1


MXNET_OPERATOR_REGISTER_UNARY(Activation)
.describe(R"code(Applies an activation function element-wise to the input.
Expand Down Expand Up @@ -163,18 +182,16 @@ The following activation functions are supported:

NNVM_REGISTER_OP(_backward_Activation)
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
int act_type = dmlc::get<ActivationParam>(attrs.parsed).act_type;
// for ReLU activation, the backward pass only needs ograd and output
if (act_type == activation::kReLU) return 2;
return 3;
})
const int act_type = dmlc::get<ActivationParam>(attrs.parsed).act_type;
return activation::GradNumInputs(act_type);
})
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
#if MXNET_USE_MKLDNN == 1
.set_attr<FInferStorageType>("FInferStorageType", BackwardActStorageType)
#endif
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<-1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
Expand Down
30 changes: 18 additions & 12 deletions src/operator/nn/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ void ActivationCompute<gpu>(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
const int act_type = param.act_type;

// SoftReLU and kSoftSign are both not supported by CUDNN yet
if (param.act_type == activation::kSoftReLU) {
if (act_type == activation::kSoftReLU) {
ActivationForward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(ctx,
inputs[0], req[0], outputs[0]);
} else if (param.act_type == activation::kSoftSign) {
} else if (act_type == activation::kSoftSign) {
ActivationForward<gpu, mshadow_op::softsign, mshadow_op::softsign_grad>(ctx,
inputs[0], req[0], outputs[0]);
} else {
Expand All @@ -76,23 +77,28 @@ void ActivationGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
bool relu = param.act_type == activation::kReLU;
CHECK_EQ(inputs.size(), relu ? 2U : 3U);
const int act_type = param.act_type;
CHECK_EQ(inputs.size(), activation::GradNumInputs(act_type));
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);

// both SoftReLU and SoftSign not supported by CUDNN yet
if (param.act_type == activation::kSoftReLU) {
if (act_type == activation::kSoftReLU) {
ActivationBackward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(
ctx, inputs[0], inputs[1], req[0], outputs[0]);
} else if (param.act_type == activation::kSoftSign) {
ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]);
} else if (act_type == activation::kSoftSign) {
ActivationBackward<gpu, mshadow_op::softsign, mshadow_op::softsign_grad>(
ctx, inputs[0], inputs[2], req[0], outputs[0]);
} else {
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
ctx, inputs.at(0), inputs.at(2), req[0], outputs[0]);
} else if (act_type == activation::kReLU) {
MSHADOW_REAL_TYPE_SWITCH(inputs.at(0).type_flag_, DType, {
// XXX: for y = relu(x), y is passed as "in_data" to Backward()
get_cudnn_op<DType>(param).Backward(ctx, inputs[0], relu ? inputs[1] : inputs[2],
inputs[1], req[0], outputs[0]);
get_cudnn_op<DType>(param).Backward(ctx, inputs.at(0), inputs.at(1),
inputs.at(1), req[0], outputs[0]);
});
} else {
MSHADOW_REAL_TYPE_SWITCH(inputs.at(0).type_flag_, DType, {
get_cudnn_op<DType>(param).Backward(ctx, inputs.at(0), inputs.at(2),
inputs.at(1), req[0], outputs[0]);
});
}
}
Expand Down
26 changes: 20 additions & 6 deletions tests/cpp/operator/activation_perf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,27 @@ const kwargs_t basic_activation_args = { };
* \brief Generic bidirectional sanity test
*/
TEST(ACTIVATION_PERF, ExecuteBidirectional) {
using namespace std;
TShape shape({5, 5});
kwargs_t kwargs = basic_activation_args;
kwargs.push_back({"act_type", "tanh"});

test::op::CoreOperatorRunner<float> runner;
runner.RunBidirectional(false, { shape }, test::op::CoreOpExecutor<float>::ArgsWithOpName(
kwargs, "Activation", "_backward_Activation"), 1);
vector<string> activations = {
"relu",
"sigmoid",
"tanh",
"softrelu",
"softsign"
};
for (const string& activation : activations) {
kwargs_t activation_args = {{"act_type", activation}};
test::op::CoreOperatorRunner<float> runner;
runner.RunBidirectional(false, { shape }, test::op::CoreOpExecutor<float>::ArgsWithOpName(
activation_args, "Activation", "_backward_Activation"), 1);
}
Copy link
Member

@TaoLv TaoLv Dec 1, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this test case on my machine with both MKL-DNN and CUDNN enabled, without the last commit. All the 5 activations run into MXNet CPU implementation (ActivationCompute<cpu> and ActivationGradCompute<cpu>). If set isGPU=true, all the 5 activations run into GPU implementation (ActivationCompute<gpu> and ActivationGradCompute<gpu>). Seems no MKL-DNN implementation was triggered for both isGPU=false and isGPU=true.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@larroy How to reproduce "What I observed is that we are calculating the backward in MKLDNN even if we have CUDNN." ?

Copy link
Contributor Author

@larroy larroy Dec 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TaoLv you are right, this statement is not true in the master branch. I retract it.

We can use this patch to reproduce the bug and see what runs and doesn't run:

then run:

build/tests/mxnet_unit_tests --gtest_filter="ACTIVATION_PERF.ExecuteBidirectional"

diff --git a/src/operator/nn/activation-inl.h b/src/operator/nn/activation-inl.h
index 2705177..c656e33 100644
--- a/src/operator/nn/activation-inl.h
+++ b/src/operator/nn/activation-inl.h
@@ -88,6 +88,7 @@ void ActivationForward(const OpContext &ctx, const TBlob &in_data,
                        const OpReqType &req, const TBlob &out_data) {
   using namespace mshadow;
   using namespace mshadow::expr;
+  std::cout << "ActivationForward" << std::endl;
   Stream<xpu> *s = ctx.get_stream<xpu>();
   const size_t sz = in_data.shape_.Size();
   if (sz) {
@@ -104,6 +105,7 @@ template<typename xpu, typename ForwardOp, typename BackwardOp>
 void ActivationBackward(const OpContext &ctx, const TBlob &out_grad,
                         const TBlob &out_data, const OpReqType &req,
                         const TBlob &in_grad) {
+  std::cout << "ActivationBackward" << std::endl;
   using namespace mshadow;
   using namespace mshadow::expr;
   Stream<xpu> *s = ctx.get_stream<xpu>();
@@ -123,6 +125,7 @@ template<typename xpu>
 void ActivationComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                            const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
                            const std::vector<TBlob>& outputs) {
+  std::cout << "ActivationComputeImpl" << std::endl;
   const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
   switch (param.act_type) {
     case activation::kReLU:
@@ -154,6 +157,7 @@ template<typename xpu>
 void ActivationGradComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                            const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
                            const std::vector<TBlob>& outputs) {
+  std::cout << "ActivationGradComputeImpl" << std::endl;
   const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
   switch (param.act_type) {
     case activation::kReLU:
@@ -187,6 +191,7 @@ void ActivationCompute(const nnvm::NodeAttrs& attrs,
                        const std::vector<TBlob>& inputs,
                        const std::vector<OpReqType>& req,
                        const std::vector<TBlob>& outputs) {
+  std::cout << "ActivationCompute" << std::endl;
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
   ActivationComputeImpl<xpu>(attrs, ctx, inputs, req, outputs);
@@ -198,6 +203,7 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs,
                            const std::vector<TBlob>& inputs,
                            const std::vector<OpReqType>& req,
                            const std::vector<TBlob>& outputs) {
+  std::cout << "ActivationGradCompute" << std::endl;
   const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
 #if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
   bool relu = param.act_type == activation::kReLU;
diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc
index ba44ebd..cb5d5ab 100644
--- a/src/operator/nn/activation.cc
+++ b/src/operator/nn/activation.cc
@@ -72,6 +72,7 @@ static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs,
                                    const std::vector<NDArray>& inputs,
                                    const std::vector<OpReqType>& req,
                                    const std::vector<NDArray>& outputs) {
+  std::cout << "ActivationComputeExCPU" << std::endl;
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
   if (SupportMKLDNN(inputs[0])) {
@@ -88,6 +89,7 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
                                 const std::vector<NDArray>& inputs,
                                 const std::vector<OpReqType>& req,
                                 const std::vector<NDArray>& outputs) {
+  std::cout << "ActivationGradComputeExCPU" << std::endl;
   const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
   bool relu = param.act_type == activation::kReLU;
   CHECK_EQ(inputs.size(), relu ? 2U : 3U);
diff --git a/src/operator/nn/activation.cu b/src/operator/nn/activation.cu
index 8892cc3..5aaeb78 100644
--- a/src/operator/nn/activation.cu
+++ b/src/operator/nn/activation.cu
@@ -51,6 +51,7 @@ void ActivationCompute<gpu>(const nnvm::NodeAttrs& attrs,
     const std::vector<TBlob>& inputs,
     const std::vector<OpReqType>& req,
     const std::vector<TBlob>& outputs) {
+  std::cout << "ActivationCompute (GPU)" << std::endl;
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
   const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
@@ -75,6 +76,7 @@ void ActivationGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
                                 const std::vector<TBlob>& inputs,
                                 const std::vector<OpReqType>& req,
                                 const std::vector<TBlob>& outputs) {
+  std::cout << "ActivationGradCompute (GPU)" << std::endl;
   const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
   bool relu = param.act_type == activation::kReLU;
   CHECK_EQ(inputs.size(), relu ? 2U : 3U);
diff --git a/tests/cpp/operator/activation_perf.cc b/tests/cpp/operator/activation_perf.cc
index 1bd8ca8..bba8a3e 100644
--- a/tests/cpp/operator/activation_perf.cc
+++ b/tests/cpp/operator/activation_perf.cc
@@ -38,13 +38,27 @@ const kwargs_t basic_activation_args = { };
  * \brief Generic bidirectional sanity test
  */
 TEST(ACTIVATION_PERF, ExecuteBidirectional) {
+  using namespace std;
   TShape shape({5, 5});
-  kwargs_t kwargs = basic_activation_args;
-  kwargs.push_back({"act_type", "tanh"});
-
-  test::op::CoreOperatorRunner<float> runner;
-  runner.RunBidirectional(false, { shape }, test::op::CoreOpExecutor<float>::ArgsWithOpName(
-          kwargs, "Activation", "_backward_Activation"), 1);
+  vector<string> activations = {
+    "relu",
+    "sigmoid",
+    "tanh",
+    "softrelu",
+    "softsign"
+  };
+  for (const string& activation : activations) {
+    kwargs_t activation_args = {{"act_type", activation}};
+    test::op::CoreOperatorRunner<float> runner;
+    runner.RunBidirectional(false, { shape }, test::op::CoreOpExecutor<float>::ArgsWithOpName(
+            activation_args, "Activation", "_backward_Activation"), 1);
+  }
+  for (const string& activation : activations) {
+    kwargs_t activation_args = {{"act_type", activation}};
+    test::op::CoreOperatorRunner<float> runner;
+    runner.RunBidirectional(true, { shape }, test::op::CoreOpExecutor<float>::ArgsWithOpName(
+            activation_args, "Activation", "_backward_Activation"), 1);
+  }
 }
 
 /*!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied this patch to mxnet master branch and got below error:

[lvtao@mlt-gpu207 mxnet-official]$ ./mxnet_unit_tests --gtest_filter="ACTIVATION_PERF.ExecuteBidirectional"
Found CUDA Device #: 0 properties: 6.0

Note: Google Test filter = ACTIVATION_PERF.ExecuteBidirectional
[==========] Running 1 test from 1 test case.
[----------] Global test environment set-up.
[----------] 1 test from ACTIVATION_PERF
[ RUN      ] ACTIVATION_PERF.ExecuteBidirectional
unknown file: Failure
C++ exception with description "[10:31:44] src/operator/tensor/./../elemwise_op_common.h:176: Check failed: in_attrs->size() == static_cast<size_t>(n_in) (2 vs. 3)  in operator

Stack trace returned 10 entries:
[bt] (0) ./mxnet_unit_tests(dmlc::StackTrace()+0x4a) [0x42f64a]
[bt] (1) ./mxnet_unit_tests(dmlc::LogMessageFatal::~LogMessageFatal()+0x21) [0x42fca1]
[bt] (2) lib/libmxnet.so(bool mxnet::op::ElemwiseType<3l, 1l>(nnvm::NodeAttrs const&, std::vector<int, std::allocator<int> >*, std::vector<int, std::allocator<int> >*)+0xe2) [0x7f1f0691bd52]
[bt] (3) ./mxnet_unit_tests() [0x477593]
[bt] (4) ./mxnet_unit_tests() [0x43a007]
[bt] (5) ./mxnet_unit_tests() [0x479a20]
[bt] (6) ./mxnet_unit_tests() [0x5e4838]
[bt] (7) ./mxnet_unit_tests() [0x5e5bc8]
[bt] (8) ./mxnet_unit_tests() [0x5e6119]
[bt] (9) ./mxnet_unit_tests() [0x632904]

" thrown in the test body.
[  FAILED  ] ACTIVATION_PERF.ExecuteBidirectional (27 ms)
[----------] 1 test from ACTIVATION_PERF (27 ms total)

[----------] Global test environment tear-down
[==========] 1 test from 1 test case ran. (27 ms total)
[  PASSED  ] 0 tests.
[  FAILED  ] 1 test, listed below:
[  FAILED  ] ACTIVATION_PERF.ExecuteBidirectional

 1 FAILED TEST

for (const string& activation : activations) {
kwargs_t activation_args = {{"act_type", activation}};
test::op::CoreOperatorRunner<float> runner;
runner.RunBidirectional(true, { shape }, test::op::CoreOpExecutor<float>::ArgsWithOpName(
activation_args, "Activation", "_backward_Activation"), 1);
}
}

/*!
Expand Down
12 changes: 6 additions & 6 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -2411,7 +2411,7 @@ def hybrid_forward(self, F, x):
x_reshape = x.reshape(self.reshape)
out = self.act(x_reshape)
return out
acts = ["relu", "sigmoid", "tanh", "softrelu"]
acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"]
for act in acts:
x = mx.nd.random.uniform(-1, 1, shape=(4, 16, 32, 32))
shape = (4, 32, 32, -1)
Expand All @@ -2433,7 +2433,7 @@ def hybrid_forward(self, F, x):
out = self.act(x_slice)
return out

acts = ["relu", "sigmoid", "tanh", "softrelu"]
acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"]
for act in acts:
x = mx.nd.random.uniform(-1, 1, shape=(8, 32, 64, 64))
slice = [(0, 16, 32, 32), (4, 32, 64, 64)]
Expand All @@ -2457,7 +2457,7 @@ def hybrid_forward(self, F, x):
y_reshape = y.reshape(self.reshape[1])
out = self.act1(y_reshape)
return out
acts = ["relu", "sigmoid", "tanh", "softrelu"]
acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"]
for idx0, act0 in enumerate(acts):
for idx1, act1 in enumerate(acts):
if idx1 == idx0:
Expand All @@ -2484,7 +2484,7 @@ def hybrid_forward(self, F, x):
y_slice = y.slice(begin=self.slice[1][0], end=self.slice[1][1])
out = self.act1(y_slice)
return out
acts = ["relu", "sigmoid", "tanh", "softrelu"]
acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"]
for idx0, act0 in enumerate(acts):
for idx1, act1 in enumerate(acts):
if idx1 == idx0:
Expand Down Expand Up @@ -2512,7 +2512,7 @@ def hybrid_forward(self, F, x):
y_slice = y.slice(begin=self.slice[0], end=self.slice[1])
out = self.act1(y_slice)
return out
acts = ["relu", "sigmoid", "tanh", "softrelu"]
acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"]
for idx0, act0 in enumerate(acts):
for idx1, act1 in enumerate(acts):
if idx1 == idx0:
Expand Down Expand Up @@ -2541,7 +2541,7 @@ def hybrid_forward(self, F, x):
y_reshape = y.reshape(self.reshape)
out = self.act1(y_reshape)
return out
acts = ["relu", "sigmoid", "tanh", "softrelu"]
acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"]
for idx0, act0 in enumerate(acts):
for idx1, act1 in enumerate(acts):
if idx1 == idx0:
Expand Down