Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] Sign #2775

Merged
merged 1 commit into from
Mar 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ List of operators
topi.negative
topi.floor
topi.ceil
topi.sign
topi.trunc
topi.round
topi.abs
Expand Down Expand Up @@ -95,6 +96,7 @@ topi
.. autofunction:: topi.identity
.. autofunction:: topi.floor
.. autofunction:: topi.ceil
.. autofunction:: topi.sign
.. autofunction:: topi.trunc
.. autofunction:: topi.round
.. autofunction:: topi.abs
Expand Down
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ This level enables additional math and transform operators.
tvm.relay.squeeze
tvm.relay.floor
tvm.relay.ceil
tvm.relay.sign
tvm.relay.trunc
tvm.relay.clip
tvm.relay.round
Expand Down Expand Up @@ -211,6 +212,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.squeeze
.. autofunction:: tvm.relay.floor
.. autofunction:: tvm.relay.ceil
.. autofunction:: tvm.relay.sign
.. autofunction:: tvm.relay.trunc
.. autofunction:: tvm.relay.clip
.. autofunction:: tvm.relay.round
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
register_schedule("ceil", schedule_broadcast)
register_schedule("trunc", schedule_broadcast)
register_schedule("round", schedule_broadcast)
register_schedule("sign", schedule_broadcast)
register_schedule("abs", schedule_broadcast)
register_schedule("tanh", schedule_broadcast)
register_schedule("logical_not", schedule_broadcast)
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,20 @@ def abs(data):
"""
return _make.abs(data)

def sign(data):
"""Compute element-wise absolute of data.

Parameters
----------
data : relay.Expr
The input data

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.sign(data)

def tanh(data):
"""Compute element-wise tanh of data.
Expand Down
10 changes: 10 additions & 0 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,16 @@ RELAY_REGISTER_UNARY_OP("round")
.set_support_level(3)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::round));

RELAY_REGISTER_UNARY_OP("sign")
.describe(R"code(Returns the sign of input array, computed element-wise.

.. numpy::
sign(x)

)code" TVM_ADD_FILELINE)
.set_support_level(3)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sign));


RELAY_REGISTER_UNARY_OP("abs")
.describe(R"code(Returns the abs of input array, computed element-wise.
Expand Down
3 changes: 2 additions & 1 deletion tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def test_unary_identity():
(relay.round, np.round),
(relay.abs, np.abs),
(relay.copy, None), # np.copy
(relay.negative, np.negative)]:
(relay.negative, np.negative),
(relay.sign, np.sign)]:
shape = (8, 9, 4)
x = relay.var("x", relay.TensorType(shape, "float32"))
y = op(x)
Expand Down
22 changes: 22 additions & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,28 @@ inline Tensor logical_not(const Tensor& x,
}, name, tag);
}

/*!
* \brief Returns the sign of the tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the sign
*/
inline Tensor sign(const Tensor& x,
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
Expr zero = make_zero(x->dtype);
Expr one = make_const(x->dtype, 1);
Expr minus_one = make_const(x->dtype, -1);
auto s1 = tvm::ir::Select::make((x(i) < zero), minus_one, zero);
auto s2 = tvm::ir::Select::make((x(i) > zero), one, s1);
return s2;
}, name, tag);
}

/*!
* \brief Creates an operation that clips each element of a tensor to
* the interval [a_min, a_max]
Expand Down
15 changes: 15 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs
import tvm
from . import tag
from . import cpp

@tvm.tag_scope(tag=tag.ELEMWISE)
def identity(x):
Expand Down Expand Up @@ -107,6 +108,20 @@ def ceil(x):
"""
return tvm.compute(x.shape, lambda *i: tvm.ceil(x(*i)))

def sign(x):
"""Returns -1, 0, 1 based on sign of x.

Parameters
----------
x : tvm.Tensor
Input argument.

Returns
-------
y : tvm.Tensor
The result.
"""
return cpp.sign(x)

@tvm.tag_scope(tag=tag.ELEMWISE)
def trunc(x):
Expand Down
5 changes: 5 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ TVM_REGISTER_GLOBAL("topi.elemwise_sum")
*rv = elemwise_sum(args[0]);
});

TVM_REGISTER_GLOBAL("topi.sign")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sign(args[0]);
});

TVM_REGISTER_GLOBAL("topi.full")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = full(args[0], args[1], args[2]);
Expand Down
6 changes: 4 additions & 2 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ def test_ewise():

shape = (20, 3)

def test_apply(func, name, f_numpy, low, high, check_round=False):
def test_apply(func, name, f_numpy, low, high, check_round=False, skip_name_check=False):
B = func(A)
assert tuple(B.shape) == tuple(A.shape)
assert B.op.body[0].name == name
if not skip_name_check:
ashutoshparkhi marked this conversation as resolved.
Show resolved Hide resolved
assert B.op.body[0].name == name
a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10
# avoid round check too close to boundary
if check_round:
Expand Down Expand Up @@ -49,6 +50,7 @@ def check_device(device):

test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True)
test_apply(topi.trunc, "trunc", np.trunc, -100, 100)
test_apply(topi.abs, "fabs", np.abs, -100, 100)
test_apply(topi.round, "round", np.round, -100, 100, check_round=True)
Expand Down