diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index aca6d1b50b0e..d29132450227 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -146,8 +146,16 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) (lhs.dtype().is_uint() && rhs.dtype().is_int())) { // Handle mixing signed and unsigned integers int bits = std::max(lhs.dtype().bits(), rhs.dtype().bits()); - lhs = SimpleCast(DataType::Int(bits, lhs.dtype().lanes()), lhs, span); - rhs = SimpleCast(DataType::Int(bits, rhs.dtype().lanes()), rhs, span); + // if the signed int range is bigger than that of uint, try uint->int + if (lhs.dtype().is_int() && rhs.dtype().bits() <= bits - 1) { + rhs = cast(lhs.dtype(), rhs); + } else if (rhs.dtype().is_int() && lhs.dtype().bits() <= bits - 1) { + lhs = cast(rhs.dtype(), lhs); + } else { + // the ranges of uint and int types conflit, try SimpleCast + lhs = SimpleCast(DataType::Int(bits, lhs.dtype().lanes()), lhs, span); + rhs = SimpleCast(DataType::Int(bits, rhs.dtype().lanes()), rhs, span); + } } else { LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype; } diff --git a/tests/python/unittest/test_tir_ops.py b/tests/python/unittest/test_tir_ops.py index 78eab6bdde9f..aeec63abba27 100644 --- a/tests/python/unittest/test_tir_ops.py +++ b/tests/python/unittest/test_tir_ops.py @@ -174,6 +174,10 @@ def verify_callop_float_only(f): verify_general_dtype_support(lambda a, b: a <= b, is_conditional=True) verify_callop_float_only(lambda a, b: te.power(a, b)) + # verify bool & int32 constant folding + assert tvm.tir.const(1) == tvm.tir.const(True) + assert tvm.tir.const(2) != tvm.tir.const(True) + def test_if_then_else(): cases = [