Skip to content

Commit

Permalink
Fix Activation backward shape inference
Browse files Browse the repository at this point in the history
  • Loading branch information
larroy committed Nov 27, 2018
1 parent a8a7f03 commit 33b436c
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 40 deletions.
20 changes: 12 additions & 8 deletions src/operator/elemwise_op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,29 +128,33 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
if (n_out != -1)
out_size = static_cast<size_t>(n_out);

auto deduce = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
CHECK_LE(in_size, in_attrs->size());
CHECK_LE(out_size, out_attrs->size());
auto deduce = [&](const std::vector<AttrType>& vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
CHECK(assign(&dattr, (*vec)[i]))
CHECK(assign(&dattr, vec.at(i)))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< name << ": " << "expected " << attr_string(dattr)
<< ", got " << attr_string((*vec)[i]);
<< ", got " << attr_string(vec.at(i));
}
};
deduce(in_attrs, in_size, "input");
if (reverse_infer) deduce(out_attrs, out_size, "output");
deduce(*in_attrs, in_size, "input");
if (reverse_infer)
deduce(*out_attrs, out_size, "output");

auto write = [&](std::vector<AttrType> *vec, size_t size, const char *name) {
for (size_t i = 0; i < size; ++i) {
CHECK(assign(&(*vec)[i], dattr))
CHECK(assign(&(vec->at(i)), dattr))
<< "Incompatible attr in node " << attrs.name << " at " << i << "-th "
<< name << ": " << "expected " << attr_string(dattr)
<< ", got " << attr_string((*vec)[i]);
<< ", got " << attr_string(vec->at(i));
}
};
write(in_attrs, in_size, "input");
write(out_attrs, out_size, "output");

if (is_none(dattr)) return false;
if (is_none(dattr))
return false;
return true;
}

Expand Down
12 changes: 5 additions & 7 deletions src/operator/nn/activation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ enum ActivationOpInputs {kData};
enum ActivationOpOutputs {kOut};
enum ActivationOpResource {kTempSpace};
enum ActivationOpType {kReLU, kSigmoid, kTanh, kSoftReLU, kSoftSign};

// Get the number of inputs to the gradient depending on the activation type
int ActivationGradNumInputs(int act_type);
} // activation

struct ActivationParam : public dmlc::Parameter<ActivationParam> {
Expand Down Expand Up @@ -199,13 +202,8 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
bool relu = param.act_type == activation::kReLU;
CHECK_EQ(inputs.size(), relu ? 2U : 3U);
#else
bool softsign = param.act_type == activation::kSoftSign;
CHECK_EQ(inputs.size(), softsign ? 3U : 2U);
#endif
const int act_type = param.act_type;
CHECK_EQ(inputs.size(), activation::ActivationGradNumInputs(act_type));
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
ActivationGradComputeImpl<xpu>(attrs, ctx, inputs, req, outputs);
Expand Down
113 changes: 97 additions & 16 deletions src/operator/nn/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,54 @@
namespace mxnet {
namespace op {

namespace activation {

int ActivationGradNumInputs(int act_type) {
#if MXNET_USE_CUDNN == 1
// check activation.cu \sa ActivationGradCompute
switch (act_type) {
case kReLU:
case kSoftReLU:
return 2;
case kSoftSign:
case kTanh:
case kSigmoid:
return 3;
default:
CHECK(false) << "missing activation type";
}
#elif MXNET_USE_MKLDNN == 1
// \sa ActivationGradComputeExCPU
switch (act_type) {
case kReLU:
return 2;
case kSigmoid:
case kTanh:
case kSoftReLU:
case kSoftSign:
return 3;
default:
CHECK(false) << "missing activation type";
}
#else
// check activation-inl.h \sa ActivationGradComputeImpl
switch (act_type) {
case kReLU:
case kSigmoid:
case kTanh:
case kSoftReLU:
return 2;
case kSoftSign:
return 3;
default:
CHECK(false) << "missing activation type";
}
#endif
// unreachable
return -1;
}
} // namespace activation

DMLC_REGISTER_PARAMETER(ActivationParam);

// This will determine the order of the inputs for backward computation.
Expand All @@ -48,18 +96,52 @@ struct ActivationGrad {
heads.emplace_back(nnvm::NodeEntry{n, activation::kOut, 0});

const NodeAttrs& attrs = n->attrs;
using namespace activation;
int act_type = dmlc::get<ActivationParam>(attrs.parsed).act_type;
if (act_type == activation::kSoftSign) {
// for softsign need the inputs to compute the activation.
heads.push_back(n->inputs[activation::kData]);
}

#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
#if MXNET_USE_CUDNN == 1
// for ReLU, no need to pass input data. This enables inplace optimization during the
// forward pass.
if (act_type != activation::kReLU &&
act_type != activation::kSoftSign) {
heads.push_back(n->inputs[activation::kData]);
// check activation.cu \sa ActivationGradCompute
switch (act_type) {
case kReLU:
case kSoftReLU:
break;
case kSoftSign:
case kTanh:
case kSigmoid:
heads.push_back(n->inputs[activation::kData]);
break;
default:
CHECK(false) << "missing activation type";
}
#elif MXNET_USE_MKLDNN == 1
// \sa ActivationGradComputeExCPU
switch (act_type) {
case kReLU:
break;
case kSoftSign:
case kTanh:
case kSoftReLU:
case kSigmoid:
heads.push_back(n->inputs[activation::kData]);
break;
default:
CHECK(false) << "missing activation type";
}

#else
// check activation-inl.h \sa ActivationGradComputeImpl
switch (act_type) {
case kSoftSign:
heads.push_back(n->inputs[activation::kData]);
break;
case kReLU:
case kTanh:
case kSoftReLU:
case kSigmoid:
break;
default:
CHECK(false) << "missing activation type";
}
#endif
return MakeGradNode(op_name, n, heads, n->attrs.dict);
Expand Down Expand Up @@ -133,6 +215,7 @@ inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs,
}
#endif


MXNET_OPERATOR_REGISTER_UNARY(Activation)
.describe(R"code(Applies an activation function element-wise to the input.
Expand Down Expand Up @@ -163,18 +246,16 @@ The following activation functions are supported:

NNVM_REGISTER_OP(_backward_Activation)
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
int act_type = dmlc::get<ActivationParam>(attrs.parsed).act_type;
// for ReLU activation, the backward pass only needs ograd and output
if (act_type == activation::kReLU) return 2;
return 3;
})
const int act_type = dmlc::get<ActivationParam>(attrs.parsed).act_type;
return activation::ActivationGradNumInputs(act_type);
})
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
#if MXNET_USE_MKLDNN == 1
.set_attr<FInferStorageType>("FInferStorageType", BackwardActStorageType)
#endif
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<-1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
Expand Down
22 changes: 14 additions & 8 deletions src/operator/nn/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ void ActivationCompute<gpu>(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
const int act_type = param.act_type;

// SoftReLU and kSoftSign are both not supported by CUDNN yet
if (param.act_type == activation::kSoftReLU) {
if (act_type == activation::kSoftReLU) {
ActivationForward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(ctx,
inputs[0], req[0], outputs[0]);
} else if (param.act_type == activation::kSoftSign) {
} else if (act_type == activation::kSoftSign) {
ActivationForward<gpu, mshadow_op::softsign, mshadow_op::softsign_grad>(ctx,
inputs[0], req[0], outputs[0]);
} else {
Expand All @@ -76,22 +77,27 @@ void ActivationGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
bool relu = param.act_type == activation::kReLU;
CHECK_EQ(inputs.size(), relu ? 2U : 3U);
const int act_type = param.act_type;
CHECK_EQ(inputs.size(), ActivationGradNumInputs(act_type));
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);

// both SoftReLU and SoftSign not supported by CUDNN yet
if (param.act_type == activation::kSoftReLU) {
if (act_type == activation::kSoftReLU) {
ActivationBackward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(
ctx, inputs[0], inputs[1], req[0], outputs[0]);
} else if (param.act_type == activation::kSoftSign) {
} else if (act_type == activation::kSoftSign) {
ActivationBackward<gpu, mshadow_op::softsign, mshadow_op::softsign_grad>(
ctx, inputs[0], inputs[2], req[0], outputs[0]);
} else {
} else if (act_type == activation::kReLU) {
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
// XXX: for y = relu(x), y is passed as "in_data" to Backward()
get_cudnn_op<DType>(param).Backward(ctx, inputs[0], relu ? inputs[1] : inputs[2],
get_cudnn_op<DType>(param).Backward(ctx, inputs[0], inputs[1],
inputs[1], req[0], outputs[0]);
});
} else {
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
get_cudnn_op<DType>(param).Backward(ctx, inputs[0], inputs[2],
inputs[1], req[0], outputs[0]);
});
}
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/operator/activation_perf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ TEST(ACTIVATION_PERF, ExecuteBidirectional) {
"softrelu",
"softsign"
};
for(const string& activation : activations) {
for (const string& activation : activations) {
kwargs_t activation_args = {{"act_type", activation}};
test::op::CoreOperatorRunner<float> runner;
runner.RunBidirectional(false, { shape }, test::op::CoreOpExecutor<float>::ArgsWithOpName(
Expand Down

0 comments on commit 33b436c

Please sign in to comment.