Skip to content

Commit

Permalink
[TIR] Change Integer Implicit Conversion Rule to C Standard Way
Browse files Browse the repository at this point in the history
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
  • Loading branch information
Johnson9009 and junrushao committed Aug 15, 2021
1 parent 994a151 commit b065b7b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 38 deletions.
63 changes: 28 additions & 35 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,6 @@ Type GetType(const PrimExpr& expr) {
return PrimType(dtype);
}

// simple cast that only checks if type matches and cast
inline PrimExpr SimpleCast(const DataType& t, PrimExpr value, Span span) {
if (value.dtype() == t) return value;
return tir::Cast(t, value, span);
}

// LargeUIntImm
PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high, Span span) {
return tir::Call(
Expand Down Expand Up @@ -113,48 +107,47 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*)
}
if (lhs.dtype() == rhs.dtype()) return;

ltype = lhs.dtype();
rtype = rhs.dtype();
// We keep dtypes conversion to be relatively consistent to reduce the amount code generated by
// operators. This can be helpful for users to find potential type conversion problems. The
// following are exceptions:
if (lhs.dtype().is_float() && rhs.dtype().is_float()) {
if (ltype.is_float() && rtype.is_float()) {
// Given two dissimilar floats, cast the lower bit version to the higher bit version.
// E.g. fp16 + fp32 --> fp32 + fp32
if (lhs.dtype().bits() < rhs.dtype().bits()) {
lhs = cast(rhs.dtype(), lhs);
} else if (lhs.dtype().bits() > rhs.dtype().bits()) {
rhs = cast(lhs.dtype(), rhs);
if (ltype.bits() < rtype.bits()) {
lhs = cast(rtype, lhs);
} else {
rhs = cast(ltype, rhs);
}
} else if (!lhs.dtype().is_float() &&
(rhs.dtype().is_float() ||
datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) {
} else if (!ltype.is_float() &&
(rtype.is_float() || datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) {
// Cast int->float when the other operand is a float
lhs = cast(rhs.dtype(), lhs);
} else if ((lhs.dtype().is_float() ||
datatype::Registry::Global()->GetTypeRegistered(lhs.dtype().code())) &&
!rhs.dtype().is_float()) {
lhs = cast(rtype, lhs);
} else if ((ltype.is_float() || datatype::Registry::Global()->GetTypeRegistered(ltype.code())) &&
!rtype.is_float()) {
// Cast int->float when the other operand is a float
rhs = cast(lhs.dtype(), rhs);
} else if ((lhs.dtype().is_int() && rhs.dtype().is_int()) ||
(lhs.dtype().is_uint() && rhs.dtype().is_uint())) {
rhs = cast(ltype, rhs);
} else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) {
// Promote int to higher bits e.g. int8 + int16 --> int16 + int16
if (lhs.dtype().bits() < rhs.dtype().bits()) {
lhs = cast(rhs.dtype(), lhs);
if (ltype.bits() < rtype.bits()) {
lhs = cast(rtype, lhs);
} else {
rhs = cast(lhs.dtype(), rhs);
rhs = cast(ltype, rhs);
}
} else if ((lhs.dtype().is_int() && rhs.dtype().is_uint()) ||
(lhs.dtype().is_uint() && rhs.dtype().is_int())) {
} else if ((ltype.is_int() && rtype.is_uint()) || (ltype.is_uint() && rtype.is_int())) {
// Handle mixing signed and unsigned integers
int bits = std::max(lhs.dtype().bits(), rhs.dtype().bits());
// if the signed int range is bigger than that of uint, try uint->int
if (lhs.dtype().is_int() && rhs.dtype().bits() <= bits - 1) {
rhs = cast(lhs.dtype(), rhs);
} else if (rhs.dtype().is_int() && lhs.dtype().bits() <= bits - 1) {
lhs = cast(rhs.dtype(), lhs);
if (ltype.bits() < rtype.bits()) {
lhs = cast(rtype, lhs);
} else if (ltype.bits() > rtype.bits()) {
rhs = cast(ltype, rhs);
} else {
// the ranges of uint and int types conflit, try SimpleCast
lhs = SimpleCast(DataType::Int(bits, lhs.dtype().lanes()), lhs, span);
rhs = SimpleCast(DataType::Int(bits, rhs.dtype().lanes()), rhs, span);
// The width of signed and unsigned integers is same.
if (ltype.is_uint()) {
rhs = cast(ltype, rhs);
} else {
lhs = cast(rtype, lhs);
}
}
} else {
LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype;
Expand Down
7 changes: 4 additions & 3 deletions tests/python/unittest/test_tir_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def verify_general_dtype_support(f, is_conditional=False):
[("bool", "int32"), "int32"],
[("int32", "float32"), "float32"],
[("int32", "int64"), "int64"],
[("uint32", "int32"), "int32"],
[("uint32", "int8"), "uint32"],
[("uint32", "int32"), "uint32"],
]
for (lhs_dtype, rhs_dtype), out_dtype in rules:
lhs = te.var("lhs", dtype=lhs_dtype)
Expand Down Expand Up @@ -184,8 +185,8 @@ def test_if_then_else():
[(te.var("cond", dtype="bool"), "bool", "int32"), "int32"],
[(True, "int32", "float32"), "float32"],
[(False, "int32", "int64"), "int64"],
[(te.var("cond", dtype="bool"), "uint32", "int32"), "int32"],
[(te.var("cond", dtype="int32"), "uint32", "int32"), "int32"],
[(te.var("cond", dtype="bool"), "uint32", "int32"), "uint32"],
[(te.var("cond", dtype="int32"), "uint32", "int32"), "uint32"],
]
for (cond, lhs_dtype, rhs_dtype), out_dtype in cases:
lhs = te.var("lhs", dtype=lhs_dtype)
Expand Down

0 comments on commit b065b7b

Please sign in to comment.