diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 69d1da5e8c1c..b9e0c3c37068 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -520,10 +520,37 @@ PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span) { return tir::NE(a, b, span); } +namespace { +void type_check_boolean_args(const PrimExpr& arg, const char* op) { + ICHECK(arg.dtype().is_bool()) << "Expected boolean argument for " << op << ", but received " + << arg << " of type " << arg.dtype(); +} +void type_check_boolean_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) { + ICHECK(lhs.dtype().is_bool()) << "Expected boolean argument as LHS of " << op << ", but received " + << lhs << " of type " << lhs.dtype(); + ICHECK(rhs.dtype().is_bool()) << "Expected boolean argument as RHS of " << op << ", but received " + << rhs << " of type " << rhs.dtype(); +} + +void type_check_integer_args(const PrimExpr& arg, const char* op) { + ICHECK(arg.dtype().is_int() || arg.dtype().is_uint()) + << "Expected integer argument for " << op << ", but received " << arg << " of type " + << arg.dtype(); +} + +void type_check_integer_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) { + ICHECK(lhs.dtype().is_int() || lhs.dtype().is_uint()) + << "Expected integer argument as LHS of " << op << ", but received " << lhs << " of type " + << lhs.dtype(); + ICHECK(rhs.dtype().is_int() || rhs.dtype().is_uint()) + << "Expected integer argument as RHS of " << op << ", but received " << rhs << " of type " + << rhs.dtype(); +} +} // namespace + PrimExpr operator&&(PrimExpr a, PrimExpr b) { return logical_and(a, b); } PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span) { - ICHECK(a.dtype().is_bool()); - ICHECK(b.dtype().is_bool()); + type_check_boolean_args(a, b, "&& operator (logical AND)"); PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return tir::And(a, b, span); @@ -531,8 +558,7 @@ PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span) { PrimExpr operator||(PrimExpr a, PrimExpr b) { return logical_or(a, b); } PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span) { - ICHECK(a.dtype().is_bool()); - ICHECK(b.dtype().is_bool()); + type_check_boolean_args(a, b, "|| operator (logical OR)"); PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; return tir::Or(a, b, span); @@ -540,7 +566,7 @@ PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span) { PrimExpr operator!(PrimExpr a) { return logical_not(a); } PrimExpr logical_not(PrimExpr a, Span span) { - ICHECK(a.dtype().is_bool()); + type_check_boolean_args(a, "! operator (logical NOT)"); PrimExpr ret = arith::TryConstFold(a); if (ret.defined()) return ret; return tir::Not(a, span); @@ -550,8 +576,8 @@ PrimExpr logical_not(PrimExpr a, Span span) { PrimExpr operator>>(PrimExpr a, PrimExpr b) { return right_shift(a, b); } PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span) { - ICHECK(a.dtype().is_int() || a.dtype().is_uint()); - ICHECK(b.dtype().is_int() || b.dtype().is_uint()); + type_check_integer_args(a, b, ">> operator (right shift)"); + BinaryOpMatchTypes(a, b, span); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); @@ -573,8 +599,7 @@ PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span) { // shift left PrimExpr operator<<(PrimExpr a, PrimExpr b) { return left_shift(a, b); } PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) { - ICHECK(a.dtype().is_int() || a.dtype().is_uint()); - ICHECK(b.dtype().is_int() || b.dtype().is_uint()); + type_check_integer_args(a, b, "<< operator (left shift)"); BinaryOpMatchTypes(a, b, span); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); @@ -593,8 +618,7 @@ PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) { // bitwise and PrimExpr operator&(PrimExpr a, PrimExpr b) { return bitwise_and(a, b); } PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) { - ICHECK(a.dtype().is_int() || a.dtype().is_uint()); - ICHECK(b.dtype().is_int() || b.dtype().is_uint()); + type_check_integer_args(a, b, "& operator (bitwise AND)"); BinaryOpMatchTypes(a, b, span); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); @@ -606,8 +630,7 @@ PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) { // bitwise_or PrimExpr operator|(PrimExpr a, PrimExpr b) { return bitwise_or(a, b); } PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) { - ICHECK(a.dtype().is_int() || a.dtype().is_uint()); - ICHECK(b.dtype().is_int() || b.dtype().is_uint()); + type_check_integer_args(a, b, "| operator (bitwise OR)"); BinaryOpMatchTypes(a, b, span); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); @@ -619,8 +642,7 @@ PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) { // bitwise_xor PrimExpr operator^(PrimExpr a, PrimExpr b) { return bitwise_xor(a, b); } PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) { - ICHECK(a.dtype().is_int() || a.dtype().is_uint()); - ICHECK(b.dtype().is_int() || b.dtype().is_uint()); + type_check_integer_args(a, b, "^ operator (bitwise XOR)"); BinaryOpMatchTypes(a, b, span); TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); @@ -633,7 +655,7 @@ PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) { PrimExpr operator~(PrimExpr a) { return bitwise_neg(a); } PrimExpr bitwise_neg(PrimExpr a, Span span) { - ICHECK(a.dtype().is_int() || a.dtype().is_uint()); + type_check_integer_args(a, "~ operator (bitwise NOT)"); return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, span); } @@ -728,7 +750,7 @@ PrimExpr sum(PrimExpr source, Array rdom, Array init, Span sp } PrimExpr all(PrimExpr source, Array rdom, Array init, Span span) { - ICHECK(source.dtype().is_bool()); + type_check_boolean_args(source, "tvm::all"); Var x("x", source.dtype(), span), y("y", source.dtype()); PrimExpr result = tir::And(x, y, span); PrimExpr identity_element = make_const(source.dtype(), true, span); @@ -737,7 +759,7 @@ PrimExpr all(PrimExpr source, Array rdom, Array init, Span sp } PrimExpr any(PrimExpr source, Array rdom, Array init, Span span) { - ICHECK(source.dtype().is_bool()); + type_check_boolean_args(source, "tvm::any"); Var x("x", source.dtype(), span), y("y", source.dtype(), span); PrimExpr result = tir::Or(x, y, span); PrimExpr identity_element = make_const(source.dtype(), false, span);