Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Nov 21, 2018
1 parent 66717fc commit 16d6294
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 23 deletions.
34 changes: 19 additions & 15 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,25 @@ struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
};

TVM_REGISTER_API("relay.op._make.clip")
.set_body_typed<Expr(Expr, double, double)>([](Expr a, double a_min, double a_max) {
auto attrs = make_node<ClipAttrs>();
attrs->a_min = a_min;
attrs->a_max = a_max;
static const Op& op = Op::Get("clip");
return CallNode::make(op, {a}, Attrs(attrs), {});
});

RELAY_REGISTER_OP("clip")
.describe(R"code(Clip tensor values.
This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("tensor", "Tensor", "The input tensor.")
.set_support_level(3);
.set_body_typed<Expr(Expr, double, double)>([](Expr a, double a_min, double a_max) {
auto attrs = make_node<ClipAttrs>();
attrs->a_min = a_min;
attrs->a_max = a_max;
static const Op& op = Op::Get("clip");
return CallNode::make(op, {a}, Attrs(attrs), {});
});

RELAY_REGISTER_UNARY_OP("clip")
.describe(R"code(Clip tensor values.
This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_support_level(3);

RELAY_REGISTER_UNARY_OP("floor")
.describe(R"code(Returns the floor of input array, computed element-wise.
Expand Down
23 changes: 15 additions & 8 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from tvm import relay
from tvm.relay.op import register_alter_op_layout
from tvm.relay.ir_pass import *
from tvm.relay.testing import layers

def test_alter_op():
"""Test alter an operator"""
"""Test directly replacing an operator with a new one"""
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var('weight', shape=(64, 64, 3, 3))
Expand Down Expand Up @@ -46,7 +45,7 @@ def expected():


def test_alter_return_none():
"""Test do nothing"""
"""Test doing nothing by returning 'None' """
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
y = relay.nn.global_max_pool2d(x)
Expand All @@ -71,7 +70,9 @@ def alter_conv2d(attrs, inputs, tinfos):


def test_alter_layout():
"""Test alternate the layout"""
"""Test alternating the layout of a conv2d.
The layout of broadcast operators and the weight should be changed accordingly.
"""
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
bias = relay.var("bias")
Expand Down Expand Up @@ -130,7 +131,10 @@ def expected():


def test_alter_layout_dual_path():
"""Test alternate the layout"""
"""
Test alternating the layout with two outputs.
One path continues to use the new layout while one path fall backs to old layout.
"""
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
Expand All @@ -147,7 +151,7 @@ def before():
y1 = relay.nn.relu(y1)
y2 = relay.nn.batch_flatten(y)
ret = relay.Tuple([y1, y2])
y = relay.Function([x, weight1, weight2], ret)
y = relay.Function(free_vars(ret), ret)
return y

@register_alter_op_layout("nn.conv2d", level=103)
Expand Down Expand Up @@ -178,7 +182,7 @@ def expected():
y2 = relay.layout_transform(y, "NCHW16c", "NCHW")
y2 = relay.nn.batch_flatten(y2)
ret = relay.Tuple([y1, y2])
y = relay.Function([x, weight1, weight2], ret)
y = relay.Function(free_vars(ret), ret)
return y

a = before()
Expand All @@ -192,7 +196,10 @@ def expected():
assert(alpha_equal(a, b))

def test_alter_layout_resnet():
"""Test alternate the layout"""
"""Test alternating the layout of a residual block
This also tests the elimination of duplicated transformation.
If a same transformation applies to a same node twice, only one transformation will be created.
"""
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
Expand Down

0 comments on commit 16d6294

Please sign in to comment.