Skip to content

Commit

Permalink
enhance tir signed-unsigned cast (apache#8706)
Browse files Browse the repository at this point in the history
  • Loading branch information
ganler authored and ylc committed Jan 13, 2022
1 parent 3aa38cf commit 2407c35
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,16 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*)
(lhs.dtype().is_uint() && rhs.dtype().is_int())) {
// Handle mixing signed and unsigned integers
int bits = std::max(lhs.dtype().bits(), rhs.dtype().bits());
lhs = SimpleCast(DataType::Int(bits, lhs.dtype().lanes()), lhs, span);
rhs = SimpleCast(DataType::Int(bits, rhs.dtype().lanes()), rhs, span);
// if the signed int range is bigger than that of uint, try uint->int
if (lhs.dtype().is_int() && rhs.dtype().bits() <= bits - 1) {
rhs = cast(lhs.dtype(), rhs);
} else if (rhs.dtype().is_int() && lhs.dtype().bits() <= bits - 1) {
lhs = cast(rhs.dtype(), lhs);
} else {
// the ranges of uint and int types conflit, try SimpleCast
lhs = SimpleCast(DataType::Int(bits, lhs.dtype().lanes()), lhs, span);
rhs = SimpleCast(DataType::Int(bits, rhs.dtype().lanes()), rhs, span);
}
} else {
LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype;
}
Expand Down
4 changes: 4 additions & 0 deletions tests/python/unittest/test_tir_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ def verify_callop_float_only(f):
verify_general_dtype_support(lambda a, b: a <= b, is_conditional=True)
verify_callop_float_only(lambda a, b: te.power(a, b))

# verify bool & int32 constant folding
assert tvm.tir.const(1) == tvm.tir.const(True)
assert tvm.tir.const(2) != tvm.tir.const(True)


def test_if_then_else():
cases = [
Expand Down

0 comments on commit 2407c35

Please sign in to comment.