-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compute comparison masks in narrower types if possible #7392
Changes from all commits
7d4aa37
90a2652
99d8d4f
00fecd2
5f04a1f
7a2e7f0
a7aa8cc
30144e6
6f97a80
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -897,6 +897,130 @@ class FindIntrinsics : public IRMutator { | |
} | ||
return op; | ||
} | ||
|
||
// Narrow and comparisons between ramps and broadcasts to produce masks that | ||
// match the bit-width of the type being selected between or loaded or | ||
// stored. We do this late in lowering in this pass instead of in the | ||
// simplifier because it messes up the reasoning done by loop partitioning. | ||
// | ||
// For example if we're selecting between uint8s, and we have the condition: | ||
// ramp(x, 1, 16) < broadcast(y) | ||
// where x is an Int(32), we can rewrite this to: | ||
// ramp(0, 1, 16) < broadcast(y - x) | ||
// which can be safely narrowed to: | ||
// cast<int8_t>(ramp(0, 1, 16)) < saturating_cast<int8_t>(broadcast(y - x)) | ||
Expr narrow_predicate(const Expr &p, Type t) { | ||
if (t.bits() >= 32) { | ||
return p; | ||
} | ||
|
||
int lanes = t.lanes(); | ||
|
||
if (const Or *op = p.as<Or>()) { | ||
return narrow_predicate(op->a, t) || narrow_predicate(op->b, t); | ||
} else if (const And *op = p.as<And>()) { | ||
return narrow_predicate(op->a, t) && narrow_predicate(op->b, t); | ||
} else if (const Not *op = p.as<Not>()) { | ||
return !narrow_predicate(op->a, t); | ||
} | ||
|
||
const LT *lt = p.as<LT>(); | ||
const LE *le = p.as<LE>(); | ||
// Check it's a comparison | ||
if (!(le || lt)) { | ||
return p; | ||
} | ||
// Check it's an int32 comparison | ||
if ((lt ? lt->a.type() : le->a.type()) != Int(32, lanes)) { | ||
return p; | ||
} | ||
|
||
auto rewrite = IRMatcher::rewriter(p, Int(32, lanes)); | ||
|
||
// Construct predicates which state the ramp can't hit the extreme | ||
// values of an int8 or an int16, so that the saturated broadcast has a | ||
// value to take on that leaves it clear of the bounds of the ramp. This | ||
// is an overconservative condition, but it's hard to imagine cases | ||
// where a more precise condition would be necessary. | ||
auto min_ramp_lane = min(c0, c0 * (lanes - 1)); | ||
auto max_ramp_lane = max(c0, c0 * (lanes - 1)); | ||
auto ramp_fits_in_i8 = min_ramp_lane > -128 && max_ramp_lane < 127; | ||
auto ramp_fits_in_i16 = min_ramp_lane > -32768 && max_ramp_lane < 32767; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might want a comment here explaining why these inequalities are strict. It took me a minute to work through why. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment immediately above not sufficient? I'm checking they can't hit the extreme values of the narrower type. I think in some cases it's fine, but it's quite hard to think about. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think an explanation of why this condition is necessary would be helpful. Also I think the inequality only needs to be strict for a strict There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment updated. I found it very hard to think through the cases, and just ended up using the most conservative condition for all of them. |
||
auto saturated_diff_i8 = saturating_cast(Int(8), saturating_sub(x, y)); | ||
auto saturated_diff_i16 = saturating_cast(Int(16), saturating_sub(x, y)); | ||
|
||
if ((t.bits() <= 8 && | ||
// Try to narrow to 8-bit comparisons | ||
(rewrite(broadcast(x, lanes) < ramp(y, c0, lanes), | ||
broadcast(saturated_diff_i8, lanes) < cast(Int(8, lanes), ramp(0, c0, lanes)), | ||
ramp_fits_in_i8) || | ||
|
||
rewrite(ramp(y, c0, lanes) < broadcast(x, lanes), | ||
cast(Int(8, lanes), ramp(0, c0, lanes)) < broadcast(saturated_diff_i8, lanes), | ||
ramp_fits_in_i8) || | ||
|
||
rewrite(broadcast(x, lanes) <= ramp(y, c0, lanes), | ||
broadcast(saturated_diff_i8, lanes) <= cast(Int(8, lanes), ramp(0, c0, lanes)), | ||
ramp_fits_in_i8) || | ||
|
||
rewrite(ramp(y, c0, lanes) <= broadcast(x, lanes), | ||
cast(Int(8, lanes), ramp(0, c0, lanes)) <= broadcast(saturated_diff_i8, lanes), | ||
ramp_fits_in_i8))) || | ||
|
||
// Try to narrow to 16-bit comparisons | ||
rewrite(broadcast(x, lanes) < ramp(y, c0, lanes), | ||
broadcast(saturated_diff_i16, lanes) < cast(Int(16, lanes), ramp(0, c0, lanes)), | ||
ramp_fits_in_i16) || | ||
|
||
rewrite(ramp(y, c0, lanes) < broadcast(x, lanes), | ||
cast(Int(16, lanes), ramp(0, c0, lanes)) < broadcast(saturated_diff_i16, lanes), | ||
ramp_fits_in_i16) || | ||
|
||
rewrite(broadcast(x, lanes) <= ramp(y, c0, lanes), | ||
broadcast(saturated_diff_i16, lanes) <= cast(Int(16, lanes), ramp(0, c0, lanes)), | ||
ramp_fits_in_i16) || | ||
|
||
rewrite(ramp(y, c0, lanes) <= broadcast(x, lanes), | ||
cast(Int(16, lanes), ramp(0, c0, lanes)) <= broadcast(saturated_diff_i16, lanes), | ||
ramp_fits_in_i16)) { | ||
return rewrite.result; | ||
} else { | ||
return p; | ||
} | ||
} | ||
|
||
Expr visit(const Select *op) override { | ||
Expr condition = mutate(op->condition); | ||
Expr true_value = mutate(op->true_value); | ||
Expr false_value = mutate(op->false_value); | ||
condition = narrow_predicate(condition, op->type); | ||
return Select::make(condition, true_value, false_value); | ||
} | ||
|
||
Expr visit(const Load *op) override { | ||
Expr predicate = mutate(op->predicate); | ||
Expr index = mutate(op->index); | ||
predicate = narrow_predicate(predicate, op->type); | ||
if (predicate.same_as(op->predicate) && index.same_as(op->index)) { | ||
return op; | ||
} else { | ||
return Load::make(op->type, op->name, std::move(index), | ||
op->image, op->param, std::move(predicate), | ||
op->alignment); | ||
} | ||
} | ||
|
||
Stmt visit(const Store *op) override { | ||
Expr predicate = mutate(op->predicate); | ||
Expr value = mutate(op->value); | ||
Expr index = mutate(op->index); | ||
predicate = narrow_predicate(predicate, value.type()); | ||
if (predicate.same_as(op->predicate) && value.same_as(op->value) && index.same_as(op->index)) { | ||
return op; | ||
} else { | ||
return Store::make(op->name, std::move(value), std::move(index), op->param, std::move(predicate), op->alignment); | ||
} | ||
} | ||
}; | ||
|
||
// Substitute in let values than have an output vector | ||
|
@@ -1119,17 +1243,62 @@ Expr lower_rounding_shift_right(const Expr &a, const Expr &b) { | |
} | ||
|
||
Expr lower_saturating_add(const Expr &a, const Expr &b) { | ||
internal_assert(a.type() == b.type()); | ||
// Lower saturating add without using widening arithmetic, which may require | ||
// types that aren't supported. | ||
return simplify(clamp(a, a.type().min() - min(b, 0), a.type().max() - max(b, 0))) + b; | ||
internal_assert(a.type() == b.type()); | ||
if (a.type().is_float()) { | ||
return a + b; | ||
} else if (a.type().is_uint()) { | ||
Expr sum = a + b; | ||
return select(sum < a, a.type().max(), sum); | ||
} else if (a.type().is_int()) { | ||
Type u = a.type().with_code(halide_type_uint); | ||
Expr ua = reinterpret(u, a); | ||
Expr ub = reinterpret(u, b); | ||
Expr upper = make_const(u, (uint64_t(1) << (a.type().bits() - 1)) - 1); | ||
Expr lower = make_const(u, (uint64_t(1) << (a.type().bits() - 1))); | ||
Expr sum = ua + ub; | ||
// For a 32-bit input, 'sum' is the low 32 bits of the true 33-bit | ||
// sum. So it's the true sum, possibly plus 2^32 in the case where the | ||
// true sum is supposed to be negative. The true sum is positive when: | ||
// a + b >= 0 === a >= -b === a >= ~b + 1 === a > ~b | ||
Expr pos_result = min(sum, upper); | ||
Expr neg_result = max(sum, lower); | ||
return simplify(reinterpret(a.type(), select(~b < a, pos_result, neg_result))); | ||
} else { | ||
internal_error << "Bad type for saturating_add: " << a.type() << "\n"; | ||
return Expr(); | ||
} | ||
} | ||
|
||
Expr lower_saturating_sub(const Expr &a, const Expr &b) { | ||
internal_assert(a.type() == b.type()); | ||
// Lower saturating add without using widening arithmetic, which may require | ||
// Lower saturating sub without using widening arithmetic, which may require | ||
// types that aren't supported. | ||
return simplify(clamp(a, a.type().min() + max(b, 0), a.type().max() + min(b, 0))) - b; | ||
internal_assert(a.type() == b.type()); | ||
if (a.type().is_float()) { | ||
return a - b; | ||
} else if (a.type().is_int()) { | ||
// Do the math in unsigned, to avoid overflow in the simplifier. | ||
Type u = a.type().with_code(halide_type_uint); | ||
Expr ua = reinterpret(u, a); | ||
Expr ub = reinterpret(u, b); | ||
Expr upper = make_const(u, (uint64_t(1) << (a.type().bits() - 1)) - 1); | ||
Expr lower = make_const(u, (uint64_t(1) << (a.type().bits() - 1))); | ||
Expr diff = ua - ub; | ||
// If a >= b, then diff is the (positive) difference. If a < b then diff | ||
// is the (negative) difference plus 2^32 due to wraparound. | ||
// We saturate the positive difference to be at most 2^31 - 1 | ||
Expr pos_diff = min(upper, diff); | ||
// and saturate the negative difference to be at least -2^31 + 2^32 = 2^31 | ||
Expr neg_diff = max(lower, diff); | ||
// Then select between them, and cast back to the signed type. | ||
return simplify(reinterpret(a.type(), select(b <= a, pos_diff, neg_diff))); | ||
} else if (a.type().is_uint()) { | ||
return simplify(select(b < a, a - b, make_zero(a.type()))); | ||
} else { | ||
internal_error << "Bad type for saturating_sub: " << a.type() << "\n"; | ||
return Expr(); | ||
} | ||
} | ||
|
||
Expr lower_saturating_cast(const Type &t, const Expr &a) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#include "Halide.h" | ||
|
||
using namespace Halide; | ||
|
||
Var x; | ||
|
||
template<typename T> | ||
void check(const Expr &e) { | ||
Func g1, g2; | ||
g1(x) = e; | ||
g2(x) = e; | ||
|
||
// Introduce some vector predicates to g1 | ||
g1.vectorize(x, 64, TailStrategy::GuardWithIf); | ||
|
||
Buffer<T> b1(1024), b2(1024); | ||
g1.realize(b1); | ||
g2.realize(b2); | ||
|
||
for (int i = 0; i < b1.width(); i++) { | ||
if (b1(i) != b2(i)) { | ||
printf("b1(%d) = %d instead of %d\n", | ||
i, b1(i), b2(i)); | ||
exit(-1); | ||
} | ||
} | ||
} | ||
|
||
template<typename T> | ||
void check_all() { | ||
Func f; | ||
f(x) = cast<T>(x); | ||
f.compute_root(); | ||
|
||
// This will have a predicated instruction in the loop tail: | ||
check<T>(f(x)); | ||
|
||
// These will also have a comparison mask in the loop body: | ||
check<T>(select(x < 50, f(x), cast<T>(17))); | ||
check<T>(select(x > 50, f(x), cast<T>(17))); | ||
|
||
// Also test boundary conditions, which introduce all sorts of coordinate | ||
// comparisons: | ||
check<T>(BoundaryConditions::repeat_edge(f, {{10, 100}})(x)); | ||
check<T>(BoundaryConditions::repeat_image(f, {{10, 100}})(x)); | ||
check<T>(BoundaryConditions::constant_exterior(f, cast<T>(17), {{10, 100}})(x)); | ||
check<T>(BoundaryConditions::mirror_image(f, {{10, 100}})(x)); | ||
check<T>(BoundaryConditions::mirror_interior(f, {{10, 100}})(x)); | ||
} | ||
|
||
int main(int argc, char **argv) { | ||
check_all<uint8_t>(); | ||
check_all<uint16_t>(); | ||
|
||
printf("Success!\n"); | ||
return 0; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Presumably we've normalized away the GT and GE at this point?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, those should be long-gone