Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute comparison masks in narrower types if possible #7392

Merged
merged 9 commits into from
Mar 25, 2023
179 changes: 174 additions & 5 deletions src/FindIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,130 @@ class FindIntrinsics : public IRMutator {
}
return op;
}

// Narrow and comparisons between ramps and broadcasts to produce masks that
// match the bit-width of the type being selected between or loaded or
// stored. We do this late in lowering in this pass instead of in the
// simplifier because it messes up the reasoning done by loop partitioning.
//
// For example if we're selecting between uint8s, and we have the condition:
// ramp(x, 1, 16) < broadcast(y)
// where x is an Int(32), we can rewrite this to:
// ramp(0, 1, 16) < broadcast(y - x)
// which can be safely narrowed to:
// cast<int8_t>(ramp(0, 1, 16)) < saturating_cast<int8_t>(broadcast(y - x))
Expr narrow_predicate(const Expr &p, Type t) {
if (t.bits() >= 32) {
return p;
}

int lanes = t.lanes();

if (const Or *op = p.as<Or>()) {
return narrow_predicate(op->a, t) || narrow_predicate(op->b, t);
} else if (const And *op = p.as<And>()) {
return narrow_predicate(op->a, t) && narrow_predicate(op->b, t);
} else if (const Not *op = p.as<Not>()) {
return !narrow_predicate(op->a, t);
}

const LT *lt = p.as<LT>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably we've normalized away the GT and GE at this point?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, those should be long-gone

const LE *le = p.as<LE>();
// Check it's a comparison
if (!(le || lt)) {
return p;
}
// Check it's an int32 comparison
if ((lt ? lt->a.type() : le->a.type()) != Int(32, lanes)) {
return p;
}

auto rewrite = IRMatcher::rewriter(p, Int(32, lanes));

// Construct predicates which state the ramp can't hit the extreme
// values of an int8 or an int16, so that the saturated broadcast has a
// value to take on that leaves it clear of the bounds of the ramp. This
// is an overconservative condition, but it's hard to imagine cases
// where a more precise condition would be necessary.
auto min_ramp_lane = min(c0, c0 * (lanes - 1));
auto max_ramp_lane = max(c0, c0 * (lanes - 1));
auto ramp_fits_in_i8 = min_ramp_lane > -128 && max_ramp_lane < 127;
auto ramp_fits_in_i16 = min_ramp_lane > -32768 && max_ramp_lane < 32767;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want a comment here explaining why these inequalities are strict. It took me a minute to work through why.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment immediately above not sufficient? I'm checking they can't hit the extreme values of the narrower type. I think in some cases it's fine, but it's quite hard to think about.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think an explanation of why this condition is necessary would be helpful. Also I think the inequality only needs to be strict for a strict <, pretty sure the inequality can be non-strict (is there a word for this?) for a <=

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment updated. I found it very hard to think through the cases, and just ended up using the most conservative condition for all of them.

auto saturated_diff_i8 = saturating_cast(Int(8), saturating_sub(x, y));
auto saturated_diff_i16 = saturating_cast(Int(16), saturating_sub(x, y));

if ((t.bits() <= 8 &&
// Try to narrow to 8-bit comparisons
(rewrite(broadcast(x, lanes) < ramp(y, c0, lanes),
broadcast(saturated_diff_i8, lanes) < cast(Int(8, lanes), ramp(0, c0, lanes)),
ramp_fits_in_i8) ||

rewrite(ramp(y, c0, lanes) < broadcast(x, lanes),
cast(Int(8, lanes), ramp(0, c0, lanes)) < broadcast(saturated_diff_i8, lanes),
ramp_fits_in_i8) ||

rewrite(broadcast(x, lanes) <= ramp(y, c0, lanes),
broadcast(saturated_diff_i8, lanes) <= cast(Int(8, lanes), ramp(0, c0, lanes)),
ramp_fits_in_i8) ||

rewrite(ramp(y, c0, lanes) <= broadcast(x, lanes),
cast(Int(8, lanes), ramp(0, c0, lanes)) <= broadcast(saturated_diff_i8, lanes),
ramp_fits_in_i8))) ||

// Try to narrow to 16-bit comparisons
rewrite(broadcast(x, lanes) < ramp(y, c0, lanes),
broadcast(saturated_diff_i16, lanes) < cast(Int(16, lanes), ramp(0, c0, lanes)),
ramp_fits_in_i16) ||

rewrite(ramp(y, c0, lanes) < broadcast(x, lanes),
cast(Int(16, lanes), ramp(0, c0, lanes)) < broadcast(saturated_diff_i16, lanes),
ramp_fits_in_i16) ||

rewrite(broadcast(x, lanes) <= ramp(y, c0, lanes),
broadcast(saturated_diff_i16, lanes) <= cast(Int(16, lanes), ramp(0, c0, lanes)),
ramp_fits_in_i16) ||

rewrite(ramp(y, c0, lanes) <= broadcast(x, lanes),
cast(Int(16, lanes), ramp(0, c0, lanes)) <= broadcast(saturated_diff_i16, lanes),
ramp_fits_in_i16)) {
return rewrite.result;
} else {
return p;
}
}

Expr visit(const Select *op) override {
Expr condition = mutate(op->condition);
Expr true_value = mutate(op->true_value);
Expr false_value = mutate(op->false_value);
condition = narrow_predicate(condition, op->type);
return Select::make(condition, true_value, false_value);
}

Expr visit(const Load *op) override {
Expr predicate = mutate(op->predicate);
Expr index = mutate(op->index);
predicate = narrow_predicate(predicate, op->type);
if (predicate.same_as(op->predicate) && index.same_as(op->index)) {
return op;
} else {
return Load::make(op->type, op->name, std::move(index),
op->image, op->param, std::move(predicate),
op->alignment);
}
}

Stmt visit(const Store *op) override {
Expr predicate = mutate(op->predicate);
Expr value = mutate(op->value);
Expr index = mutate(op->index);
predicate = narrow_predicate(predicate, value.type());
if (predicate.same_as(op->predicate) && value.same_as(op->value) && index.same_as(op->index)) {
return op;
} else {
return Store::make(op->name, std::move(value), std::move(index), op->param, std::move(predicate), op->alignment);
}
}
};

// Substitute in let values than have an output vector
Expand Down Expand Up @@ -1119,17 +1243,62 @@ Expr lower_rounding_shift_right(const Expr &a, const Expr &b) {
}

Expr lower_saturating_add(const Expr &a, const Expr &b) {
internal_assert(a.type() == b.type());
// Lower saturating add without using widening arithmetic, which may require
// types that aren't supported.
return simplify(clamp(a, a.type().min() - min(b, 0), a.type().max() - max(b, 0))) + b;
internal_assert(a.type() == b.type());
if (a.type().is_float()) {
return a + b;
} else if (a.type().is_uint()) {
Expr sum = a + b;
return select(sum < a, a.type().max(), sum);
} else if (a.type().is_int()) {
Type u = a.type().with_code(halide_type_uint);
Expr ua = reinterpret(u, a);
Expr ub = reinterpret(u, b);
Expr upper = make_const(u, (uint64_t(1) << (a.type().bits() - 1)) - 1);
Expr lower = make_const(u, (uint64_t(1) << (a.type().bits() - 1)));
Expr sum = ua + ub;
// For a 32-bit input, 'sum' is the low 32 bits of the true 33-bit
// sum. So it's the true sum, possibly plus 2^32 in the case where the
// true sum is supposed to be negative. The true sum is positive when:
// a + b >= 0 === a >= -b === a >= ~b + 1 === a > ~b
Expr pos_result = min(sum, upper);
Expr neg_result = max(sum, lower);
return simplify(reinterpret(a.type(), select(~b < a, pos_result, neg_result)));
} else {
internal_error << "Bad type for saturating_add: " << a.type() << "\n";
return Expr();
}
}

Expr lower_saturating_sub(const Expr &a, const Expr &b) {
internal_assert(a.type() == b.type());
// Lower saturating add without using widening arithmetic, which may require
// Lower saturating sub without using widening arithmetic, which may require
// types that aren't supported.
return simplify(clamp(a, a.type().min() + max(b, 0), a.type().max() + min(b, 0))) - b;
internal_assert(a.type() == b.type());
if (a.type().is_float()) {
return a - b;
} else if (a.type().is_int()) {
// Do the math in unsigned, to avoid overflow in the simplifier.
Type u = a.type().with_code(halide_type_uint);
Expr ua = reinterpret(u, a);
Expr ub = reinterpret(u, b);
Expr upper = make_const(u, (uint64_t(1) << (a.type().bits() - 1)) - 1);
Expr lower = make_const(u, (uint64_t(1) << (a.type().bits() - 1)));
Expr diff = ua - ub;
// If a >= b, then diff is the (positive) difference. If a < b then diff
// is the (negative) difference plus 2^32 due to wraparound.
// We saturate the positive difference to be at most 2^31 - 1
Expr pos_diff = min(upper, diff);
// and saturate the negative difference to be at least -2^31 + 2^32 = 2^31
Expr neg_diff = max(lower, diff);
// Then select between them, and cast back to the signed type.
return simplify(reinterpret(a.type(), select(b <= a, pos_diff, neg_diff)));
} else if (a.type().is_uint()) {
return simplify(select(b < a, a - b, make_zero(a.type())));
} else {
internal_error << "Bad type for saturating_sub: " << a.type() << "\n";
return Expr();
}
}

Expr lower_saturating_cast(const Type &t, const Expr &a) {
Expand Down
2 changes: 1 addition & 1 deletion src/Simplify_Internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
}
}

#if (LOG_EXPR_MUTATORIONS || LOG_STMT_MUTATIONS)
#if (LOG_EXPR_MUTATIONS || LOG_STMT_MUTATIONS)
static int debug_indent;
#endif

Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ tests(GROUPS correctness
multipass_constraints.cpp
multiple_outputs.cpp
mux.cpp
narrow_predicates.cpp
nested_tail_strategies.cpp
newtons_method.cpp
non_nesting_extern_bounds_query.cpp
Expand Down
57 changes: 57 additions & 0 deletions test/correctness/narrow_predicates.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#include "Halide.h"

using namespace Halide;

Var x;

template<typename T>
void check(const Expr &e) {
Func g1, g2;
g1(x) = e;
g2(x) = e;

// Introduce some vector predicates to g1
g1.vectorize(x, 64, TailStrategy::GuardWithIf);

Buffer<T> b1(1024), b2(1024);
g1.realize(b1);
g2.realize(b2);

for (int i = 0; i < b1.width(); i++) {
if (b1(i) != b2(i)) {
printf("b1(%d) = %d instead of %d\n",
i, b1(i), b2(i));
exit(-1);
}
}
}

template<typename T>
void check_all() {
Func f;
f(x) = cast<T>(x);
f.compute_root();

// This will have a predicated instruction in the loop tail:
check<T>(f(x));

// These will also have a comparison mask in the loop body:
check<T>(select(x < 50, f(x), cast<T>(17)));
check<T>(select(x > 50, f(x), cast<T>(17)));

// Also test boundary conditions, which introduce all sorts of coordinate
// comparisons:
check<T>(BoundaryConditions::repeat_edge(f, {{10, 100}})(x));
check<T>(BoundaryConditions::repeat_image(f, {{10, 100}})(x));
check<T>(BoundaryConditions::constant_exterior(f, cast<T>(17), {{10, 100}})(x));
check<T>(BoundaryConditions::mirror_image(f, {{10, 100}})(x));
check<T>(BoundaryConditions::mirror_interior(f, {{10, 100}})(x));
}

int main(int argc, char **argv) {
check_all<uint8_t>();
check_all<uint16_t>();

printf("Success!\n");
return 0;
}