Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][PYTORCH]cosh,sinh,log2,log10,log1p op support #5395

Merged
merged 3 commits into from
Apr 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ def _impl(inputs, input_types):
return _impl


def _log1p():
def _impl(inputs, input_types):
# 1_plus_log x = log(x + 1)
one = _expr.const(1, dtype="float32")
return _op.log(inputs[0] + one)
return _impl


def _arange():
def _impl(inputs, input_types):
if len(inputs) == 5:
Expand Down Expand Up @@ -1642,11 +1650,16 @@ def _get_convert_map(prelude):
"aten::abs" : _unary("abs"),
"aten::neg" : _unary("negative"),
"aten::cos" : _unary("cos"),
"aten::cosh" : _unary("cosh"),
"aten::sin" : _unary("sin"),
"aten::sinh" : _unary("sinh"),
"aten::tan" : _unary("tan"),
"aten::tanh" : _unary("tanh"),
"aten::atan" : _unary("atan"),
"aten::log" : _unary("log"),
"aten::log2" : _unary("log2"),
"aten::log10" : _unary("log10"),
"aten::log1p" : _log1p(),
"aten::exp" : _unary("exp"),
"aten::erf" : _unary("erf"),
"aten::trunc" : _unary("trunc"),
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@


register_broadcast_schedule("log")
register_broadcast_schedule("log2")
register_broadcast_schedule("log10")
register_broadcast_schedule("tan")
register_broadcast_schedule("cos")
register_broadcast_schedule("cosh")
register_broadcast_schedule("sin")
register_broadcast_schedule("sinh")
register_broadcast_schedule("atan")
register_broadcast_schedule("exp")
register_broadcast_schedule("erf")
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@
from .reduce import sum as _sum
from .tensor import (
cos,
cosh,
exp,
less,
negative,
ones_like,
power,
sin,
sinh,
zeros_like,
equal,
shape_of,
Expand Down Expand Up @@ -61,6 +63,24 @@ def log_grad(orig, grad):
return [grad * ones_like(x) / x]


@register_gradient("log2")
def log2_grad(orig, grad):
"""Returns [grad * 1 / (log(2) * x)]"""
x = orig.args[0]
ones = ones_like(x)
two = const(2.0)
return [grad * ones / (log(two) * x)]


@register_gradient("log10")
def log10_grad(orig, grad):
"""Returns [grad * 1 / (log(10) * x)]"""
x = orig.args[0]
ones = ones_like(x)
ten = const(10.0)
return [grad * ones / (log(ten) * x)]


@register_gradient("tan")
def tan_grad(orig, grad):
"""Returns [grad / (cos^2(x))]"""
Expand All @@ -76,12 +96,26 @@ def cos_grad(orig, grad):
return [grad * (-ones * sin(x))]


@register_gradient("cosh")
def cosh_grad(orig, grad):
"""Returns [grad * (-sinh(x))]"""
x = orig.args[0]
ones = ones_like(x)
return [grad * (-ones * sinh(x))]


@register_gradient("sin")
def sin_grad(orig, grad):
"""Returns [grad * cos(x)]"""
x = orig.args[0]
return [grad * cos(x)]

@register_gradient("sinh")
def sinh_grad(orig, grad):
"""Returns [grad * cosh(x)]"""
x = orig.args[0]
return [grad * cosh(x)]

@register_gradient("atan")
def atan_grad(orig, grad):
"""Returns [grad * 1 / (1 + x ^ 2)]"""
Expand Down
60 changes: 60 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,36 @@ def log(data):
"""
return _make.log(data)

def log2(data):
"""Compute elementwise log to the base 2 of data.

Parameters
----------
data : relay.Expr
The input data

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.log2(data)

def log10(data):
"""Compute elementwise log to the base 10 of data.

Parameters
----------
data : relay.Expr
The input data

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.log10(data)

def tan(data):
"""Compute elementwise tan of data.

Expand Down Expand Up @@ -77,6 +107,21 @@ def cos(data):
"""
return _make.cos(data)

def cosh(data):
"""Compute elementwise cosh of data.

Parameters
----------
data : relay.Expr
The input data

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.cosh(data)

def sin(data):
"""Compute elementwise sin of data.

Expand All @@ -92,6 +137,21 @@ def sin(data):
"""
return _make.sin(data)

def sinh(data):
"""Compute elementwise sinh of data.

Parameters
----------
data : relay.Expr
The input data

Returns
-------
result : relay.Expr
The computed result.
"""
return _make.sinh(data)

def atan(data):
"""Compute elementwise atan of data.

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def where(condition, x, y):
Returns
-------
result : relay.Expr
The selected array.
The selected array.

Examples
--------
Expand Down
1 change: 1 addition & 0 deletions python/tvm/te/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# expose all operators in tvm tir.op
from tvm.tir import any, all, min_value, max_value, trace
from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil
from tvm.tir import sinh, cosh, log2, log10
from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else
from tvm.tir import isnan, isfinite, isinf
from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def sin(x):


def sinh(x):
"""Take sin of input x.
"""Take sinh of input x.

Parameters
----------
Expand Down
44 changes: 44 additions & 0 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,28 @@ RELAY_REGISTER_UNARY_OP("log")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));


RELAY_REGISTER_UNARY_OP("log2")
.describe(R"code(Returns the log to base 2 of input array, computed element-wise.

.. math::
log2(x)

)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log2));


RELAY_REGISTER_UNARY_OP("log10")
.describe(R"code(Returns the log to base 10 of input array, computed element-wise.

.. math::
log10(x)

)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log10));


RELAY_REGISTER_UNARY_OP("tan")
.describe(R"code(Returns the tan of input array, computed element-wise.

Expand All @@ -73,6 +95,17 @@ RELAY_REGISTER_UNARY_OP("cos")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cos));


RELAY_REGISTER_UNARY_OP("cosh")
.describe(R"code(Returns the cosh of input array, computed element-wise.

.. math::
Y = cosh(X)

)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cosh));


RELAY_REGISTER_UNARY_OP("sin")
.describe(R"code(Returns the sin of input array, computed element-wise.

Expand All @@ -84,6 +117,17 @@ RELAY_REGISTER_UNARY_OP("sin")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sin));


RELAY_REGISTER_UNARY_OP("sinh")
.describe(R"code(Returns the sinh of input array, computed element-wise.

.. math::
Y = sinh(X)

)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sinh));


RELAY_REGISTER_UNARY_OP("atan")
.describe(R"code(Returns the atan of input array, computed element-wise.

Expand Down
12 changes: 12 additions & 0 deletions src/target/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p")
.set_body(DispatchExtern<FloatSuffix>);

Expand All @@ -49,9 +55,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan")
.set_body(DispatchExtern<FloatSuffix>);

Expand Down
25 changes: 25 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,6 +1868,26 @@ class Neg1(Module):
def forward(self, *args):
return torch.neg(args[0])

class Sinh1(Module):
def forward(self, *args):
return torch.sinh(args[0])

class Cosh1(Module):
def forward(self, *args):
return torch.cosh(args[0])

class Log2_1(Module):
def forward(self, *args):
return torch.log2(args[0])

class Log10_1(Module):
def forward(self, *args):
return torch.log10(args[0])

class Log1p_1(Module):
def forward(self, *args):
return torch.log1p(args[0])

input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(Sqrt1().float().eval(), input_data=input_data)
Expand All @@ -1876,11 +1896,16 @@ def forward(self, *args):
verify_model(Floor1().float().eval(), input_data=input_data)
verify_model(Round1().float().eval(), input_data=input_data)
verify_model(Cos1().float().eval(), input_data=input_data)
verify_model(Cosh1().float().eval(), input_data=input_data)
verify_model(Sin1().float().eval(), input_data=input_data)
verify_model(Sinh1().float().eval(), input_data=input_data)
verify_model(Tan1().float().eval(), input_data=input_data)
verify_model(Tanh1().float().eval(), input_data=input_data)
verify_model(ATanh1().float().eval(), input_data=input_data)
verify_model(Log1().float().eval(), input_data=input_data)
verify_model(Log2_1().float().eval(), input_data=input_data)
verify_model(Log10_1().float().eval(), input_data=input_data)
verify_model(Log1p_1().float().eval(), input_data=input_data)
verify_model(Exp1().float().eval(), input_data=input_data)
verify_model(Erf1().float().eval(), input_data=input_data)
verify_model(Trunc1().float().eval(), input_data=input_data)
Expand Down
6 changes: 5 additions & 1 deletion tests/python/relay/test_op_grad_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def check_single_op(opfunc, ref):
(tvm.relay.cos, lambda x: -1.0 * np.sin(x)),
(tvm.relay.sin, lambda x: np.cos(x)),
(tvm.relay.tan, lambda x: 1.0 / (np.cos(x) ** 2)),
(tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0)))]:
(tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0))),
(tvm.relay.log2, lambda x: 1 / (np.log(2) * x)),
(tvm.relay.log10, lambda x: 1 / (np.log(10) * x)),
(tvm.relay.cosh, lambda x: -1.0 * np.sinh(x)),
(tvm.relay.sinh, lambda x: np.cosh(x))]:
check_single_op(opfunc, ref)


Expand Down
4 changes: 4 additions & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,18 @@ TOPI_DECLARE_UNARY_OP(erf);
TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt);
TOPI_DECLARE_UNARY_OP(log);
TOPI_DECLARE_UNARY_OP(log2);
TOPI_DECLARE_UNARY_OP(log10);
TOPI_DECLARE_UNARY_OP(floor);
TOPI_DECLARE_UNARY_OP(ceil);
TOPI_DECLARE_UNARY_OP(round);
TOPI_DECLARE_UNARY_OP(trunc);
TOPI_DECLARE_UNARY_OP(abs);
TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(cosh);
TOPI_DECLARE_UNARY_OP(tan);
TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(sinh);
TOPI_DECLARE_UNARY_OP(atan);
TOPI_DECLARE_UNARY_OP(isnan);
TOPI_DECLARE_UNARY_OP(tanh);
Expand Down
Loading