diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index bae34bdd6b05..a440af994202 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -181,6 +181,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { + CHECK_NE(pb->value, 0) << "Divide by zero"; return IntImm(rtype, pa->value % pb->value); } if (pa) { @@ -226,6 +227,7 @@ inline PrimExpr TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { + CHECK_NE(pb->value, 0) << "Divide by zero"; return IntImm(rtype, floormod(pa->value, pb->value)); } if (pa) { diff --git a/src/tir/ir/op.cc b/src/tir/ir/op.cc index 452c3bbc68a2..2882fea0693b 100644 --- a/src/tir/ir/op.cc +++ b/src/tir/ir/op.cc @@ -606,6 +606,9 @@ PrimExpr fmod(PrimExpr x, PrimExpr y) { } PrimExpr floor(PrimExpr x) { + if (x.dtype().is_int() || x.dtype().is_uint()) { + return x; + } using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::floor(fx->value)); @@ -613,6 +616,9 @@ PrimExpr floor(PrimExpr x) { } PrimExpr ceil(PrimExpr x) { + if (x.dtype().is_int() || x.dtype().is_uint()) { + return x; + } using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::ceil(fx->value)); @@ -620,6 +626,9 @@ PrimExpr ceil(PrimExpr x) { } PrimExpr round(PrimExpr x) { + if (x.dtype().is_int() || x.dtype().is_uint()) { + return x; + } using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); @@ -627,6 +636,9 @@ PrimExpr round(PrimExpr x) { } PrimExpr nearbyint(PrimExpr x) { + if (x.dtype().is_int() || x.dtype().is_uint()) { + return x; + } using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value)); @@ -634,6 +646,9 @@ PrimExpr nearbyint(PrimExpr x) { } PrimExpr trunc(PrimExpr x) { + if (x.dtype().is_int() || x.dtype().is_uint()) { + return x; + } using tir::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { diff --git a/tests/python/unittest/test_lang_basic.py b/tests/python/unittest/test_lang_basic.py index cd532a0db77f..c279194ce522 100644 --- a/tests/python/unittest/test_lang_basic.py +++ b/tests/python/unittest/test_lang_basic.py @@ -187,14 +187,14 @@ def test_bitwise(): assert(x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2" assert(te.var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2" + def test_float_bitwise(): t = tvm.tir.const(1.5,dtype='float32') for test in [lambda lhs, rhs : lhs << rhs, lambda lhs, rhs : lhs >> rhs, lambda lhs, rhs : lhs | rhs, lambda lhs, rhs : lhs ^ rhs, - lambda lhs, rhs : lhs & rhs - ]: + lambda lhs, rhs : lhs & rhs]: try: test(t,10.0) assert False @@ -206,6 +206,20 @@ def test_float_bitwise(): except RuntimeError: pass + +def test_divide_by_zero(): + for test in [lambda lhs, rhs : tvm.tir.floormod(lhs,rhs), + lambda lhs, rhs : tvm.tir.floordiv(lhs,rhs), + lambda lhs, rhs : tvm.tir.truncmod(lhs,rhs), + lambda lhs, rhs : tvm.tir.truncdiv(lhs,rhs), + lambda lhs, rhs : tvm.tir.div(lhs,rhs)]: + try: + test(tvm.tir.const(5,'int32'), tvm.tir.const(0,'int32')) + assert False + except tvm.TVMError: + pass + + def test_isnan(): x = te.var('x', 'float32') assert str(tvm.tir.isnan(x)) == 'isnan(x)' @@ -250,6 +264,7 @@ def test_equality_string_imm(): test_all() test_bitwise() test_float_bitwise() + test_divide_by_zero() test_isnan() test_equality() test_equality_string_imm() diff --git a/tests/python/unittest/test_tvm_intrin.py b/tests/python/unittest/test_tvm_intrin.py index 0054273e6210..52ae4408eda6 100644 --- a/tests/python/unittest/test_tvm_intrin.py +++ b/tests/python/unittest/test_tvm_intrin.py @@ -44,6 +44,16 @@ def test_nearbyint(): tvm.testing.assert_allclose( a_rounded.asnumpy(), np.rint(a.asnumpy())) +def test_round_intrinsics_on_int(): + i = tvm.te.var("i", 'int32') + for op in [tvm.tir.round, tvm.tir.trunc, tvm.tir.ceil, + tvm.tir.floor, tvm.tir.nearbyint]: + assert op(tvm.tir.const(10,'int32')).value == 10 + assert op(tvm.tir.const(True,'bool')).value == True + assert op(i).same_as(i) + + assert tvm.tir.isnan(tvm.tir.const(10, 'int32')).value == False + def test_unary_intrin(): test_funcs = [ @@ -75,3 +85,4 @@ def run_test(tvm_intrin, np_func): if __name__ == "__main__": test_nearbyint() test_unary_intrin() + test_round_intrinsics_on_int()