From 2565ad5f0cb9a073872ce9bc991397a6a86fa1f2 Mon Sep 17 00:00:00 2001 From: pankratz <35379668+dpankratz@users.noreply.github.com> Date: Mon, 17 Feb 2020 19:48:05 -0700 Subject: [PATCH] Fixed bugs that occured when using bitwise operators on floating point type expressions. Further crash when using ops <<, >>, %. Finally added regression tests for both types of bug. (#4892) --- python/tvm/tir/expr.py | 16 ++++++++++++++++ src/tir/ir/op.cc | 10 ++++++++++ tests/python/unittest/test_lang_basic.py | 22 ++++++++++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index e36ca2c1dede..aeda603e19aa 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -52,6 +52,11 @@ def _dtype_is_int(value): return (isinstance(value, ExprOp) and DataType(value.dtype).type_code == TypeCode.INT) +def _dtype_is_float(value): + if isinstance(value, float): + return True + return (isinstance(value, ExprOp) and + DataType(value.dtype).type_code == TypeCode.FLOAT) class ExprOp(object): """Operator overloading for Expr like expressions.""" @@ -102,6 +107,9 @@ def __rfloordiv__(self, other): def __mod__(self, other): return _ffi_api._OpFloorMod(self, other) + def __rmod__(self, other): + return _ffi_api._OpFloorMod(other, self) + def __neg__(self): neg_one = const(-1, self.dtype) return self.__mul__(neg_one) @@ -109,9 +117,15 @@ def __neg__(self): def __lshift__(self, other): return _ffi_api.left_shift(self, other) + def __rlshift__(self, other): + return _ffi_api.left_shift(other, self) + def __rshift__(self, other): return _ffi_api.right_shift(self, other) + def __rrshift__(self, other): + return _ffi_api.right_shift(other, self) + def __and__(self, other): return _ffi_api.bitwise_and(self, other) @@ -131,6 +145,8 @@ def __rxor__(self, other): return _ffi_api.bitwise_xor(other, self) def __invert__(self): + if _dtype_is_float(self): + raise RuntimeError("Cannot use ~ operator on float type Expr.") return _ffi_api.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0) def __lt__(self, other): diff --git a/src/tir/ir/op.cc b/src/tir/ir/op.cc index d046f5d9df4e..58f8b6b76da8 100644 --- a/src/tir/ir/op.cc +++ b/src/tir/ir/op.cc @@ -417,6 +417,8 @@ PrimExpr operator!(PrimExpr a) { } PrimExpr operator>>(PrimExpr a, PrimExpr b) { + CHECK(a.dtype().is_int() || a.dtype().is_uint()); + CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); @@ -430,6 +432,8 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) { } PrimExpr operator<<(PrimExpr a, PrimExpr b) { + CHECK(a.dtype().is_int() || a.dtype().is_uint()); + CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); @@ -443,6 +447,8 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) { } PrimExpr operator&(PrimExpr a, PrimExpr b) { + CHECK(a.dtype().is_int() || a.dtype().is_uint()); + CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); @@ -453,6 +459,8 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) { } PrimExpr operator|(PrimExpr a, PrimExpr b) { + CHECK(a.dtype().is_int() || a.dtype().is_uint()); + CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); @@ -463,6 +471,8 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) { } PrimExpr operator^(PrimExpr a, PrimExpr b) { + CHECK(a.dtype().is_int() || a.dtype().is_uint()); + CHECK(b.dtype().is_int() || b.dtype().is_uint()); BinaryOpMatchTypes(a, b); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); diff --git a/tests/python/unittest/test_lang_basic.py b/tests/python/unittest/test_lang_basic.py index 733992595562..3b1431a54d36 100644 --- a/tests/python/unittest/test_lang_basic.py +++ b/tests/python/unittest/test_lang_basic.py @@ -178,11 +178,32 @@ def test_bitwise(): assert str(10 & x) == 'bitwise_and(10, x)' assert str(10 | x) == 'bitwise_or(10, x)' assert str(10 ^ x) == 'bitwise_xor(10, x)' + assert str(10 >> x) == 'shift_right(10, x)' + assert str(10 << x) == 'shift_left(10, x)' + assert str(10 % x) == 'floormod(10, x)' assert str(~x) == 'bitwise_not(x)' assert(tvm.const(1, "int8x2") >> 1).dtype == "int8x2" assert(x >> tvm.const(1, "int32x2")).dtype == "int32x2" assert(tvm.var("z", "int8x2") << tvm.const(1, "int8x2")).dtype == "int8x2" +def test_float_bitwise(): + t = tvm.const(1.5,dtype='float32') + for test in [lambda lhs, rhs : lhs << rhs, + lambda lhs, rhs : lhs >> rhs, + lambda lhs, rhs : lhs | rhs, + lambda lhs, rhs : lhs ^ rhs, + lambda lhs, rhs : lhs & rhs + ]: + try: + test(t,10.0) + assert False + except tvm.TVMError: + pass + try: + ~t + assert False + except RuntimeError: + pass def test_isnan(): x = tvm.var('x', 'float32') @@ -227,6 +248,7 @@ def test_equality_string_imm(): test_any() test_all() test_bitwise() + test_float_bitwise() test_isnan() test_equality() test_equality_string_imm()