From 16d4da4d61427b292fbc2f8de8c14472b9f36e31 Mon Sep 17 00:00:00 2001 From: "Huang, Guangtai" Date: Mon, 23 Sep 2019 00:49:10 +0800 Subject: [PATCH] Add operator `isnan` (#3979) * add expr `isnan` * move to intrinsic * doc & add to topi * fix error from ci --- docs/api/python/intrin.rst | 2 ++ docs/api/python/topi.rst | 2 ++ include/tvm/expr_operator.h | 6 ++++ include/tvm/ir.h | 1 + python/tvm/intrin.py | 16 +++++++++ src/api/api_ir.cc | 3 ++ src/codegen/codegen_c.cc | 6 ++++ src/codegen/llvm/codegen_llvm.cc | 4 +++ src/lang/expr_operator.cc | 24 ++++++++++++++ tests/python/unittest/test_lang_basic.py | 22 +++++++++++++ topi/include/topi/elemwise.h | 1 + topi/python/topi/math.py | 17 ++++++++++ topi/tests/python/test_topi_math.py | 41 ++++++++++++++++++++++++ 13 files changed, 145 insertions(+) diff --git a/docs/api/python/intrin.rst b/docs/api/python/intrin.rst index e774bb74bd9a..da8b64243209 100644 --- a/docs/api/python/intrin.rst +++ b/docs/api/python/intrin.rst @@ -36,6 +36,7 @@ tvm.intrin tvm.trunc tvm.round tvm.abs + tvm.isnan .. autofunction:: tvm.call_packed .. autofunction:: tvm.call_pure_intrin @@ -52,3 +53,4 @@ tvm.intrin .. autofunction:: tvm.trunc .. autofunction:: tvm.round .. autofunction:: tvm.abs +.. autofunction:: tvm.isnan diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 3b058031438e..3483668a5b08 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -32,6 +32,7 @@ List of operators topi.trunc topi.round topi.abs + topi.isnan topi.exp topi.tanh topi.log @@ -127,6 +128,7 @@ topi .. autofunction:: topi.trunc .. autofunction:: topi.round .. autofunction:: topi.abs +.. autofunction:: topi.isnan .. autofunction:: topi.exp .. autofunction:: topi.tanh .. autofunction:: topi.log diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index f74210d7d8f0..b0e82e7fb50c 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -465,6 +465,12 @@ TVM_DLL Expr pow(Expr x, Expr y); * \return The aboslute value of input data x */ TVM_DLL Expr abs(Expr x); +/*! + * \brief Check if x is NaN. + * \param x The input data + * \return The result expression. + */ +TVM_DLL Expr isnan(Expr x); /*! * \brief sum of of source expression over axis diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 994570e7f9df..079f05f5a7f2 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -574,6 +574,7 @@ class Call : public ExprNode { static constexpr const char* likely = "likely"; static constexpr const char* glsl_texture_store = "glsl_texture_store"; static constexpr const char* prefetch = "prefetch"; + static constexpr const char* isnan = "isnan"; /*! \brief Vectorizable intrinsic list. */ static const char* vectorizable_intrinsics[]; diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index 4fe9d18faefc..2a4ebfec135b 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -434,6 +434,22 @@ def round(x): return _make.round(x) +def isnan(x): + """Check if input value is Nan. + + Parameters + ---------- + x : Expr + Input argument. + + Returns + ------- + y : Expr + The result. + """ + return _make.isnan(x) + + def power(x, y): """x power y diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 9e14048dbbfe..b1f9af4f6f75 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -38,6 +38,9 @@ TVM_REGISTER_API("_Var") TVM_REGISTER_API("make.abs") .set_body_typed(tvm::abs); +TVM_REGISTER_API("make.isnan") +.set_body_typed(tvm::isnan); + TVM_REGISTER_API("make.floor") .set_body_typed(tvm::floor); diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc index 043c64702edc..ecf62ab0cfac 100644 --- a/src/codegen/codegen_c.cc +++ b/src/codegen/codegen_c.cc @@ -576,6 +576,12 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) os << " *)(&("; this->PrintExpr(op->args[0], os); os << ")))"; + } else if (op->is_intrinsic(Call::isnan)) { + os << "("; + this->PrintExpr(op->args[0], os); + os << " != "; + this->PrintExpr(op->args[0], os); + os << ")"; } else { if (op->call_type == Call::Intrinsic || op->call_type == Call::PureIntrinsic) { diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index c30ac841437e..3a58963cc6e1 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -746,6 +746,10 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const Call* op) { } else if (op->is_intrinsic(Call::reinterpret)) { llvm::Type * target = LLVMType(op->type); return builder_->CreateBitCast(MakeValue(op->args[0]), target); + } else if (op->is_intrinsic(Call::isnan)) { + // TODO(hgt312): set fast math flag + llvm::Value* a = MakeValue(op->args[0]); + return builder_->CreateFCmpUNO(a, a); } else if (op->is_intrinsic("vectorlow")) { llvm::Value *v = MakeValue(op->args[0]); int l = v->getType()->getVectorNumElements(); diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 50da8a144c45..d7a40c133784 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -424,6 +424,30 @@ Expr abs(Expr x) { } } +Expr isnan(Expr x) { + Type t = Bool(x.type().lanes()); + if (x.type().is_int() || x.type().is_uint()) { + return make_const(t, false); + } else if (x.type().is_float()) { + using ir::FloatImm; + const FloatImm* fx = x.as(); + if (fx) { + return make_const(t, std::isnan(fx->value)); + } + if (x.type().bits() == 16) { + return ir::Call::make(t, ir::Call::isnan, + {cast(Float(32, t.lanes()), std::move(x))}, + ir::Call::PureIntrinsic); + } else { + return ir::Call::make(t, ir::Call::isnan, {x}, ir::Call::PureIntrinsic); + } + } else { + LOG(FATAL) << "Data type " << x.type() + <<" not supported for isnan op. Skipping isnan op..."; + return x; + } +} + Expr sum(Expr source, Array rdom) { Var x("x", source.type()), y("y", source.type()); Expr result = ir::Add::make(x, y); diff --git a/tests/python/unittest/test_lang_basic.py b/tests/python/unittest/test_lang_basic.py index 7df92edcc3dd..8b54ef9534d6 100644 --- a/tests/python/unittest/test_lang_basic.py +++ b/tests/python/unittest/test_lang_basic.py @@ -17,6 +17,7 @@ import tvm import numpy as np + def test_const(): x = tvm.const(1, "int32") print(x.dtype) @@ -39,6 +40,7 @@ def test_scalar_dtype_inference(): assert tvm.convert(1).dtype == 'int32' assert tvm.convert(1.0).dtype == 'float32' + def test_make(): x = tvm.const(1, "int32") y = tvm.var("x") @@ -46,6 +48,7 @@ def test_make(): assert isinstance(tvm.max(x, y), tvm.expr.Max) assert isinstance(tvm.min(x, y), tvm.expr.Min) + def test_ir(): x = tvm.const(1, "int32") y = tvm.make.IntImm('int32', 1) @@ -53,6 +56,7 @@ def test_ir(): stmt = tvm.make.Evaluate(z) assert isinstance(stmt, tvm.stmt.Evaluate) + def test_ir2(): x = tvm.var("n") a = tvm.var("array", tvm.handle) @@ -60,12 +64,14 @@ def test_ir2(): assert isinstance(st, tvm.stmt.Store) assert(st.buffer_var == a) + def test_let(): x = tvm.var('x') y = tvm.var('y') stmt = tvm.make.LetStmt( x, 10, tvm.make.Evaluate(x + 1)); + def test_cast(): x = tvm.var('x', dtype="float32") y = x.astype("int32") @@ -104,10 +110,12 @@ def test_stmt(): tvm.stmt.For.Serial, 0, x) + def test_dir(): x = tvm.var('x') dir(x) + def test_dtype(): x = tvm.var('x') assert x.dtype == 'int32' @@ -158,6 +166,7 @@ def test_all(): '(((%s < %s) && (%s > (%s + 1))) && (%s < (%s*2)))' % ( x.name, y.name, y.name, z.name, x.name, z.name) + def test_bitwise(): x = tvm.var('x') y = tvm.var('y') @@ -172,6 +181,18 @@ def test_bitwise(): assert(tvm.var("z", "int8x2") << tvm.const(1, "int8x2")).dtype == "int8x2" +def test_isnan(): + x = tvm.var('x', 'float32') + assert str(tvm.isnan(x)) == 'isnan(x)' + assert str(tvm.isnan(x).dtype) == 'bool' + y = tvm.var('y', 'float16') + assert str(tvm.isnan(y)) == 'isnan(float32(y))' + z = tvm.var('z', 'int32') + assert str(tvm.isnan(z)) == '(bool)0' + k = tvm.var('k', 'int8x2') + assert str(tvm.isnan(k).dtype) == 'uint1x2' + + def test_equality(): a = tvm.var('a') b = tvm.var('b') @@ -203,5 +224,6 @@ def test_equality_string_imm(): test_any() test_all() test_bitwise() + test_isnan() test_equality() test_equality_string_imm() diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index 41112da3ff79..0cfc299c130f 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -58,6 +58,7 @@ TOPI_DECLARE_UNARY_OP(abs); TOPI_DECLARE_UNARY_OP(cos); TOPI_DECLARE_UNARY_OP(sin); TOPI_DECLARE_UNARY_OP(atan); +TOPI_DECLARE_UNARY_OP(isnan); /* * \brief Fast_tanh_float implementation from Eigen diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index 1611914d0461..84c9da37ff7a 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -243,6 +243,23 @@ def abs(x): return tvm.compute(x.shape, lambda *i: tvm.abs(x(*i))) +@tvm.tag_scope(tag=tag.ELEMWISE) +def isnan(x): + """Check if value of x is NaN, element-wise. + + Parameters + ---------- + x : tvm.Tensor + Input argument. + + Returns + ------- + y : tvm.Tensor + The result. + """ + return tvm.compute(x.shape, lambda *i: tvm.isnan(x(*i))) + + @tvm.tag_scope(tag=tag.ELEMWISE) def round(x): """Round elements of x to nearest integer. diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py index f5162e06d47a..660d22ccf2bc 100644 --- a/topi/tests/python/test_topi_math.py +++ b/topi/tests/python/test_topi_math.py @@ -72,6 +72,46 @@ def check_device(device): for device in get_all_backend(): check_device(device) + def test_isnan( + low, + high, + shape=(20, 3), + dtype=tvm.float32, + check_round=False, + skip_name_check=False, + ): + m = tvm.var("m") + l = tvm.var("l") + A = tvm.placeholder((m, l), dtype=dtype, name="A") + + B = topi.isnan(A) + assert tuple(B.shape) == tuple(A.shape) + if not skip_name_check: + assert B.op.body[0].name == "isnan" + a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10 + a_np.ravel()[np.random.choice(a_np.size, int(a_np.size * 0.5), replace=False)] = np.nan + # avoid round check too close to boundary + if check_round: + a_np += ((np.fmod(a_np, 1) - 0.5) < 1e-6) * 1e-5 + b_np = np.isnan(a_np) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_injective(B) + foo = tvm.build(s, [A, B], device, name="isnan") + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.zeros_like(b_np), ctx) + foo(a, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) + + for device in get_all_backend(): + check_device(device) + test_apply(topi.floor, "floor", np.floor, -100, 100) test_apply(topi.ceil, "ceil", np.ceil, -100, 100) test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True) @@ -88,6 +128,7 @@ def check_device(device): test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi) test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi) test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32") + test_isnan(-100, 100) def test_cast():