Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't introduce reinterprets in find/lower intrinsics #7776

Merged
merged 1 commit into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 27 additions & 24 deletions src/FindIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,19 @@ Expr to_rounding_shift(const Call *c) {
return rounding_shift(cast(add->type, add->args[0]), b);
}
}
// Also need to handle the annoying case of a reinterpret wrapping a widen_right_add

// Also need to handle the annoying case of a reinterpret cast wrapping a widen_right_add
// TODO: this pattern makes me want to change the semantics of this op.
if (const Reinterpret *reinterp = a.as<Reinterpret>()) {
if (reinterp->type.bits() == reinterp->value.type().bits()) {
if (const Call *add = Call::as_intrinsic(reinterp->value, {Call::widen_right_add})) {
if (const Cast *cast = a.as<Cast>()) {
if (cast->is_reinterpret()) {
if (const Call *add = Call::as_intrinsic(cast->value, {Call::widen_right_add})) {
if (can_prove(lower_intrinsics(add->args[1] == round))) {
// We expect the first operand to be a reinterpet.
const Reinterpret *reinterp_a = add->args[0].as<Reinterpret>();
internal_assert(reinterp_a) << "Failed: " << add->args[0] << "\n";
return rounding_shift(reinterp_a->value, b);
// We expect the first operand to be a reinterpet cast.
if (const Cast *cast_a = add->args[0].as<Cast>()) {
if (cast_a->is_reinterpret()) {
return rounding_shift(cast_a->value, b);
}
}
}
}
}
Expand Down Expand Up @@ -245,9 +248,9 @@ class FindIntrinsics : public IRMutator {
if (b.type().code() != narrow_a.type().code()) {
// Need to do a safe reinterpret.
Type t = b.type().with_code(code);
result = widen_right_add(reinterpret(t, b), narrow_a);
result = widen_right_add(cast(t, b), narrow_a);
internal_assert(result.type() != op->type);
result = reinterpret(op->type, result);
result = cast(op->type, result);
} else {
result = widen_right_add(b, narrow_a);
}
Expand All @@ -258,9 +261,9 @@ class FindIntrinsics : public IRMutator {
if (a.type().code() != narrow_b.type().code()) {
// Need to do a safe reinterpret.
Type t = a.type().with_code(code);
result = widen_right_add(reinterpret(t, a), narrow_b);
result = widen_right_add(cast(t, a), narrow_b);
internal_assert(result.type() != op->type);
result = reinterpret(op->type, result);
result = cast(op->type, result);
} else {
result = widen_right_add(a, narrow_b);
}
Expand Down Expand Up @@ -328,9 +331,9 @@ class FindIntrinsics : public IRMutator {
if (a.type().code() != narrow_b.type().code()) {
// Need to do a safe reinterpret.
Type t = a.type().with_code(code);
result = widen_right_sub(reinterpret(t, a), narrow_b);
result = widen_right_sub(cast(t, a), narrow_b);
internal_assert(result.type() != op->type);
result = reinterpret(op->type, result);
result = cast(op->type, result);
} else {
result = widen_right_sub(a, narrow_b);
}
Expand Down Expand Up @@ -410,9 +413,9 @@ class FindIntrinsics : public IRMutator {
if (b.type().code() != narrow_a.type().code()) {
// Need to do a safe reinterpret.
Type t = b.type().with_code(code);
result = widen_right_mul(reinterpret(t, b), narrow_a);
result = widen_right_mul(cast(t, b), narrow_a);
internal_assert(result.type() != op->type);
result = reinterpret(op->type, result);
result = cast(op->type, result);
} else {
result = widen_right_mul(b, narrow_a);
}
Expand All @@ -423,9 +426,9 @@ class FindIntrinsics : public IRMutator {
if (a.type().code() != narrow_b.type().code()) {
// Need to do a safe reinterpret.
Type t = a.type().with_code(code);
result = widen_right_mul(reinterpret(t, a), narrow_b);
result = widen_right_mul(cast(t, a), narrow_b);
internal_assert(result.type() != op->type);
result = reinterpret(op->type, result);
result = cast(op->type, result);
} else {
result = widen_right_mul(a, narrow_b);
}
Expand Down Expand Up @@ -1261,8 +1264,8 @@ Expr lower_saturating_add(const Expr &a, const Expr &b) {
return select(sum < a, a.type().max(), sum);
} else if (a.type().is_int()) {
Type u = a.type().with_code(halide_type_uint);
Expr ua = reinterpret(u, a);
Expr ub = reinterpret(u, b);
Expr ua = cast(u, a);
Expr ub = cast(u, b);
Expr upper = make_const(u, (uint64_t(1) << (a.type().bits() - 1)) - 1);
Expr lower = make_const(u, (uint64_t(1) << (a.type().bits() - 1)));
Expr sum = ua + ub;
Expand All @@ -1272,7 +1275,7 @@ Expr lower_saturating_add(const Expr &a, const Expr &b) {
// a + b >= 0 === a >= -b === a >= ~b + 1 === a > ~b
Expr pos_result = min(sum, upper);
Expr neg_result = max(sum, lower);
return simplify(reinterpret(a.type(), select(~b < a, pos_result, neg_result)));
return simplify(cast(a.type(), select(~b < a, pos_result, neg_result)));
} else {
internal_error << "Bad type for saturating_add: " << a.type() << "\n";
return Expr();
Expand All @@ -1288,8 +1291,8 @@ Expr lower_saturating_sub(const Expr &a, const Expr &b) {
} else if (a.type().is_int()) {
// Do the math in unsigned, to avoid overflow in the simplifier.
Type u = a.type().with_code(halide_type_uint);
Expr ua = reinterpret(u, a);
Expr ub = reinterpret(u, b);
Expr ua = cast(u, a);
Expr ub = cast(u, b);
Expr upper = make_const(u, (uint64_t(1) << (a.type().bits() - 1)) - 1);
Expr lower = make_const(u, (uint64_t(1) << (a.type().bits() - 1)));
Expr diff = ua - ub;
Expand All @@ -1300,7 +1303,7 @@ Expr lower_saturating_sub(const Expr &a, const Expr &b) {
// and saturate the negative difference to be at least -2^31 + 2^32 = 2^31
Expr neg_diff = max(lower, diff);
// Then select between them, and cast back to the signed type.
return simplify(reinterpret(a.type(), select(b <= a, pos_diff, neg_diff)));
return simplify(cast(a.type(), select(b <= a, pos_diff, neg_diff)));
} else if (a.type().is_uint()) {
return simplify(select(b < a, a - b, make_zero(a.type())));
} else {
Expand Down
7 changes: 7 additions & 0 deletions src/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ struct Cast : public ExprNode<Cast> {
static Expr make(Type t, Expr v);

static const IRNodeType _node_type = IRNodeType::Cast;

/** Check if the cast is equivalent to a reinterpret. */
bool is_reinterpret() const {
return (type.is_int_or_uint() &&
value.type().is_int_or_uint() &&
type.bits() == value.type().bits());
}
};

/** Reinterpret value as another type, without affecting any of the bits
Expand Down