Skip to content

Commit

Permalink
Support for sign (apache#2775)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashutoshparkhi authored and wweic committed Mar 20, 2019
1 parent f8c5afe commit b277e1d
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ List of operators
topi.negative
topi.floor
topi.ceil
topi.sign
topi.trunc
topi.round
topi.abs
Expand Down Expand Up @@ -96,6 +97,7 @@ topi
.. autofunction:: topi.identity
.. autofunction:: topi.floor
.. autofunction:: topi.ceil
.. autofunction:: topi.sign
.. autofunction:: topi.trunc
.. autofunction:: topi.round
.. autofunction:: topi.abs
Expand Down
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ This level enables additional math and transform operators.
tvm.relay.squeeze
tvm.relay.floor
tvm.relay.ceil
tvm.relay.sign
tvm.relay.trunc
tvm.relay.clip
tvm.relay.round
Expand Down Expand Up @@ -213,6 +214,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.squeeze
.. autofunction:: tvm.relay.floor
.. autofunction:: tvm.relay.ceil
.. autofunction:: tvm.relay.sign
.. autofunction:: tvm.relay.trunc
.. autofunction:: tvm.relay.clip
.. autofunction:: tvm.relay.round
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
register_schedule("ceil", schedule_broadcast)
register_schedule("trunc", schedule_broadcast)
register_schedule("round", schedule_broadcast)
register_schedule("sign", schedule_broadcast)
register_schedule("abs", schedule_broadcast)
register_schedule("tanh", schedule_broadcast)
register_schedule("logical_not", schedule_broadcast)
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,20 @@ def abs(data):
"""
return _make.abs(data)

def sign(data):
"""Compute element-wise absolute of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.sign(data)

def tanh(data):
"""Compute element-wise tanh of data.
Expand Down
10 changes: 10 additions & 0 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,16 @@ RELAY_REGISTER_UNARY_OP("round")
.set_support_level(3)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round));

RELAY_REGISTER_UNARY_OP("sign")
.describe(R"code(Returns the sign of input array, computed element-wise.
.. numpy::
sign(x)
)code" TVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sign));


RELAY_REGISTER_UNARY_OP("abs")
.describe(R"code(Returns the abs of input array, computed element-wise.
Expand Down
3 changes: 2 additions & 1 deletion tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def test_unary_identity():
(relay.round, np.round),
(relay.abs, np.abs),
(relay.copy, None), # np.copy
(relay.negative, np.negative)]:
(relay.negative, np.negative),
(relay.sign, np.sign)]:
shape = (8, 9, 4)
x = relay.var("x", relay.TensorType(shape, "float32"))
y = op(x)
Expand Down
22 changes: 22 additions & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,28 @@ inline Tensor logical_not(const Tensor& x,
}, name, tag);
}

/*!
* \brief Returns the sign of the tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the sign
*/
inline Tensor sign(const Tensor& x,
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
Expr zero = make_zero(x->dtype);
Expr one = make_const(x->dtype, 1);
Expr minus_one = make_const(x->dtype, -1);
auto s1 = tvm::ir::Select::make((x(i) < zero), minus_one, zero);
auto s2 = tvm::ir::Select::make((x(i) > zero), one, s1);
return s2;
}, name, tag);
}

/*!
* \brief Creates an operation that clips each element of a tensor to
* the interval [a_min, a_max]
Expand Down
15 changes: 15 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs
import tvm
from . import tag
from . import cpp

@tvm.tag_scope(tag=tag.ELEMWISE)
def identity(x):
Expand Down Expand Up @@ -107,6 +108,20 @@ def ceil(x):
"""
return tvm.compute(x.shape, lambda *i: tvm.ceil(x(*i)))

def sign(x):
"""Returns -1, 0, 1 based on sign of x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return cpp.sign(x)

@tvm.tag_scope(tag=tag.ELEMWISE)
def trunc(x):
Expand Down
5 changes: 5 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ TVM_REGISTER_GLOBAL("topi.elemwise_sum")
*rv = elemwise_sum(args[0]);
});

TVM_REGISTER_GLOBAL("topi.sign")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sign(args[0]);
});

TVM_REGISTER_GLOBAL("topi.full")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = full(args[0], args[1], args[2]);
Expand Down
6 changes: 4 additions & 2 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ def test_ewise():

shape = (20, 3)

def test_apply(func, name, f_numpy, low, high, check_round=False):
def test_apply(func, name, f_numpy, low, high, check_round=False, skip_name_check=False):
B = func(A)
assert tuple(B.shape) == tuple(A.shape)
assert B.op.body[0].name == name
if not skip_name_check:
assert B.op.body[0].name == name
a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10
# avoid round check too close to boundary
if check_round:
Expand Down Expand Up @@ -49,6 +50,7 @@ def check_device(device):

test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True)
test_apply(topi.trunc, "trunc", np.trunc, -100, 100)
test_apply(topi.abs, "fabs", np.abs, -100, 100)
test_apply(topi.round, "round", np.round, -100, 100, check_round=True)
Expand Down

0 comments on commit b277e1d

Please sign in to comment.