Skip to content

Commit

Permalink
Add support for Tensorflow operators log1p, cos, sin (apache#3614)
Browse files Browse the repository at this point in the history
The patch adds support for Tensorflow operators log1p and cos
Tensorflow log1p is described at https://www.tensorflow.org/api_docs/python/tf/math/log1p
Tensorflow cos is described at https://www.tensorflow.org/api_docs/python/tf/math/cos
Tensorflow sin is described at https://www.tensorflow.org/api_docs/python/tf/math/sin
  • Loading branch information
alexgl-github authored and wweic committed Aug 9, 2019
1 parent 1b834d3 commit cddfa3e
Show file tree
Hide file tree
Showing 19 changed files with 212 additions and 5 deletions.
3 changes: 2 additions & 1 deletion include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,8 @@ TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(rsqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(popcount);

TVM_DECLARE_INTRIN_UNARY(cos);
TVM_DECLARE_INTRIN_UNARY(sin);

// Implementation details after this
inline bool is_const(const Expr& x) {
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,35 @@ def log(x):
"""
return call_pure_intrin(x.dtype, "log", x)

def cos(x):
"""Take cos of input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "cos", x)

def sin(x):
"""Take sin of input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "sin", x)

def sqrt(x):
"""Take square root of input x.
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,13 @@ def _impl(inputs, attr, params):
return _op.prod(inputs[0], int(axis), keepdims=keepdims)
return _impl

def _log1p():
# op description: https://www.tensorflow.org/api_docs/python/tf/math/log1p
def _impl(inputs, attr, params):
one = tvm.relay.const(1, attr['T'].name)
add_out = _get_relay_op('add')(inputs[0], one)
return _get_relay_op('log')(add_out)
return _impl

# compatible operators that do NOT require any conversion.
_identity_list = []
Expand Down Expand Up @@ -1354,6 +1361,9 @@ def _impl(inputs, attr, params):
'Less' : _broadcast('less'),
'LessEqual' : _broadcast('less_equal'),
'Log' : AttrCvt('log'),
'Log1p' : _log1p(),
'Cos' : AttrCvt('cos'),
'Sin' : AttrCvt('sin'),
'LogicalAnd' : _logical('logical_and'),
'LogicalOr' : _logical('logical_or'),
'LogicalNot' : _logical('logical_not'),
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
schedule_elemwise = schedule_injective

register_schedule("log", schedule_broadcast)
register_schedule("log1p", schedule_broadcast)
register_schedule("cos", schedule_broadcast)
register_schedule("sin", schedule_broadcast)
register_schedule("exp", schedule_broadcast)
register_schedule("sqrt", schedule_broadcast)
register_schedule("rsqrt", schedule_broadcast)
Expand Down
14 changes: 13 additions & 1 deletion python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..expr import const
from .op import register_gradient
from .transform import collapse_sum_like, broadcast_to_like, where
from .tensor import exp, negative, power, less
from .tensor import exp, negative, power, less, cos, sin
from .tensor import zeros_like, ones_like
from . import nn as _nn

Expand All @@ -31,6 +31,18 @@ def log_grad(orig, grad):
x = orig.args[0]
return [grad * ones_like(x) / x]

@register_gradient("cos")
def cos_grad(orig, grad):
"""Returns [grad * (-sin(x))]"""
x = orig.args[0]
ones = ones_like(x)
return [grad * (-ones * sin(x))]

@register_gradient("sin")
def sin_grad(orig, grad):
"""Returns [grad * cos(x)]"""
x = orig.args[0]
return [grad * cos(x)]

@register_gradient("exp")
def exp_grad(orig, grad):
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,35 @@ def log(data):
"""
return _make.log(data)

def cos(data):
"""Compute elementwise cos of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.cos(data)

def sin(data):
"""Compute elementwise sin of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.sin(data)

def exp(data):
"""Compute elementwise exp of data.
Expand Down
6 changes: 6 additions & 0 deletions src/codegen/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
.set_body(DispatchExtern<FloatSuffix>);

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.sin")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh")
.set_body(DispatchExtern<CUDAMath>);

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sin")
.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);

} // namespace llvm
} // namespace codegen
} // namespace tvm
Expand Down
6 changes: 6 additions & 0 deletions src/codegen/llvm/intrin_rule_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos")
.set_body(DispatchExternLibDevice);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.sin")
.set_body(DispatchExternLibDevice);

} // namespace llvm
} // namespace codegen
} // namespace tvm
Expand Down
22 changes: 22 additions & 0 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,28 @@ RELAY_REGISTER_UNARY_OP("log")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));


RELAY_REGISTER_UNARY_OP("cos")
.describe(R"code(Returns the cos of input array, computed element-wise.
.. math::
Y = cos(X)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cos));


RELAY_REGISTER_UNARY_OP("sin")
.describe(R"code(Returns the sin of input array, computed element-wise.
.. math::
Y = sin(X)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sin));


RELAY_REGISTER_UNARY_OP("exp")
.describe(R"code(Returns the exp input array, computed element-wise.
Expand Down
27 changes: 27 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1869,6 +1869,30 @@ def test_forward_log():
tf.log(in_data, name="log")
compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0')

def test_forward_log1p():
"""test operator Log1p """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.log1p(in_data, name="log1p")
compare_tf_with_tvm([np_data], ['in_data:0'], 'log1p:0')

def test_forward_cos():
"""test operator cos """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.cos(in_data, name="cos")
compare_tf_with_tvm([np_data], ['in_data:0'], 'cos:0')

def test_forward_sin():
"""test operator sin """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.sin(in_data, name="sin")
compare_tf_with_tvm([np_data], ['in_data:0'], 'sin:0')

def test_forward_negative():
"""test tf operator Neg """
np_data = np.random.uniform(-100, 255, size=(224, 224, 3)).astype(np.float32)
Expand Down Expand Up @@ -2159,6 +2183,9 @@ def test_placeholder():
test_forward_pow_exp()
test_forward_sign()
test_forward_log()
test_forward_log1p()
test_forward_cos()
test_forward_sin()
test_forward_negative()
test_forward_divide()
test_forward_abs()
Expand Down
4 changes: 3 additions & 1 deletion tests/python/relay/test_op_grad_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def check_single_op(opfunc, ref):
(tvm.relay.tanh, lambda x: 1 - np.tanh(x) * np.tanh(x)),
(tvm.relay.sqrt, lambda x: 0.5 * np.power(x, -0.5)),
(tvm.relay.abs, lambda x: np.where(x < 0, -np.ones_like(x), np.ones_like(x))),
(relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x)))]:
(relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x))),
(tvm.relay.cos, lambda x: -1.0 * np.sin(x)),
(tvm.relay.sin, lambda x: np.cos(x))]:
check_single_op(opfunc, ref)


Expand Down
4 changes: 3 additions & 1 deletion tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def check_single_op(opfunc, ref):
(tvm.relay.rsqrt, rsqrt),
(tvm.relay.sigmoid, sigmoid),
(tvm.relay.tanh, np.tanh),
(relay.nn.relu, relu)]:
(relay.nn.relu, relu),
(tvm.relay.cos, np.cos),
(tvm.relay.sin, np.sin)]:
check_single_op(opfunc, ref)


Expand Down
2 changes: 2 additions & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ TOPI_DECLARE_UNARY_OP(ceil);
TOPI_DECLARE_UNARY_OP(round);
TOPI_DECLARE_UNARY_OP(trunc);
TOPI_DECLARE_UNARY_OP(abs);
TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(sin);

/*
* \brief Fast_tanh_float implementation from Eigen
Expand Down
32 changes: 31 additions & 1 deletion topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,37 @@ def tanh(x):
"""
return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i)))

@tvm.tag_scope(tag=tag.ELEMWISE)
def cos(x):
"""Take cos of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.cos(x(*i)))

@tvm.tag_scope(tag=tag.ELEMWISE)
def sin(x):
"""Take sin of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.sin(x(*i)))

@tvm.tag_scope(tag=tag.ELEMWISE)
def floor(x):
Expand Down Expand Up @@ -206,7 +237,6 @@ def log(x):
"""
return tvm.compute(x.shape, lambda *i: tvm.log(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def sqrt(x):
"""Take square root of input x.
Expand Down
10 changes: 10 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@ TVM_REGISTER_GLOBAL("topi.exp")
*rv = exp(args[0]);
});

TVM_REGISTER_GLOBAL("topi.cos")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cos(args[0]);
});

TVM_REGISTER_GLOBAL("topi.sin")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sin(args[0]);
});

TVM_REGISTER_GLOBAL("topi.tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]);
Expand Down
2 changes: 2 additions & 0 deletions topi/tests/python/test_topi_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def test_apply(func, name):
test_apply(topi.log, "log")
test_apply(topi.sqrt, "sqrt")
test_apply(topi.rsqrt, "rsqrt")
test_apply(topi.sin, "sin")
test_apply(topi.cos, "cos")


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def check_device(device):
test_apply(topi.log, "log", np.log, 0, 100)
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True)
test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi)


def test_cast():
Expand Down

0 comments on commit cddfa3e

Please sign in to comment.