Skip to content

Commit

Permalink
Fixed bugs that occured when using bitwise operators on floating poin…
Browse files Browse the repository at this point in the history
…t type expressions. Further crash when using ops <<, >>, %. Finally added regression tests for both types of bug. (apache#4892)
  • Loading branch information
dpankratz authored and zhiics committed Mar 2, 2020
1 parent f4a6b65 commit 2565ad5
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 0 deletions.
16 changes: 16 additions & 0 deletions python/tvm/tir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ def _dtype_is_int(value):
return (isinstance(value, ExprOp) and
DataType(value.dtype).type_code == TypeCode.INT)

def _dtype_is_float(value):
if isinstance(value, float):
return True
return (isinstance(value, ExprOp) and
DataType(value.dtype).type_code == TypeCode.FLOAT)

class ExprOp(object):
"""Operator overloading for Expr like expressions."""
Expand Down Expand Up @@ -102,16 +107,25 @@ def __rfloordiv__(self, other):
def __mod__(self, other):
return _ffi_api._OpFloorMod(self, other)

def __rmod__(self, other):
return _ffi_api._OpFloorMod(other, self)

def __neg__(self):
neg_one = const(-1, self.dtype)
return self.__mul__(neg_one)

def __lshift__(self, other):
return _ffi_api.left_shift(self, other)

def __rlshift__(self, other):
return _ffi_api.left_shift(other, self)

def __rshift__(self, other):
return _ffi_api.right_shift(self, other)

def __rrshift__(self, other):
return _ffi_api.right_shift(other, self)

def __and__(self, other):
return _ffi_api.bitwise_and(self, other)

Expand All @@ -131,6 +145,8 @@ def __rxor__(self, other):
return _ffi_api.bitwise_xor(other, self)

def __invert__(self):
if _dtype_is_float(self):
raise RuntimeError("Cannot use ~ operator on float type Expr.")
return _ffi_api.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0)

def __lt__(self, other):
Expand Down
10 changes: 10 additions & 0 deletions src/tir/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,8 @@ PrimExpr operator!(PrimExpr a) {
}

PrimExpr operator>>(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
Expand All @@ -430,6 +432,8 @@ PrimExpr operator>>(PrimExpr a, PrimExpr b) {
}

PrimExpr operator<<(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
Expand All @@ -443,6 +447,8 @@ PrimExpr operator<<(PrimExpr a, PrimExpr b) {
}

PrimExpr operator&(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
Expand All @@ -453,6 +459,8 @@ PrimExpr operator&(PrimExpr a, PrimExpr b) {
}

PrimExpr operator|(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
Expand All @@ -463,6 +471,8 @@ PrimExpr operator|(PrimExpr a, PrimExpr b) {
}

PrimExpr operator^(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
Expand Down
22 changes: 22 additions & 0 deletions tests/python/unittest/test_lang_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,32 @@ def test_bitwise():
assert str(10 & x) == 'bitwise_and(10, x)'
assert str(10 | x) == 'bitwise_or(10, x)'
assert str(10 ^ x) == 'bitwise_xor(10, x)'
assert str(10 >> x) == 'shift_right(10, x)'
assert str(10 << x) == 'shift_left(10, x)'
assert str(10 % x) == 'floormod(10, x)'
assert str(~x) == 'bitwise_not(x)'
assert(tvm.const(1, "int8x2") >> 1).dtype == "int8x2"
assert(x >> tvm.const(1, "int32x2")).dtype == "int32x2"
assert(tvm.var("z", "int8x2") << tvm.const(1, "int8x2")).dtype == "int8x2"

def test_float_bitwise():
t = tvm.const(1.5,dtype='float32')
for test in [lambda lhs, rhs : lhs << rhs,
lambda lhs, rhs : lhs >> rhs,
lambda lhs, rhs : lhs | rhs,
lambda lhs, rhs : lhs ^ rhs,
lambda lhs, rhs : lhs & rhs
]:
try:
test(t,10.0)
assert False
except tvm.TVMError:
pass
try:
~t
assert False
except RuntimeError:
pass

def test_isnan():
x = tvm.var('x', 'float32')
Expand Down Expand Up @@ -227,6 +248,7 @@ def test_equality_string_imm():
test_any()
test_all()
test_bitwise()
test_float_bitwise()
test_isnan()
test_equality()
test_equality_string_imm()

0 comments on commit 2565ad5

Please sign in to comment.