Skip to content

Commit

Permalink
Compute comparison masks in narrower types if possible (halide#7392)
Browse files Browse the repository at this point in the history
* Compute comparison masks in narrower types if possible

* Remove reliance on infinite precision int32s

* Further elaborate on comment

* Lower signed saturating_add and sub to unsigned math

The existing lowering was prone to overflow

* cast -> reinterpret
  • Loading branch information
abadams authored and ardier committed Mar 3, 2024
1 parent 08a4226 commit 4d97824
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 6 deletions.
179 changes: 174 additions & 5 deletions src/FindIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,130 @@ class FindIntrinsics : public IRMutator {
}
return op;
}

// Narrow and comparisons between ramps and broadcasts to produce masks that
// match the bit-width of the type being selected between or loaded or
// stored. We do this late in lowering in this pass instead of in the
// simplifier because it messes up the reasoning done by loop partitioning.
//
// For example if we're selecting between uint8s, and we have the condition:
// ramp(x, 1, 16) < broadcast(y)
// where x is an Int(32), we can rewrite this to:
// ramp(0, 1, 16) < broadcast(y - x)
// which can be safely narrowed to:
// cast<int8_t>(ramp(0, 1, 16)) < saturating_cast<int8_t>(broadcast(y - x))
Expr narrow_predicate(const Expr &p, Type t) {
if (t.bits() >= 32) {
return p;
}

int lanes = t.lanes();

if (const Or *op = p.as<Or>()) {
return narrow_predicate(op->a, t) || narrow_predicate(op->b, t);
} else if (const And *op = p.as<And>()) {
return narrow_predicate(op->a, t) && narrow_predicate(op->b, t);
} else if (const Not *op = p.as<Not>()) {
return !narrow_predicate(op->a, t);
}

const LT *lt = p.as<LT>();
const LE *le = p.as<LE>();
// Check it's a comparison
if (!(le || lt)) {
return p;
}
// Check it's an int32 comparison
if ((lt ? lt->a.type() : le->a.type()) != Int(32, lanes)) {
return p;
}

auto rewrite = IRMatcher::rewriter(p, Int(32, lanes));

// Construct predicates which state the ramp can't hit the extreme
// values of an int8 or an int16, so that the saturated broadcast has a
// value to take on that leaves it clear of the bounds of the ramp. This
// is an overconservative condition, but it's hard to imagine cases
// where a more precise condition would be necessary.
auto min_ramp_lane = min(c0, c0 * (lanes - 1));
auto max_ramp_lane = max(c0, c0 * (lanes - 1));
auto ramp_fits_in_i8 = min_ramp_lane > -128 && max_ramp_lane < 127;
auto ramp_fits_in_i16 = min_ramp_lane > -32768 && max_ramp_lane < 32767;
auto saturated_diff_i8 = saturating_cast(Int(8), saturating_sub(x, y));
auto saturated_diff_i16 = saturating_cast(Int(16), saturating_sub(x, y));

if ((t.bits() <= 8 &&
// Try to narrow to 8-bit comparisons
(rewrite(broadcast(x, lanes) < ramp(y, c0, lanes),
broadcast(saturated_diff_i8, lanes) < cast(Int(8, lanes), ramp(0, c0, lanes)),
ramp_fits_in_i8) ||

rewrite(ramp(y, c0, lanes) < broadcast(x, lanes),
cast(Int(8, lanes), ramp(0, c0, lanes)) < broadcast(saturated_diff_i8, lanes),
ramp_fits_in_i8) ||

rewrite(broadcast(x, lanes) <= ramp(y, c0, lanes),
broadcast(saturated_diff_i8, lanes) <= cast(Int(8, lanes), ramp(0, c0, lanes)),
ramp_fits_in_i8) ||

rewrite(ramp(y, c0, lanes) <= broadcast(x, lanes),
cast(Int(8, lanes), ramp(0, c0, lanes)) <= broadcast(saturated_diff_i8, lanes),
ramp_fits_in_i8))) ||

// Try to narrow to 16-bit comparisons
rewrite(broadcast(x, lanes) < ramp(y, c0, lanes),
broadcast(saturated_diff_i16, lanes) < cast(Int(16, lanes), ramp(0, c0, lanes)),
ramp_fits_in_i16) ||

rewrite(ramp(y, c0, lanes) < broadcast(x, lanes),
cast(Int(16, lanes), ramp(0, c0, lanes)) < broadcast(saturated_diff_i16, lanes),
ramp_fits_in_i16) ||

rewrite(broadcast(x, lanes) <= ramp(y, c0, lanes),
broadcast(saturated_diff_i16, lanes) <= cast(Int(16, lanes), ramp(0, c0, lanes)),
ramp_fits_in_i16) ||

rewrite(ramp(y, c0, lanes) <= broadcast(x, lanes),
cast(Int(16, lanes), ramp(0, c0, lanes)) <= broadcast(saturated_diff_i16, lanes),
ramp_fits_in_i16)) {
return rewrite.result;
} else {
return p;
}
}

Expr visit(const Select *op) override {
Expr condition = mutate(op->condition);
Expr true_value = mutate(op->true_value);
Expr false_value = mutate(op->false_value);
condition = narrow_predicate(condition, op->type);
return Select::make(condition, true_value, false_value);
}

Expr visit(const Load *op) override {
Expr predicate = mutate(op->predicate);
Expr index = mutate(op->index);
predicate = narrow_predicate(predicate, op->type);
if (predicate.same_as(op->predicate) && index.same_as(op->index)) {
return op;
} else {
return Load::make(op->type, op->name, std::move(index),
op->image, op->param, std::move(predicate),
op->alignment);
}
}

Stmt visit(const Store *op) override {
Expr predicate = mutate(op->predicate);
Expr value = mutate(op->value);
Expr index = mutate(op->index);
predicate = narrow_predicate(predicate, value.type());
if (predicate.same_as(op->predicate) && value.same_as(op->value) && index.same_as(op->index)) {
return op;
} else {
return Store::make(op->name, std::move(value), std::move(index), op->param, std::move(predicate), op->alignment);
}
}
};

// Substitute in let values than have an output vector
Expand Down Expand Up @@ -1127,17 +1251,62 @@ Expr lower_rounding_shift_right(const Expr &a, const Expr &b) {
}

Expr lower_saturating_add(const Expr &a, const Expr &b) {
internal_assert(a.type() == b.type());
// Lower saturating add without using widening arithmetic, which may require
// types that aren't supported.
return simplify(clamp(a, a.type().min() - min(b, 0), a.type().max() - max(b, 0))) + b;
internal_assert(a.type() == b.type());
if (a.type().is_float()) {
return a + b;
} else if (a.type().is_uint()) {
Expr sum = a + b;
return select(sum < a, a.type().max(), sum);
} else if (a.type().is_int()) {
Type u = a.type().with_code(halide_type_uint);
Expr ua = reinterpret(u, a);
Expr ub = reinterpret(u, b);
Expr upper = make_const(u, (uint64_t(1) << (a.type().bits() - 1)) - 1);
Expr lower = make_const(u, (uint64_t(1) << (a.type().bits() - 1)));
Expr sum = ua + ub;
// For a 32-bit input, 'sum' is the low 32 bits of the true 33-bit
// sum. So it's the true sum, possibly plus 2^32 in the case where the
// true sum is supposed to be negative. The true sum is positive when:
// a + b >= 0 === a >= -b === a >= ~b + 1 === a > ~b
Expr pos_result = min(sum, upper);
Expr neg_result = max(sum, lower);
return simplify(reinterpret(a.type(), select(~b < a, pos_result, neg_result)));
} else {
internal_error << "Bad type for saturating_add: " << a.type() << "\n";
return Expr();
}
}

Expr lower_saturating_sub(const Expr &a, const Expr &b) {
internal_assert(a.type() == b.type());
// Lower saturating add without using widening arithmetic, which may require
// Lower saturating sub without using widening arithmetic, which may require
// types that aren't supported.
return simplify(clamp(a, a.type().min() + max(b, 0), a.type().max() + min(b, 0))) - b;
internal_assert(a.type() == b.type());
if (a.type().is_float()) {
return a - b;
} else if (a.type().is_int()) {
// Do the math in unsigned, to avoid overflow in the simplifier.
Type u = a.type().with_code(halide_type_uint);
Expr ua = reinterpret(u, a);
Expr ub = reinterpret(u, b);
Expr upper = make_const(u, (uint64_t(1) << (a.type().bits() - 1)) - 1);
Expr lower = make_const(u, (uint64_t(1) << (a.type().bits() - 1)));
Expr diff = ua - ub;
// If a >= b, then diff is the (positive) difference. If a < b then diff
// is the (negative) difference plus 2^32 due to wraparound.
// We saturate the positive difference to be at most 2^31 - 1
Expr pos_diff = min(upper, diff);
// and saturate the negative difference to be at least -2^31 + 2^32 = 2^31
Expr neg_diff = max(lower, diff);
// Then select between them, and cast back to the signed type.
return simplify(reinterpret(a.type(), select(b <= a, pos_diff, neg_diff)));
} else if (a.type().is_uint()) {
return simplify(select(b < a, a - b, make_zero(a.type())));
} else {
internal_error << "Bad type for saturating_sub: " << a.type() << "\n";
return Expr();
}
}

Expr lower_saturating_cast(const Type &t, const Expr &a) {
Expand Down
2 changes: 1 addition & 1 deletion src/Simplify_Internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
}
}

#if (LOG_EXPR_MUTATORIONS || LOG_STMT_MUTATIONS)
#if (LOG_EXPR_MUTATIONS || LOG_STMT_MUTATIONS)
static int debug_indent;
#endif

Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ tests(GROUPS correctness
multipass_constraints.cpp
multiple_outputs.cpp
mux.cpp
narrow_predicates.cpp
nested_tail_strategies.cpp
newtons_method.cpp
non_nesting_extern_bounds_query.cpp
Expand Down
57 changes: 57 additions & 0 deletions test/correctness/narrow_predicates.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#include "Halide.h"

using namespace Halide;

Var x;

template<typename T>
void check(const Expr &e) {
Func g1, g2;
g1(x) = e;
g2(x) = e;

// Introduce some vector predicates to g1
g1.vectorize(x, 64, TailStrategy::GuardWithIf);

Buffer<T> b1(1024), b2(1024);
g1.realize(b1);
g2.realize(b2);

for (int i = 0; i < b1.width(); i++) {
if (b1(i) != b2(i)) {
printf("b1(%d) = %d instead of %d\n",
i, b1(i), b2(i));
exit(-1);
}
}
}

template<typename T>
void check_all() {
Func f;
f(x) = cast<T>(x);
f.compute_root();

// This will have a predicated instruction in the loop tail:
check<T>(f(x));

// These will also have a comparison mask in the loop body:
check<T>(select(x < 50, f(x), cast<T>(17)));
check<T>(select(x > 50, f(x), cast<T>(17)));

// Also test boundary conditions, which introduce all sorts of coordinate
// comparisons:
check<T>(BoundaryConditions::repeat_edge(f, {{10, 100}})(x));
check<T>(BoundaryConditions::repeat_image(f, {{10, 100}})(x));
check<T>(BoundaryConditions::constant_exterior(f, cast<T>(17), {{10, 100}})(x));
check<T>(BoundaryConditions::mirror_image(f, {{10, 100}})(x));
check<T>(BoundaryConditions::mirror_interior(f, {{10, 100}})(x));
}

int main(int argc, char **argv) {
check_all<uint8_t>();
check_all<uint16_t>();

printf("Success!\n");
return 0;
}

0 comments on commit 4d97824

Please sign in to comment.