Skip to content

Commit

Permalink
Fix Activation backward shape inference
Browse files Browse the repository at this point in the history
  • Loading branch information
larroy committed Nov 29, 2018
1 parent f7bd997 commit a2f2d44
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 48 deletions.
20 changes: 12 additions & 8 deletions src/operator/elemwise_op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,29 +128,33 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
if (n_out != -1)
out_size = static_cast<size_t>(n_out);

auto deduce = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
CHECK_LE(in_size, in_attrs->size());
CHECK_LE(out_size, out_attrs->size());
auto deduce = [&](const std::vector<AttrType>& vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
CHECK(assign(&dattr, (*vec)[i]))
CHECK(assign(&dattr, vec.at(i)))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< name << ": " << "expected " << attr_string(dattr)
<< ", got " << attr_string((*vec)[i]);
<< ", got " << attr_string(vec.at(i));
}
};
deduce(in_attrs, in_size, "input");
if (reverse_infer) deduce(out_attrs, out_size, "output");
deduce(*in_attrs, in_size, "input");
if (reverse_infer)
deduce(*out_attrs, out_size, "output");

auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
CHECK(assign(&(*vec)[i], dattr))
CHECK(assign(&(vec->at(i)), dattr))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< name << ": " << "expected " << attr_string(dattr)
<< ", got " << attr_string((*vec)[i]);
<< ", got " << attr_string(vec->at(i));
}
};
write(in_attrs, in_size, "input");
write(out_attrs, out_size, "output");

if (is_none(dattr)) return false;
if (is_none(dattr))
return false;
return true;
}

Expand Down
12 changes: 5 additions & 7 deletions src/operator/nn/activation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ enum ActivationOpInputs {kData};
enum ActivationOpOutputs {kOut};
enum ActivationOpResource {kTempSpace};
enum ActivationOpType {kReLU, kSigmoid, kTanh, kSoftReLU, kSoftSign};

// Get the number of inputs to the gradient depending on the activation type
int GradNumInputs(int act_type);
} // activation

struct ActivationParam : public dmlc::Parameter<ActivationParam> {
Expand Down Expand Up @@ -199,13 +202,8 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
bool relu = param.act_type == activation::kReLU;
CHECK_EQ(inputs.size(), relu ? 2U : 3U);
#else
bool softsign = param.act_type == activation::kSoftSign;
CHECK_EQ(inputs.size(), softsign ? 3U : 2U);
#endif
const int act_type = param.act_type;
CHECK_EQ(inputs.size(), activation::GradNumInputs(act_type));
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
ActivationGradComputeImpl<xpu>(attrs, ctx, inputs, req, outputs);
Expand Down
125 changes: 101 additions & 24 deletions src/operator/nn/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,112 @@
namespace mxnet {
namespace op {

namespace activation {

int GradNumInputs(int act_type) {
#if MXNET_USE_CUDNN == 1
// check activation.cu \sa ActivationGradCompute
switch (act_type) {
case kReLU:
case kSoftReLU:
return 2;
case kSoftSign:
case kTanh:
case kSigmoid:
return 3;
default:
CHECK(false) << "missing activation type";
}
#elif MXNET_USE_MKLDNN == 1
// \sa ActivationGradComputeExCPU
switch (act_type) {
case kReLU:
return 2;
case kSigmoid:
case kTanh:
case kSoftReLU:
case kSoftSign:
return 3;
default:
CHECK(false) << "missing activation type";
}
#else
// check activation-inl.h \sa ActivationGradComputeImpl
switch (act_type) {
case kReLU:
case kSigmoid:
case kTanh:
case kSoftReLU:
return 2;
case kSoftSign:
return 3;
default:
CHECK(false) << "missing activation type";
}
#endif
// unreachable
return -1;
}
} // namespace activation

DMLC_REGISTER_PARAMETER(ActivationParam);

// This will determine the order of the inputs for backward computation.
struct ActivationGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
// ograds, output...
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
heads.emplace_back(nnvm::NodeEntry{n, activation::kOut, 0});

const NodeAttrs& attrs = n->attrs;
using namespace activation;
int act_type = dmlc::get<ActivationParam>(attrs.parsed).act_type;
if (act_type == activation::kSoftSign) {
// for softsign need the inputs to compute the activation.
heads.push_back(n->inputs[activation::kData]);
}

#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
#if MXNET_USE_CUDNN == 1
// for ReLU, no need to pass input data. This enables inplace optimization during the
// forward pass.
if (act_type != activation::kReLU &&
act_type != activation::kSoftSign) {
heads.push_back(n->inputs[activation::kData]);
// check activation.cu \sa ActivationGradCompute
switch (act_type) {
case kReLU:
case kSoftReLU:
break;
case kSoftSign:
case kTanh:
case kSigmoid:
heads.push_back(n->inputs[activation::kData]);
break;
default:
CHECK(false) << "missing activation type";
}
#elif MXNET_USE_MKLDNN == 1
// \sa ActivationGradComputeExCPU
switch (act_type) {
case kReLU:
break;
case kSoftSign:
case kTanh:
case kSoftReLU:
case kSigmoid:
heads.push_back(n->inputs[activation::kData]);
break;
default:
CHECK(false) << "missing activation type";
}

#else
// check activation-inl.h \sa ActivationGradComputeImpl
switch (act_type) {
case kSoftSign:
heads.push_back(n->inputs[activation::kData]);
break;
case kReLU:
case kTanh:
case kSoftReLU:
case kSigmoid:
break;
default:
CHECK(false) << "missing activation type";
}
#endif
return MakeGradNode(op_name, n, heads, n->attrs.dict);
Expand Down Expand Up @@ -89,11 +172,11 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
bool relu = param.act_type == activation::kReLU;
CHECK_EQ(inputs.size(), relu ? 2U : 3U);
CHECK_EQ(inputs.size(), activation::GradNumInputs(param.act_type));
if (SupportMKLDNN(inputs[0])) {
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
// XXX: for y = relu(x), y is passed as "in_data" to Backward()
const bool relu = param.act_type == activation::kReLU;
MKLDNNActivationBackward(attrs, ctx, inputs[0], relu ? inputs[1] : inputs[2], req[0],
outputs[0]);
MKLDNN_OPCHECK_RUN(ActivationGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
Expand Down Expand Up @@ -122,17 +205,13 @@ inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
if (param.act_type != activation::kReLU) {
CHECK_EQ(in_attrs->size(), 3U);
} else {
// for ReLU activation, the backward pass only needs ograd and output
CHECK_EQ(in_attrs->size(), 2U);
}
CHECK_EQ(in_attrs->size(), activation::GradNumInputs(param.act_type));
return MKLDNNStorageType(attrs, dev_mask, SupportMKLDNNAct(param),
dispatch_mode, in_attrs, out_attrs);
}
#endif


MXNET_OPERATOR_REGISTER_UNARY(Activation)
.describe(R"code(Applies an activation function element-wise to the input.
Expand Down Expand Up @@ -163,18 +242,16 @@ The following activation functions are supported:

NNVM_REGISTER_OP(_backward_Activation)
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
int act_type = dmlc::get<ActivationParam>(attrs.parsed).act_type;
// for ReLU activation, the backward pass only needs ograd and output
if (act_type == activation::kReLU) return 2;
return 3;
})
const int act_type = dmlc::get<ActivationParam>(attrs.parsed).act_type;
return activation::GradNumInputs(act_type);
})
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
#if MXNET_USE_MKLDNN == 1
.set_attr<FInferStorageType>("FInferStorageType", BackwardActStorageType)
#endif
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<-1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
Expand Down
22 changes: 14 additions & 8 deletions src/operator/nn/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ void ActivationCompute<gpu>(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
const int act_type = param.act_type;

// SoftReLU and kSoftSign are both not supported by CUDNN yet
if (param.act_type == activation::kSoftReLU) {
if (act_type == activation::kSoftReLU) {
ActivationForward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(ctx,
inputs[0], req[0], outputs[0]);
} else if (param.act_type == activation::kSoftSign) {
} else if (act_type == activation::kSoftSign) {
ActivationForward<gpu, mshadow_op::softsign, mshadow_op::softsign_grad>(ctx,
inputs[0], req[0], outputs[0]);
} else {
Expand All @@ -76,22 +77,27 @@ void ActivationGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
bool relu = param.act_type == activation::kReLU;
CHECK_EQ(inputs.size(), relu ? 2U : 3U);
const int act_type = param.act_type;
CHECK_EQ(inputs.size(), activation::GradNumInputs(act_type));
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);

// both SoftReLU and SoftSign not supported by CUDNN yet
if (param.act_type == activation::kSoftReLU) {
if (act_type == activation::kSoftReLU) {
ActivationBackward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(
ctx, inputs[0], inputs[1], req[0], outputs[0]);
} else if (param.act_type == activation::kSoftSign) {
} else if (act_type == activation::kSoftSign) {
ActivationBackward<gpu, mshadow_op::softsign, mshadow_op::softsign_grad>(
ctx, inputs[0], inputs[2], req[0], outputs[0]);
} else {
} else if (act_type == activation::kReLU) {
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
// XXX: for y = relu(x), y is passed as "in_data" to Backward()
get_cudnn_op<DType>(param).Backward(ctx, inputs[0], relu ? inputs[1] : inputs[2],
get_cudnn_op<DType>(param).Backward(ctx, inputs[0], inputs[1],
inputs[1], req[0], outputs[0]);
});
} else {
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
get_cudnn_op<DType>(param).Backward(ctx, inputs[0], inputs[2],
inputs[1], req[0], outputs[0]);
});
}
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/operator/activation_perf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ TEST(ACTIVATION_PERF, ExecuteBidirectional) {
"softrelu",
"softsign"
};
for(const string& activation : activations) {
for (const string& activation : activations) {
kwargs_t activation_args = {{"act_type", activation}};
test::op::CoreOperatorRunner<float> runner;
runner.RunBidirectional(false, { shape }, test::op::CoreOpExecutor<float>::ArgsWithOpName(
Expand Down

0 comments on commit a2f2d44

Please sign in to comment.