Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Intrinsic] Add log1p, ldexp, atan2, hypot, nextafter, copysign #5312

Merged
merged 2 commits into from
Apr 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@

from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, all, any, min_value, max_value, trace
from .op import exp, exp2, exp10, log, log2, log10
from .op import cos, sin, cosh, sinh, tan, tanh, atan
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil
from .op import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else
from .op import isnan, isfinite, isinf
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp
from .op import cos, sin, cosh, sinh, tan, tanh, atan, atan2
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
from .op import isnan, isfinite, isinf, copysign
from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
from .op import comm_reducer, min, max, sum

Expand Down
113 changes: 113 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,23 @@ def log10(x):
"""
return call_pure_intrin(x.dtype, "log10", x)


def log1p(x):
"""Take log(x + 1) with respect to input x.

Parameters
----------
x : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "log1p", x)


def tan(x):
"""Take tan of input x.

Expand Down Expand Up @@ -552,6 +569,26 @@ def atan(x):
"""
return call_pure_intrin(x.dtype, "atan", x)


def atan2(x1, x2):
"""Take arctan2(x1, x2).

Parameters
----------
x1 : PrimExpr
Input argument.

x2 : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x1.dtype, "atan2", x1, x2)


def sqrt(x):
"""Take square root of input x.

Expand Down Expand Up @@ -690,6 +727,82 @@ def nearbyint(x):
return _ffi_api.nearbyint(x)


def nextafter(x1, x2):
"""Return the next floating-point value after x1 towards x2.

Parameters
----------
x1 : PrimExpr
Input argument.

x2 : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x1.dtype, "nextafter", x1, x2)


def hypot(x1, x2):
"""Equivalent to sqrt(x1**2 + x2**2), element-wise.

Parameters
----------
x1 : PrimExpr
Input argument.

x2 : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x1.dtype, "hypot", x1, x2)


def copysign(x1, x2):
"""Change the sign of x1 to that of x2, element-wise.

Parameters
----------
x1 : PrimExpr
Input argument.

x2 : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x1.dtype, "copysign", x1, x2)


def ldexp(x1, x2):
"""Returns x1 * (2 ** x2).

Parameters
----------
x1 : PrimExpr
Input argument.

x2 : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x1.dtype, "ldexp", x1, x2)


def isnan(x):
"""Check if input value is Nan.

Expand Down
18 changes: 18 additions & 0 deletions src/target/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
.set_body(DispatchExtern<FloatSuffix>);

Expand All @@ -52,6 +55,21 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nextafter")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.copysign")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.ldexp")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sqrt")
.set_body(DispatchExtern<FloatSuffix>);

Expand Down
50 changes: 49 additions & 1 deletion tests/python/unittest/test_tir_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def test_unary_intrin():
(tvm.tir.log10, lambda x : np.log10(x)),
(tvm.tir.sinh, lambda x : np.sinh(x)),
(tvm.tir.cosh, lambda x : np.cosh(x)),
(tvm.tir.log1p, lambda x : np.log1p(x)),
]
def run_test(tvm_intrin, np_func):
m = te.var("m",)
Expand All @@ -79,10 +80,57 @@ def run_test(tvm_intrin, np_func):
b.asnumpy(), np_func(a.asnumpy()), atol=1e-5, rtol=1e-5)

for func in test_funcs:
run_test(*func);
run_test(*func)


def test_binary_intrin():
test_funcs = [
(tvm.tir.atan2, lambda x1, x2 : np.arctan2(x1, x2)),
(tvm.tir.nextafter, lambda x1, x2 : np.nextafter(x1, x2)),
(tvm.tir.copysign, lambda x1, x2 : np.copysign(x1, x2)),
(tvm.tir.hypot, lambda x1, x2 : np.hypot(x1, x2)),
]
def run_test(tvm_intrin, np_func):
m = te.var("m",)
A = te.placeholder((m,), name='A')
B = te.placeholder((m,), name='B')
C = te.compute((m,), lambda *i: tvm_intrin(A(*i), B(*i)), name='C')
s = te.create_schedule(C.op)
f = tvm.build(s, [A, B, C], "llvm")
ctx = tvm.cpu(0)
n = 10
a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(B.dtype), ctx)
c = tvm.nd.array( \
np.random.uniform(size=n).astype(A.dtype), ctx)
f(a, b, c)
tvm.testing.assert_allclose(
c.asnumpy(), np_func(a.asnumpy(), b.asnumpy()), atol=1e-5, rtol=1e-5)

for func in test_funcs:
run_test(*func)


def test_ldexp():
m = te.var("m",)
A = te.placeholder((m,), name='A')
B = te.placeholder((m,), name='B', dtype="int32")
C = te.compute((m,), lambda *i: tvm.tir.ldexp(A(*i), B(*i)), name='C')
s = te.create_schedule(C.op)
f = tvm.build(s, [A, B, C], "llvm")
ctx = tvm.cpu(0)
n = 10
a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.randint(0, 5, size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
f(a, b, c)
tvm.testing.assert_allclose(
c.asnumpy(), np.ldexp(a.asnumpy(), b.asnumpy()), atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
test_nearbyint()
test_unary_intrin()
test_round_intrinsics_on_int()
test_binary_intrin()
test_ldexp()