Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[TIR] Improved error messages for PrimExpr operator overloads (apache…
Browse files Browse the repository at this point in the history
…#12638)

Previously, type-checks in boolean operators on `PrimExpr` would
state that the type is incorrect, but further investigation would be
required in order to determine what expression caused the error.
After this commit, error messages for these type checks include the
expression that was used, and the dtype of that expression.
  • Loading branch information
Lunderberg authored and xinetzone committed Nov 25, 2022
1 parent 4e20021 commit cb1617d
Showing 1 changed file with 40 additions and 18 deletions.
58 changes: 40 additions & 18 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -520,27 +520,53 @@ PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span) {
return tir::NE(a, b, span);
}

namespace {
void type_check_boolean_args(const PrimExpr& arg, const char* op) {
ICHECK(arg.dtype().is_bool()) << "Expected boolean argument for " << op << ", but received "
<< arg << " of type " << arg.dtype();
}
void type_check_boolean_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) {
ICHECK(lhs.dtype().is_bool()) << "Expected boolean argument as LHS of " << op << ", but received "
<< lhs << " of type " << lhs.dtype();
ICHECK(rhs.dtype().is_bool()) << "Expected boolean argument as RHS of " << op << ", but received "
<< rhs << " of type " << rhs.dtype();
}

void type_check_integer_args(const PrimExpr& arg, const char* op) {
ICHECK(arg.dtype().is_int() || arg.dtype().is_uint())
<< "Expected integer argument for " << op << ", but received " << arg << " of type "
<< arg.dtype();
}

void type_check_integer_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) {
ICHECK(lhs.dtype().is_int() || lhs.dtype().is_uint())
<< "Expected integer argument as LHS of " << op << ", but received " << lhs << " of type "
<< lhs.dtype();
ICHECK(rhs.dtype().is_int() || rhs.dtype().is_uint())
<< "Expected integer argument as RHS of " << op << ", but received " << rhs << " of type "
<< rhs.dtype();
}
} // namespace

PrimExpr operator&&(PrimExpr a, PrimExpr b) { return logical_and(a, b); }
PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_bool());
ICHECK(b.dtype().is_bool());
type_check_boolean_args(a, b, "&& operator (logical AND)");
PrimExpr ret = arith::TryConstFold<tir::And>(a, b);
if (ret.defined()) return ret;
return tir::And(a, b, span);
}

PrimExpr operator||(PrimExpr a, PrimExpr b) { return logical_or(a, b); }
PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_bool());
ICHECK(b.dtype().is_bool());
type_check_boolean_args(a, b, "|| operator (logical OR)");
PrimExpr ret = arith::TryConstFold<tir::Or>(a, b);
if (ret.defined()) return ret;
return tir::Or(a, b, span);
}

PrimExpr operator!(PrimExpr a) { return logical_not(a); }
PrimExpr logical_not(PrimExpr a, Span span) {
ICHECK(a.dtype().is_bool());
type_check_boolean_args(a, "! operator (logical NOT)");
PrimExpr ret = arith::TryConstFold<tir::Not>(a);
if (ret.defined()) return ret;
return tir::Not(a, span);
Expand All @@ -550,8 +576,8 @@ PrimExpr logical_not(PrimExpr a, Span span) {
PrimExpr operator>>(PrimExpr a, PrimExpr b) { return right_shift(a, b); }

PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint());
ICHECK(b.dtype().is_int() || b.dtype().is_uint());
type_check_integer_args(a, b, ">> operator (right shift)");

BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
Expand All @@ -573,8 +599,7 @@ PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span) {
// shift left
PrimExpr operator<<(PrimExpr a, PrimExpr b) { return left_shift(a, b); }
PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint());
ICHECK(b.dtype().is_int() || b.dtype().is_uint());
type_check_integer_args(a, b, "<< operator (left shift)");
BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
Expand All @@ -593,8 +618,7 @@ PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) {
// bitwise and
PrimExpr operator&(PrimExpr a, PrimExpr b) { return bitwise_and(a, b); }
PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint());
ICHECK(b.dtype().is_int() || b.dtype().is_uint());
type_check_integer_args(a, b, "& operator (bitwise AND)");
BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
Expand All @@ -606,8 +630,7 @@ PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) {
// bitwise_or
PrimExpr operator|(PrimExpr a, PrimExpr b) { return bitwise_or(a, b); }
PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint());
ICHECK(b.dtype().is_int() || b.dtype().is_uint());
type_check_integer_args(a, b, "| operator (bitwise OR)");
BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
Expand All @@ -619,8 +642,7 @@ PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) {
// bitwise_xor
PrimExpr operator^(PrimExpr a, PrimExpr b) { return bitwise_xor(a, b); }
PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint());
ICHECK(b.dtype().is_int() || b.dtype().is_uint());
type_check_integer_args(a, b, "^ operator (bitwise XOR)");
BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
Expand All @@ -633,7 +655,7 @@ PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) {
PrimExpr operator~(PrimExpr a) { return bitwise_neg(a); }

PrimExpr bitwise_neg(PrimExpr a, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint());
type_check_integer_args(a, "~ operator (bitwise NOT)");
return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, span);
}

Expand Down Expand Up @@ -728,7 +750,7 @@ PrimExpr sum(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span sp
}

PrimExpr all(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span span) {
ICHECK(source.dtype().is_bool());
type_check_boolean_args(source, "tvm::all");
Var x("x", source.dtype(), span), y("y", source.dtype());
PrimExpr result = tir::And(x, y, span);
PrimExpr identity_element = make_const(source.dtype(), true, span);
Expand All @@ -737,7 +759,7 @@ PrimExpr all(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span sp
}

PrimExpr any(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span span) {
ICHECK(source.dtype().is_bool());
type_check_boolean_args(source, "tvm::any");
Var x("x", source.dtype(), span), y("y", source.dtype(), span);
PrimExpr result = tir::Or(x, y, span);
PrimExpr identity_element = make_const(source.dtype(), false, span);
Expand Down

0 comments on commit cb1617d

Please sign in to comment.