Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Improved error messages for PrimExpr operator overloads #12638

Merged
merged 1 commit into from
Aug 30, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 40 additions & 18 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -520,27 +520,53 @@ PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span) {
return tir::NE(a, b, span);
}

namespace {
void type_check_boolean_args(const PrimExpr& arg, const char* op) {
ICHECK(arg.dtype().is_bool()) << "Expected boolean argument for " << op << ", but received "
<< arg << " of type " << arg.dtype();
}
void type_check_boolean_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) {
ICHECK(lhs.dtype().is_bool()) << "Expected boolean argument as LHS of " << op << ", but received "
<< lhs << " of type " << lhs.dtype();
ICHECK(rhs.dtype().is_bool()) << "Expected boolean argument as RHS of " << op << ", but received "
<< rhs << " of type " << rhs.dtype();
}

void type_check_integer_args(const PrimExpr& arg, const char* op) {
ICHECK(arg.dtype().is_int() || arg.dtype().is_uint())
<< "Expected integer argument for " << op << ", but received " << arg << " of type "
<< arg.dtype();
}

void type_check_integer_args(const PrimExpr& lhs, const PrimExpr& rhs, const char* op) {
ICHECK(lhs.dtype().is_int() || lhs.dtype().is_uint())
<< "Expected integer argument as LHS of " << op << ", but received " << lhs << " of type "
<< lhs.dtype();
ICHECK(rhs.dtype().is_int() || rhs.dtype().is_uint())
<< "Expected integer argument as RHS of " << op << ", but received " << rhs << " of type "
<< rhs.dtype();
}
} // namespace

PrimExpr operator&&(PrimExpr a, PrimExpr b) { return logical_and(a, b); }
PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_bool());
ICHECK(b.dtype().is_bool());
type_check_boolean_args(a, b, "&& operator (logical AND)");
PrimExpr ret = arith::TryConstFold<tir::And>(a, b);
if (ret.defined()) return ret;
return tir::And(a, b, span);
}

PrimExpr operator||(PrimExpr a, PrimExpr b) { return logical_or(a, b); }
PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_bool());
ICHECK(b.dtype().is_bool());
type_check_boolean_args(a, b, "|| operator (logical OR)");
PrimExpr ret = arith::TryConstFold<tir::Or>(a, b);
if (ret.defined()) return ret;
return tir::Or(a, b, span);
}

PrimExpr operator!(PrimExpr a) { return logical_not(a); }
PrimExpr logical_not(PrimExpr a, Span span) {
ICHECK(a.dtype().is_bool());
type_check_boolean_args(a, "! operator (logical NOT)");
PrimExpr ret = arith::TryConstFold<tir::Not>(a);
if (ret.defined()) return ret;
return tir::Not(a, span);
Expand All @@ -550,8 +576,8 @@ PrimExpr logical_not(PrimExpr a, Span span) {
PrimExpr operator>>(PrimExpr a, PrimExpr b) { return right_shift(a, b); }

PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint());
ICHECK(b.dtype().is_int() || b.dtype().is_uint());
type_check_integer_args(a, b, ">> operator (right shift)");

BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
Expand All @@ -573,8 +599,7 @@ PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span) {
// shift left
PrimExpr operator<<(PrimExpr a, PrimExpr b) { return left_shift(a, b); }
PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint());
ICHECK(b.dtype().is_int() || b.dtype().is_uint());
type_check_integer_args(a, b, "<< operator (left shift)");
BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
Expand All @@ -593,8 +618,7 @@ PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) {
// bitwise and
PrimExpr operator&(PrimExpr a, PrimExpr b) { return bitwise_and(a, b); }
PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint());
ICHECK(b.dtype().is_int() || b.dtype().is_uint());
type_check_integer_args(a, b, "& operator (bitwise AND)");
BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
Expand All @@ -606,8 +630,7 @@ PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) {
// bitwise_or
PrimExpr operator|(PrimExpr a, PrimExpr b) { return bitwise_or(a, b); }
PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint());
ICHECK(b.dtype().is_int() || b.dtype().is_uint());
type_check_integer_args(a, b, "| operator (bitwise OR)");
BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
Expand All @@ -619,8 +642,7 @@ PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) {
// bitwise_xor
PrimExpr operator^(PrimExpr a, PrimExpr b) { return bitwise_xor(a, b); }
PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint());
ICHECK(b.dtype().is_int() || b.dtype().is_uint());
type_check_integer_args(a, b, "^ operator (bitwise XOR)");
BinaryOpMatchTypes(a, b, span);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
Expand All @@ -633,7 +655,7 @@ PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) {
PrimExpr operator~(PrimExpr a) { return bitwise_neg(a); }

PrimExpr bitwise_neg(PrimExpr a, Span span) {
ICHECK(a.dtype().is_int() || a.dtype().is_uint());
type_check_integer_args(a, "~ operator (bitwise NOT)");
return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, span);
}

Expand Down Expand Up @@ -728,7 +750,7 @@ PrimExpr sum(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span sp
}

PrimExpr all(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span span) {
ICHECK(source.dtype().is_bool());
type_check_boolean_args(source, "tvm::all");
Var x("x", source.dtype(), span), y("y", source.dtype());
PrimExpr result = tir::And(x, y, span);
PrimExpr identity_element = make_const(source.dtype(), true, span);
Expand All @@ -737,7 +759,7 @@ PrimExpr all(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span sp
}

PrimExpr any(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init, Span span) {
ICHECK(source.dtype().is_bool());
type_check_boolean_args(source, "tvm::any");
Var x("x", source.dtype(), span), y("y", source.dtype(), span);
PrimExpr result = tir::Or(x, y, span);
PrimExpr identity_element = make_const(source.dtype(), false, span);
Expand Down