Skip to content

Commit

Permalink
[RELAY]prelu op support (apache#2016)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored and tqchen committed Oct 29, 2018
1 parent 1260671 commit bfc8c68
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 6 deletions.
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ This level enables additional math and transform operators.

tvm.relay.zeros
tvm.relay.nn.leaky_relu
tvm.relay.nn.prelu
tvm.relay.zeros_like
tvm.relay.ones
tvm.relay.ones_like
Expand Down Expand Up @@ -183,6 +184,7 @@ Level 2 Definitions
Level 3 Definitions
-------------------
.. autofunction:: tvm.relay.nn.leaky_relu
.. autofunction:: tvm.relay.nn.prelu
.. autofunction:: tvm.relay.floor
.. autofunction:: tvm.relay.ceil
.. autofunction:: tvm.relay.trunc
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,17 @@ struct LeakyReluAttrs : public tvm::AttrsNode<LeakyReluAttrs> {
};


/*! \brief Attributes for prelu operator */
struct PReluAttrs : public tvm::AttrsNode<PReluAttrs> {
int axis;

TVM_DECLARE_ATTRS(PReluAttrs, "relay.attrs.PReluAttrs") {
TVM_ATTR_FIELD(axis).set_default(1)
.describe("Specify which shape axis the channel is specified.");
}
};


/*! \brief Attributes used in dropout operator */
struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
double rate;
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ class TypeReporterNode : public Node {
TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
/*!
* \brief assert shape expression comparison.
* \note Use assert only if any of the condition input is symbolic.
* \param cond The condition of operation.
* \return false if assertation can be proven to have failed
* true if solver can still proceed.
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,33 @@ def leaky_relu(data, alpha):
return _make.leaky_relu(data, alpha)


def prelu(data, alpha, axis=1):
"""This operator takes data as input and does Leaky version
of a Rectified Linear Unit.
.. math::
`y = x > 0 ? x : alpha * x`
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
alpha : tvm.relay.Expr
Slope coefficient for the negative half axis.
axis : int, optional
Specify which shape axis the channel is specified.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.prelu(data, alpha, axis)


def pad(data,
pad_width,
pad_value=0.0):
Expand Down
56 changes: 56 additions & 0 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,62 @@ RELAY_REGISTER_OP("nn.leaky_relu")
.add_type_rel("Identity", IdentityRel);


TVM_REGISTER_NODE_TYPE(PReluAttrs);

bool PReluRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;

const PReluAttrs* param = attrs.as<PReluAttrs>();
CHECK(param != nullptr);

CHECK(param->axis < static_cast<int>(data->shape.size()))
<< "Wrong axis (" << param->axis << ")value.";

// assign alpha type
Array<IndexExpr> alpha_shape({data->shape[param->axis]});
reporter->Assign(types[1], TensorTypeNode::make(alpha_shape, data->dtype));

// assign output type
reporter->Assign(types[2], TensorTypeNode::make(data->shape, data->dtype));
return true;
}

// Positional relay function to create prelu operator used by frontend FFI.
Expr MakePRelu(Expr data,
Expr alpha,
int axis) {
auto attrs = make_node<PReluAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.prelu");
return CallNode::make(op, {data, alpha}, Attrs(attrs), {});
}


TVM_REGISTER_API("relay.op.nn._make.prelu")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakePRelu, args, rv);
});


RELAY_REGISTER_OP("nn.prelu")
.describe(R"code(Parametric version of a Rectified Linear Unit.
It accepts two arguments: an input ``x`` and a channelwise slope ``alpha``
and computes the output as :math:`PReLU(x) y = x > 0 ? x : alpha * x`,
where :math:`*` is an channelwise multiplication for each sample in the batch.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.PReluAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "Input data.")
.add_argument("alpha", "Tensor", "Input channelwise alpha.")
.set_support_level(3)
.add_type_rel("PRelu", PReluRel);


TVM_REGISTER_API("relay.op.nn._make.softmax")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
auto make_func = [](Expr data, int axis) {
Expand Down
39 changes: 33 additions & 6 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,39 @@ def test_full_like():
assert yy.checked_type == relay.TensorType((n, c, h, w), "float32")

def test_infer_type_leaky_relu():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.leaky_relu(x, alpha=0.1)
"alpha=0.1" in y.astext()
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, w), "float32")
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.leaky_relu(x, alpha=0.1)
"alpha=0.1" in y.astext()
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, w), "float32")

def verify_infer_type_prelu(data, alpha, axis, output, dtype="float32"):
x = relay.var("data", relay.TensorType(data, dtype))
if alpha:
y = relay.var("alpha", relay.TensorType(alpha, dtype))
else:
y = relay.var("alpha", relay.IncompleteType())
z = relay.nn.prelu(x, y, axis=axis)
zz = relay.ir_pass.infer_type(z)
if axis != 1:
assert "axis" in z.astext()
assert zz.checked_type == relay.ty.TensorType(output, dtype)
if not alpha:
axis = axis if axis else 1
alpha_shape = (data[axis],)
assert zz.args[1].checked_type == relay.TensorType(alpha_shape, "float32")

def test_infer_type_prelu():
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
verify_infer_type_prelu((n, c, h, w), (c,), 1, (n, c, h, w))
verify_infer_type_prelu((n, h, w, c), (c,), 3, (n, h, w, c))
verify_infer_type_prelu((n, c, h, w), None, 1, (n, c, h, w))
verify_infer_type_prelu((n, h, w, c), None, 3, (n, h, w, c))
verify_infer_type_prelu((1, 3, 2, 2), (3,), 1, (1, 3, 2, 2))
verify_infer_type_prelu((1, 2, 2, 3), (3,), 3, (1, 2, 2, 3))
verify_infer_type_prelu((1, 3, 2, 2), None, 1, (1, 3, 2, 2))
verify_infer_type_prelu((1, 2, 2, 3), None, 3, (1, 2, 2, 3))

if __name__ == "__main__":
test_cast()
Expand All @@ -208,6 +234,7 @@ def test_infer_type_leaky_relu():
test_full()
test_full_like()
test_infer_type_leaky_relu()
test_infer_type_prelu()
test_squeeze_infer_type()
test_squeeze_bad_axes_infer_type()
test_split_infer_type()

0 comments on commit bfc8c68

Please sign in to comment.