Skip to content

Commit

Permalink
Add operator isnan (#3979)
Browse files Browse the repository at this point in the history
* add expr `isnan`

* move to intrinsic

* doc & add to topi

* fix error from ci
  • Loading branch information
hgt312 authored and tqchen committed Sep 22, 2019
1 parent 88cd1b1 commit 16d4da4
Show file tree
Hide file tree
Showing 13 changed files with 145 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/intrin.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ tvm.intrin
tvm.trunc
tvm.round
tvm.abs
tvm.isnan

.. autofunction:: tvm.call_packed
.. autofunction:: tvm.call_pure_intrin
Expand All @@ -52,3 +53,4 @@ tvm.intrin
.. autofunction:: tvm.trunc
.. autofunction:: tvm.round
.. autofunction:: tvm.abs
.. autofunction:: tvm.isnan
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ List of operators
topi.trunc
topi.round
topi.abs
topi.isnan
topi.exp
topi.tanh
topi.log
Expand Down Expand Up @@ -127,6 +128,7 @@ topi
.. autofunction:: topi.trunc
.. autofunction:: topi.round
.. autofunction:: topi.abs
.. autofunction:: topi.isnan
.. autofunction:: topi.exp
.. autofunction:: topi.tanh
.. autofunction:: topi.log
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,12 @@ TVM_DLL Expr pow(Expr x, Expr y);
* \return The aboslute value of input data x
*/
TVM_DLL Expr abs(Expr x);
/*!
* \brief Check if x is NaN.
* \param x The input data
* \return The result expression.
*/
TVM_DLL Expr isnan(Expr x);

/*!
* \brief sum of of source expression over axis
Expand Down
1 change: 1 addition & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ class Call : public ExprNode {
static constexpr const char* likely = "likely";
static constexpr const char* glsl_texture_store = "glsl_texture_store";
static constexpr const char* prefetch = "prefetch";
static constexpr const char* isnan = "isnan";

/*! \brief Vectorizable intrinsic list. */
static const char* vectorizable_intrinsics[];
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,22 @@ def round(x):
return _make.round(x)


def isnan(x):
"""Check if input value is Nan.
Parameters
----------
x : Expr
Input argument.
Returns
-------
y : Expr
The result.
"""
return _make.isnan(x)


def power(x, y):
"""x power y
Expand Down
3 changes: 3 additions & 0 deletions src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ TVM_REGISTER_API("_Var")
TVM_REGISTER_API("make.abs")
.set_body_typed(tvm::abs);

TVM_REGISTER_API("make.isnan")
.set_body_typed(tvm::isnan);

TVM_REGISTER_API("make.floor")
.set_body_typed(tvm::floor);

Expand Down
6 changes: 6 additions & 0 deletions src/codegen/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,12 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
os << " *)(&(";
this->PrintExpr(op->args[0], os);
os << ")))";
} else if (op->is_intrinsic(Call::isnan)) {
os << "(";
this->PrintExpr(op->args[0], os);
os << " != ";
this->PrintExpr(op->args[0], os);
os << ")";
} else {
if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
Expand Down
4 changes: 4 additions & 0 deletions src/codegen/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) {
} else if (op->is_intrinsic(Call::reinterpret)) {
llvm::Type * target = LLVMType(op->type);
return builder_->CreateBitCast(MakeValue(op->args[0]), target);
} else if (op->is_intrinsic(Call::isnan)) {
// TODO(hgt312): set fast math flag
llvm::Value* a = MakeValue(op->args[0]);
return builder_->CreateFCmpUNO(a, a);
} else if (op->is_intrinsic("vectorlow")) {
llvm::Value *v = MakeValue(op->args[0]);
int l = v->getType()->getVectorNumElements();
Expand Down
24 changes: 24 additions & 0 deletions src/lang/expr_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,30 @@ Expr abs(Expr x) {
}
}

Expr isnan(Expr x) {
Type t = Bool(x.type().lanes());
if (x.type().is_int() || x.type().is_uint()) {
return make_const(t, false);
} else if (x.type().is_float()) {
using ir::FloatImm;
const FloatImm* fx = x.as<FloatImm>();
if (fx) {
return make_const(t, std::isnan(fx->value));
}
if (x.type().bits() == 16) {
return ir::Call::make(t, ir::Call::isnan,
{cast(Float(32, t.lanes()), std::move(x))},
ir::Call::PureIntrinsic);
} else {
return ir::Call::make(t, ir::Call::isnan, {x}, ir::Call::PureIntrinsic);
}
} else {
LOG(FATAL) << "Data type " << x.type()
<<" not supported for isnan op. Skipping isnan op...";
return x;
}
}

Expr sum(Expr source, Array<IterVar> rdom) {
Var x("x", source.type()), y("y", source.type());
Expr result = ir::Add::make(x, y);
Expand Down
22 changes: 22 additions & 0 deletions tests/python/unittest/test_lang_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tvm
import numpy as np


def test_const():
x = tvm.const(1, "int32")
print(x.dtype)
Expand All @@ -39,33 +40,38 @@ def test_scalar_dtype_inference():
assert tvm.convert(1).dtype == 'int32'
assert tvm.convert(1.0).dtype == 'float32'


def test_make():
x = tvm.const(1, "int32")
y = tvm.var("x")
z = x + y
assert isinstance(tvm.max(x, y), tvm.expr.Max)
assert isinstance(tvm.min(x, y), tvm.expr.Min)


def test_ir():
x = tvm.const(1, "int32")
y = tvm.make.IntImm('int32', 1)
z = x + y
stmt = tvm.make.Evaluate(z)
assert isinstance(stmt, tvm.stmt.Evaluate)


def test_ir2():
x = tvm.var("n")
a = tvm.var("array", tvm.handle)
st = tvm.make.Store(a, x + 1, 1)
assert isinstance(st, tvm.stmt.Store)
assert(st.buffer_var == a)


def test_let():
x = tvm.var('x')
y = tvm.var('y')
stmt = tvm.make.LetStmt(
x, 10, tvm.make.Evaluate(x + 1));


def test_cast():
x = tvm.var('x', dtype="float32")
y = x.astype("int32")
Expand Down Expand Up @@ -104,10 +110,12 @@ def test_stmt():
tvm.stmt.For.Serial, 0,
x)


def test_dir():
x = tvm.var('x')
dir(x)


def test_dtype():
x = tvm.var('x')
assert x.dtype == 'int32'
Expand Down Expand Up @@ -158,6 +166,7 @@ def test_all():
'(((%s < %s) && (%s > (%s + 1))) && (%s < (%s*2)))' % (
x.name, y.name, y.name, z.name, x.name, z.name)


def test_bitwise():
x = tvm.var('x')
y = tvm.var('y')
Expand All @@ -172,6 +181,18 @@ def test_bitwise():
assert(tvm.var("z", "int8x2") << tvm.const(1, "int8x2")).dtype == "int8x2"


def test_isnan():
x = tvm.var('x', 'float32')
assert str(tvm.isnan(x)) == 'isnan(x)'
assert str(tvm.isnan(x).dtype) == 'bool'
y = tvm.var('y', 'float16')
assert str(tvm.isnan(y)) == 'isnan(float32(y))'
z = tvm.var('z', 'int32')
assert str(tvm.isnan(z)) == '(bool)0'
k = tvm.var('k', 'int8x2')
assert str(tvm.isnan(k).dtype) == 'uint1x2'


def test_equality():
a = tvm.var('a')
b = tvm.var('b')
Expand Down Expand Up @@ -203,5 +224,6 @@ def test_equality_string_imm():
test_any()
test_all()
test_bitwise()
test_isnan()
test_equality()
test_equality_string_imm()
1 change: 1 addition & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ TOPI_DECLARE_UNARY_OP(abs);
TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(atan);
TOPI_DECLARE_UNARY_OP(isnan);

/*
* \brief Fast_tanh_float implementation from Eigen
Expand Down
17 changes: 17 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,23 @@ def abs(x):
return tvm.compute(x.shape, lambda *i: tvm.abs(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def isnan(x):
"""Check if value of x is NaN, element-wise.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.isnan(x(*i)))


@tvm.tag_scope(tag=tag.ELEMWISE)
def round(x):
"""Round elements of x to nearest integer.
Expand Down
41 changes: 41 additions & 0 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,46 @@ def check_device(device):
for device in get_all_backend():
check_device(device)

def test_isnan(
low,
high,
shape=(20, 3),
dtype=tvm.float32,
check_round=False,
skip_name_check=False,
):
m = tvm.var("m")
l = tvm.var("l")
A = tvm.placeholder((m, l), dtype=dtype, name="A")

B = topi.isnan(A)
assert tuple(B.shape) == tuple(A.shape)
if not skip_name_check:
assert B.op.body[0].name == "isnan"
a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10
a_np.ravel()[np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)] = np.nan
# avoid round check too close to boundary
if check_round:
a_np += ((np.fmod(a_np, 1) - 0.5) < 1e-6) * 1e-5
b_np = np.isnan(a_np)

def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
foo = tvm.build(s, [A, B], device, name="isnan")
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros_like(b_np), ctx)
foo(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)

for device in get_all_backend():
check_device(device)

test_apply(topi.floor, "floor", np.floor, -100, 100)
test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True)
Expand All @@ -88,6 +128,7 @@ def check_device(device):
test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi)
test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32")
test_isnan(-100, 100)


def test_cast():
Expand Down

0 comments on commit 16d4da4

Please sign in to comment.