Skip to content

Commit

Permalink
Add support for Tensorflow operators log1p and cos
Browse files Browse the repository at this point in the history
The patch adds support for Tensorflow operators log1p and cos
Tensorflow log1p is described at https://www.tensorflow.org/api_docs/python/tf/math/log1p
Tensorflow cos is described at https://www.tensorflow.org/api_docs/python/tf/math/cos
  • Loading branch information
alexgl-github committed Jul 24, 2019
1 parent 90eee08 commit 8233b01
Show file tree
Hide file tree
Showing 13 changed files with 183 additions and 1 deletion.
2 changes: 1 addition & 1 deletion include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(rsqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(popcount);

TVM_DECLARE_INTRIN_UNARY(cos);

// Implementation details after this
inline bool is_const(const Expr& x) {
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,35 @@ def log(x):
"""
return call_pure_intrin(x.dtype, "log", x)

def log1p(x):
"""Take log of input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "log1p", x)

def cos(x):
"""Take cos of input x.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "cos", x)

def sqrt(x):
"""Take square root of input x.
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,8 @@ def _impl(inputs, attr, params):
'Less' : _broadcast('less'),
'LessEqual' : _broadcast('less_equal'),
'Log' : AttrCvt('log'),
'Log1p' : AttrCvt('log1p'),
'Cos' : AttrCvt('cos'),
'LogicalAnd' : _logical('logical_and'),
'LogicalOr' : _logical('logical_or'),
'LogicalNot' : _logical('logical_not'),
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
schedule_elemwise = schedule_injective

register_schedule("log", schedule_broadcast)
register_schedule("log1p", schedule_broadcast)
register_schedule("cos", schedule_broadcast)
register_schedule("exp", schedule_broadcast)
register_schedule("sqrt", schedule_broadcast)
register_schedule("rsqrt", schedule_broadcast)
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ def log_grad(orig, grad):
x = orig.args[0]
return [grad * ones_like(x) / x]

@register_gradient("log1p")
def log1p_grad(orig, grad):
"""Returns [grad * (1 / (x+1))]"""
x = orig.args[0]
return [grad * ones_like(x) / (x + ones_like(x))]

@register_gradient("cos")
def cos_grad(orig, grad):
"""Returns [grad * (-sin(x))]"""
x = orig.args[0]
ones = ones_like(x)
return [grad * (-ones * sin(x))]

@register_gradient("exp")
def exp_grad(orig, grad):
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,35 @@ def log(data):
"""
return _make.log(data)

def log1p(data):
"""Compute elementwise log of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.log1p(data)

def cos(data):
"""Compute elementwise cos of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.cos(data)

def exp(data):
"""Compute elementwise exp of data.
Expand Down
14 changes: 14 additions & 0 deletions src/codegen/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,23 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.exp")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p")
.set_body([](const TVMArgs& args, TVMRetValue* rv){
Expr e = args[0];
const Call* call = e.as<Call>();
CHECK(call != nullptr);

auto one = make_const(call->args[0].type(), 1);
*rv = log(call->args[0] + one);
});


TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
.set_body(DispatchExtern<FloatSuffix>);

Expand Down
3 changes: 3 additions & 0 deletions src/codegen/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos")
.set_body(DispatchExtern<CUDAFastMath>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tanh")
.set_body(DispatchExtern<CUDAMath>);

Expand Down
22 changes: 22 additions & 0 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,28 @@ RELAY_REGISTER_UNARY_OP("log")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));


RELAY_REGISTER_UNARY_OP("log1p")
.describe(R"code(Returns the log1p input array, computed element-wise.
.. math::
log(x+1)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log1p));


RELAY_REGISTER_UNARY_OP("cos")
.describe(R"code(Returns the cos of input array, computed element-wise.
.. math::
Y = cos(X)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cos));


RELAY_REGISTER_UNARY_OP("exp")
.describe(R"code(Returns the exp input array, computed element-wise.
Expand Down
19 changes: 19 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1850,6 +1850,22 @@ def test_forward_log():
tf.log(in_data, name="log")
compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0')

def test_forward_log1p():
"""test operator Log1p """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.log1p(in_data, name="log1p")
compare_tf_with_tvm([np_data], ['in_data:0'], 'log1p:0')

def test_forward_cos():
"""test operator cos """
np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
tf.reset_default_graph()
in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
tf.cos(in_data, name="cos")
compare_tf_with_tvm([np_data], ['in_data:0'], 'cos:0')

def test_forward_negative():
"""test tf operator Neg """
np_data = np.random.uniform(-100, 255, size=(224, 224, 3)).astype(np.float32)
Expand Down Expand Up @@ -2097,6 +2113,7 @@ def test_placeholder():
# ----
if __name__ == '__main__':


# Transforms
test_forward_transpose()
test_forward_reshape()
Expand Down Expand Up @@ -2140,6 +2157,8 @@ def test_placeholder():
test_forward_pow_exp()
test_forward_sign()
test_forward_log()
test_forward_log1p()
test_forward_cos()
test_forward_negative()
test_forward_divide()
test_forward_abs()
Expand Down
10 changes: 10 additions & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ TOPI_DECLARE_UNARY_OP(ceil);
TOPI_DECLARE_UNARY_OP(round);
TOPI_DECLARE_UNARY_OP(trunc);
TOPI_DECLARE_UNARY_OP(abs);
TOPI_DECLARE_UNARY_OP(cos);

/*
* \brief Fast_tanh_float implementation from Eigen
Expand Down Expand Up @@ -214,6 +215,15 @@ inline Tensor rsqrt(const Tensor& x,
}, name, tag);
}

inline Tensor log1p(const Tensor& x,
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
Expr one = make_const(x->dtype, 1);
return tvm::log(x(i)+one);
}, name, tag);
}

/*!
* \brief Creates an operation that clips each element of a tensor to
* the interval [a_min, a_max]
Expand Down
30 changes: 30 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ def tanh(x):
"""
return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i)))

@tvm.tag_scope(tag=tag.ELEMWISE)
def cos(x):
"""Take cos of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.cos(x(*i)))

@tvm.tag_scope(tag=tag.ELEMWISE)
def floor(x):
Expand Down Expand Up @@ -206,6 +221,21 @@ def log(x):
"""
return tvm.compute(x.shape, lambda *i: tvm.log(x(*i)))

@tvm.tag_scope(tag=tag.ELEMWISE)
def log1p(x):
"""Take logarithm of input (x+1).
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.log1p(x(*i)))

@tvm.tag_scope(tag=tag.ELEMWISE)
def sqrt(x):
Expand Down
10 changes: 10 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ TVM_REGISTER_GLOBAL("topi.exp")
*rv = exp(args[0]);
});

TVM_REGISTER_GLOBAL("topi.cos")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cos(args[0]);
});

TVM_REGISTER_GLOBAL("topi.tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]);
Expand All @@ -173,6 +178,11 @@ TVM_REGISTER_GLOBAL("topi.log")
*rv = log(args[0]);
});

TVM_REGISTER_GLOBAL("topi.log1p")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = log1p(args[0]);
});

TVM_REGISTER_GLOBAL("topi.identity")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = identity(args[0]);
Expand Down

0 comments on commit 8233b01

Please sign in to comment.