Skip to content

Commit

Permalink
Don't introduce reinterprets in find/lower intrinsics (halide#7776)
Browse files Browse the repository at this point in the history
  • Loading branch information
rootjalex authored and ardier committed Mar 3, 2024
1 parent 6ec7cf0 commit da36e88
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 24 deletions.
51 changes: 27 additions & 24 deletions src/FindIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,19 @@ Expr to_rounding_shift(const Call *c) {
return rounding_shift(cast(add->type, add->args[0]), b);
}
}
// Also need to handle the annoying case of a reinterpret wrapping a widen_right_add

// Also need to handle the annoying case of a reinterpret cast wrapping a widen_right_add
// TODO: this pattern makes me want to change the semantics of this op.
if (const Reinterpret *reinterp = a.as<Reinterpret>()) {
if (reinterp->type.bits() == reinterp->value.type().bits()) {
if (const Call *add = Call::as_intrinsic(reinterp->value, {Call::widen_right_add})) {
if (const Cast *cast = a.as<Cast>()) {
if (cast->is_reinterpret()) {
if (const Call *add = Call::as_intrinsic(cast->value, {Call::widen_right_add})) {
if (can_prove(lower_intrinsics(add->args[1] == round))) {
// We expect the first operand to be a reinterpet.
const Reinterpret *reinterp_a = add->args[0].as<Reinterpret>();
internal_assert(reinterp_a) << "Failed: " << add->args[0] << "\n";
return rounding_shift(reinterp_a->value, b);
// We expect the first operand to be a reinterpet cast.
if (const Cast *cast_a = add->args[0].as<Cast>()) {
if (cast_a->is_reinterpret()) {
return rounding_shift(cast_a->value, b);
}
}
}
}
}
Expand Down Expand Up @@ -245,9 +248,9 @@ class FindIntrinsics : public IRMutator {
if (b.type().code() != narrow_a.type().code()) {
// Need to do a safe reinterpret.
Type t = b.type().with_code(code);
result = widen_right_add(reinterpret(t, b), narrow_a);
result = widen_right_add(cast(t, b), narrow_a);
internal_assert(result.type() != op->type);
result = reinterpret(op->type, result);
result = cast(op->type, result);
} else {
result = widen_right_add(b, narrow_a);
}
Expand All @@ -258,9 +261,9 @@ class FindIntrinsics : public IRMutator {
if (a.type().code() != narrow_b.type().code()) {
// Need to do a safe reinterpret.
Type t = a.type().with_code(code);
result = widen_right_add(reinterpret(t, a), narrow_b);
result = widen_right_add(cast(t, a), narrow_b);
internal_assert(result.type() != op->type);
result = reinterpret(op->type, result);
result = cast(op->type, result);
} else {
result = widen_right_add(a, narrow_b);
}
Expand Down Expand Up @@ -328,9 +331,9 @@ class FindIntrinsics : public IRMutator {
if (a.type().code() != narrow_b.type().code()) {
// Need to do a safe reinterpret.
Type t = a.type().with_code(code);
result = widen_right_sub(reinterpret(t, a), narrow_b);
result = widen_right_sub(cast(t, a), narrow_b);
internal_assert(result.type() != op->type);
result = reinterpret(op->type, result);
result = cast(op->type, result);
} else {
result = widen_right_sub(a, narrow_b);
}
Expand Down Expand Up @@ -410,9 +413,9 @@ class FindIntrinsics : public IRMutator {
if (b.type().code() != narrow_a.type().code()) {
// Need to do a safe reinterpret.
Type t = b.type().with_code(code);
result = widen_right_mul(reinterpret(t, b), narrow_a);
result = widen_right_mul(cast(t, b), narrow_a);
internal_assert(result.type() != op->type);
result = reinterpret(op->type, result);
result = cast(op->type, result);
} else {
result = widen_right_mul(b, narrow_a);
}
Expand All @@ -423,9 +426,9 @@ class FindIntrinsics : public IRMutator {
if (a.type().code() != narrow_b.type().code()) {
// Need to do a safe reinterpret.
Type t = a.type().with_code(code);
result = widen_right_mul(reinterpret(t, a), narrow_b);
result = widen_right_mul(cast(t, a), narrow_b);
internal_assert(result.type() != op->type);
result = reinterpret(op->type, result);
result = cast(op->type, result);
} else {
result = widen_right_mul(a, narrow_b);
}
Expand Down Expand Up @@ -1261,8 +1264,8 @@ Expr lower_saturating_add(const Expr &a, const Expr &b) {
return select(sum < a, a.type().max(), sum);
} else if (a.type().is_int()) {
Type u = a.type().with_code(halide_type_uint);
Expr ua = reinterpret(u, a);
Expr ub = reinterpret(u, b);
Expr ua = cast(u, a);
Expr ub = cast(u, b);
Expr upper = make_const(u, (uint64_t(1) << (a.type().bits() - 1)) - 1);
Expr lower = make_const(u, (uint64_t(1) << (a.type().bits() - 1)));
Expr sum = ua + ub;
Expand All @@ -1272,7 +1275,7 @@ Expr lower_saturating_add(const Expr &a, const Expr &b) {
// a + b >= 0 === a >= -b === a >= ~b + 1 === a > ~b
Expr pos_result = min(sum, upper);
Expr neg_result = max(sum, lower);
return simplify(reinterpret(a.type(), select(~b < a, pos_result, neg_result)));
return simplify(cast(a.type(), select(~b < a, pos_result, neg_result)));
} else {
internal_error << "Bad type for saturating_add: " << a.type() << "\n";
return Expr();
Expand All @@ -1288,8 +1291,8 @@ Expr lower_saturating_sub(const Expr &a, const Expr &b) {
} else if (a.type().is_int()) {
// Do the math in unsigned, to avoid overflow in the simplifier.
Type u = a.type().with_code(halide_type_uint);
Expr ua = reinterpret(u, a);
Expr ub = reinterpret(u, b);
Expr ua = cast(u, a);
Expr ub = cast(u, b);
Expr upper = make_const(u, (uint64_t(1) << (a.type().bits() - 1)) - 1);
Expr lower = make_const(u, (uint64_t(1) << (a.type().bits() - 1)));
Expr diff = ua - ub;
Expand All @@ -1300,7 +1303,7 @@ Expr lower_saturating_sub(const Expr &a, const Expr &b) {
// and saturate the negative difference to be at least -2^31 + 2^32 = 2^31
Expr neg_diff = max(lower, diff);
// Then select between them, and cast back to the signed type.
return simplify(reinterpret(a.type(), select(b <= a, pos_diff, neg_diff)));
return simplify(cast(a.type(), select(b <= a, pos_diff, neg_diff)));
} else if (a.type().is_uint()) {
return simplify(select(b < a, a - b, make_zero(a.type())));
} else {
Expand Down
7 changes: 7 additions & 0 deletions src/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ struct Cast : public ExprNode<Cast> {
static Expr make(Type t, Expr v);

static const IRNodeType _node_type = IRNodeType::Cast;

/** Check if the cast is equivalent to a reinterpret. */
bool is_reinterpret() const {
return (type.is_int_or_uint() &&
value.type().is_int_or_uint() &&
type.bits() == value.type().bits());
}
};

/** Reinterpret value as another type, without affecting any of the bits
Expand Down

0 comments on commit da36e88

Please sign in to comment.