From 7d4aa37b749207543067d0f556eb0ddf5bd58aee Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 3 Mar 2023 10:05:00 -0800 Subject: [PATCH 1/5] Compute comparison masks in narrower types if possible --- src/FindIntrinsics.cpp | 121 +++++++++++++++++++++++++ test/correctness/CMakeLists.txt | 1 + test/correctness/narrow_predicates.cpp | 57 ++++++++++++ 3 files changed, 179 insertions(+) create mode 100644 test/correctness/narrow_predicates.cpp diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 9700d06e319b..eb27134d0164 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -897,6 +897,127 @@ class FindIntrinsics : public IRMutator { } return op; } + + // Narrow and comparisons between ramps and broadcasts to produce masks that + // match the bit-width of the type being selected between or loaded or + // stored. We do this late in lowering in this pass instead of in the + // simplifier because it messes up the reasoning done by loop partitioning. + // + // For example if we're selecting between uint8s, and we have the condition: + // ramp(x, 1, 16) < broadcast(y) + // where x is an Int(32), we can rewrite this to: + // ramp(0, 1, 16) < broadcast(y - x) + // which can be safely narrowed to: + // cast(ramp(0, 1, 16)) < saturating_cast(broadcast(y - x)) + Expr narrow_predicate(const Expr &p, Type t) { + if (t.bits() >= 32) { + return p; + } + + int lanes = t.lanes(); + + if (const Or *op = p.as()) { + return narrow_predicate(op->a, t) || narrow_predicate(op->b, t); + } else if (const And *op = p.as()) { + return narrow_predicate(op->a, t) && narrow_predicate(op->b, t); + } else if (const Not *op = p.as()) { + return !narrow_predicate(op->a, t); + } + + const LT *lt = p.as(); + const LE *le = p.as(); + // Check it's a comparison + if (!(le || lt)) { + return p; + } + // Check it's an int32 comparison + if ((lt ? lt->a.type() : le->a.type()) != Int(32, lanes)) { + return p; + } + + auto rewrite = IRMatcher::rewriter(p, Int(32, lanes)); + + // Construct predicates which state the ramp can't hit the extreme + // values of an int8 or an int16. This is an overconservative condition, + // but it's hard to imagine cases where a more precise condition would + // be necessary. + auto min_ramp_lane = min(c0, c0 * (lanes - 1)); + auto max_ramp_lane = max(c0, c0 * (lanes - 1)); + auto ramp_fits_in_i8 = min_ramp_lane > -128 && max_ramp_lane < 127; + auto ramp_fits_in_i16 = min_ramp_lane > -32768 && max_ramp_lane < 32767; + + if ((t.bits() <= 8 && + // Try to narrow to 8-bit comparisons + (rewrite(broadcast(x, lanes) < ramp(y, c0, lanes), + broadcast(saturating_cast(Int(8), x - y), lanes) < cast(Int(8, lanes), ramp(0, c0, lanes)), + ramp_fits_in_i8) || + + rewrite(ramp(y, c0, lanes) < broadcast(x, lanes), + cast(Int(8, lanes), ramp(0, c0, lanes)) < broadcast(saturating_cast(Int(8), x - y), lanes), + ramp_fits_in_i8) || + + rewrite(broadcast(x, lanes) <= ramp(y, c0, lanes), + broadcast(saturating_cast(Int(8), x - y), lanes) <= cast(Int(8, lanes), ramp(0, c0, lanes)), + ramp_fits_in_i8) || + + rewrite(ramp(y, c0, lanes) <= broadcast(x, lanes), + cast(Int(8, lanes), ramp(0, c0, lanes)) <= broadcast(saturating_cast(Int(8), x - y), lanes), + ramp_fits_in_i8))) || + + // Try to narrow to 16-bit comparisons + rewrite(broadcast(x, lanes) < ramp(y, c0, lanes), + broadcast(saturating_cast(Int(16), x - y), lanes) < cast(Int(16, lanes), ramp(0, c0, lanes)), + ramp_fits_in_i16) || + + rewrite(ramp(y, c0, lanes) < broadcast(x, lanes), + cast(Int(16, lanes), ramp(0, c0, lanes)) < broadcast(saturating_cast(Int(16), x - y), lanes), + ramp_fits_in_i16) || + + rewrite(broadcast(x, lanes) <= ramp(y, c0, lanes), + broadcast(saturating_cast(Int(16), x - y), lanes) <= cast(Int(16, lanes), ramp(0, c0, lanes)), + ramp_fits_in_i16) || + + rewrite(ramp(y, c0, lanes) <= broadcast(x, lanes), + cast(Int(16, lanes), ramp(0, c0, lanes)) <= broadcast(saturating_cast(Int(16), x - y), lanes), + ramp_fits_in_i16)) { + return rewrite.result; + } else { + return p; + } + } + + Expr visit(const Select *op) override { + Expr condition = mutate(op->condition); + Expr true_value = mutate(op->true_value); + Expr false_value = mutate(op->false_value); + condition = narrow_predicate(condition, op->type); + return Select::make(condition, true_value, false_value); + } + + Expr visit(const Load *op) override { + Expr predicate = mutate(op->predicate); + Expr index = mutate(op->index); + predicate = narrow_predicate(predicate, op->type); + if (predicate.same_as(op->predicate) && index.same_as(op->index)) { + return op; + } else { + return Load::make(op->type, op->name, std::move(index), + op->image, op->param, std::move(predicate), + op->alignment); + } + } + + Stmt visit(const Store *op) override { + Expr predicate = mutate(op->predicate); + Expr value = mutate(op->value); + Expr index = mutate(op->index); + predicate = narrow_predicate(predicate, value.type()); + if (predicate.same_as(op->predicate) && value.same_as(op->value) && index.same_as(op->index)) { + return op; + } else { + return Store::make(op->name, std::move(value), std::move(index), op->param, std::move(predicate), op->alignment); + } + } }; // Substitute in let values than have an output vector diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index a55f38988576..0eb3b72dc2bd 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -216,6 +216,7 @@ tests(GROUPS correctness multipass_constraints.cpp multiple_outputs.cpp mux.cpp + narrow_predicates.cpp nested_tail_strategies.cpp newtons_method.cpp non_nesting_extern_bounds_query.cpp diff --git a/test/correctness/narrow_predicates.cpp b/test/correctness/narrow_predicates.cpp new file mode 100644 index 000000000000..a290ca626c52 --- /dev/null +++ b/test/correctness/narrow_predicates.cpp @@ -0,0 +1,57 @@ +#include "Halide.h" + +using namespace Halide; + +Var x; + +template +void check(const Expr &e) { + Func g1, g2; + g1(x) = e; + g2(x) = e; + + // Introduce some vector predicates to g1 + g1.vectorize(x, 64, TailStrategy::GuardWithIf); + + Buffer b1(1024), b2(1024); + g1.realize(b1); + g2.realize(b2); + + for (int i = 0; i < b1.width(); i++) { + if (b1(i) != b2(i)) { + printf("b1(%d) = %d instead of %d\n", + i, b1(i), b2(i)); + exit(-1); + } + } +} + +template +void check_all() { + Func f; + f(x) = cast(x); + f.compute_root(); + + // This will have a predicated instruction in the loop tail: + check(f(x)); + + // These will also have a comparison mask in the loop body: + check(select(x < 50, f(x), cast(17))); + check(select(x > 50, f(x), cast(17))); + + // Also test boundary conditions, which introduce all sorts of coordinate + // comparisons: + check(BoundaryConditions::repeat_edge(f, {{10, 100}})(x)); + check(BoundaryConditions::repeat_image(f, {{10, 100}})(x)); + check(BoundaryConditions::constant_exterior(f, cast(17), {{10, 100}})(x)); + check(BoundaryConditions::mirror_image(f, {{10, 100}})(x)); + check(BoundaryConditions::mirror_interior(f, {{10, 100}})(x)); +} + +int main(int argc, char **argv) { + check_all(); + check_all(); + + printf("Success!\n"); + return 0; +} From 90a2652dc616d901d0362252c402cb91703a62d0 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 3 Mar 2023 10:25:32 -0800 Subject: [PATCH 2/5] Remove reliance on infinite precision int32s --- src/FindIntrinsics.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index eb27134d0164..033926c1f092 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -945,40 +945,42 @@ class FindIntrinsics : public IRMutator { auto max_ramp_lane = max(c0, c0 * (lanes - 1)); auto ramp_fits_in_i8 = min_ramp_lane > -128 && max_ramp_lane < 127; auto ramp_fits_in_i16 = min_ramp_lane > -32768 && max_ramp_lane < 32767; + auto saturated_diff_i8 = saturating_cast(Int(8), saturating_sub(x, y)); + auto saturated_diff_i16 = saturating_cast(Int(16), saturating_sub(x, y)); if ((t.bits() <= 8 && // Try to narrow to 8-bit comparisons (rewrite(broadcast(x, lanes) < ramp(y, c0, lanes), - broadcast(saturating_cast(Int(8), x - y), lanes) < cast(Int(8, lanes), ramp(0, c0, lanes)), + broadcast(saturated_diff_i8, lanes) < cast(Int(8, lanes), ramp(0, c0, lanes)), ramp_fits_in_i8) || rewrite(ramp(y, c0, lanes) < broadcast(x, lanes), - cast(Int(8, lanes), ramp(0, c0, lanes)) < broadcast(saturating_cast(Int(8), x - y), lanes), + cast(Int(8, lanes), ramp(0, c0, lanes)) < broadcast(saturated_diff_i8, lanes), ramp_fits_in_i8) || rewrite(broadcast(x, lanes) <= ramp(y, c0, lanes), - broadcast(saturating_cast(Int(8), x - y), lanes) <= cast(Int(8, lanes), ramp(0, c0, lanes)), + broadcast(saturated_diff_i8, lanes) <= cast(Int(8, lanes), ramp(0, c0, lanes)), ramp_fits_in_i8) || rewrite(ramp(y, c0, lanes) <= broadcast(x, lanes), - cast(Int(8, lanes), ramp(0, c0, lanes)) <= broadcast(saturating_cast(Int(8), x - y), lanes), + cast(Int(8, lanes), ramp(0, c0, lanes)) <= broadcast(saturated_diff_i8, lanes), ramp_fits_in_i8))) || // Try to narrow to 16-bit comparisons rewrite(broadcast(x, lanes) < ramp(y, c0, lanes), - broadcast(saturating_cast(Int(16), x - y), lanes) < cast(Int(16, lanes), ramp(0, c0, lanes)), + broadcast(saturated_diff_i16, lanes) < cast(Int(16, lanes), ramp(0, c0, lanes)), ramp_fits_in_i16) || rewrite(ramp(y, c0, lanes) < broadcast(x, lanes), - cast(Int(16, lanes), ramp(0, c0, lanes)) < broadcast(saturating_cast(Int(16), x - y), lanes), + cast(Int(16, lanes), ramp(0, c0, lanes)) < broadcast(saturated_diff_i16, lanes), ramp_fits_in_i16) || rewrite(broadcast(x, lanes) <= ramp(y, c0, lanes), - broadcast(saturating_cast(Int(16), x - y), lanes) <= cast(Int(16, lanes), ramp(0, c0, lanes)), + broadcast(saturated_diff_i16, lanes) <= cast(Int(16, lanes), ramp(0, c0, lanes)), ramp_fits_in_i16) || rewrite(ramp(y, c0, lanes) <= broadcast(x, lanes), - cast(Int(16, lanes), ramp(0, c0, lanes)) <= broadcast(saturating_cast(Int(16), x - y), lanes), + cast(Int(16, lanes), ramp(0, c0, lanes)) <= broadcast(saturated_diff_i16, lanes), ramp_fits_in_i16)) { return rewrite.result; } else { From 99d8d4f7b512305a657e9ce1c73da9419123e317 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 3 Mar 2023 11:35:21 -0800 Subject: [PATCH 3/5] Further elaborate on comment --- src/FindIntrinsics.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 033926c1f092..477f46446b9e 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -938,9 +938,10 @@ class FindIntrinsics : public IRMutator { auto rewrite = IRMatcher::rewriter(p, Int(32, lanes)); // Construct predicates which state the ramp can't hit the extreme - // values of an int8 or an int16. This is an overconservative condition, - // but it's hard to imagine cases where a more precise condition would - // be necessary. + // values of an int8 or an int16, so that the saturated broadcast has a + // value to take on that leaves it clear of the bounds of the ramp. This + // is an overconservative condition, but it's hard to imagine cases + // where a more precise condition would be necessary. auto min_ramp_lane = min(c0, c0 * (lanes - 1)); auto max_ramp_lane = max(c0, c0 * (lanes - 1)); auto ramp_fits_in_i8 = min_ramp_lane > -128 && max_ramp_lane < 127; From a7aa8cc3e82ff8521bf4dce820509f6ed0ce8c3f Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sun, 19 Mar 2023 16:32:28 -0700 Subject: [PATCH 4/5] Lower signed saturating_add and sub to unsigned math The existing lowering was prone to overflow --- src/FindIntrinsics.cpp | 55 +++++++++++++++++++++++++++++++++++++---- src/Simplify_Internal.h | 2 +- 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 477f46446b9e..4c0b70d06130 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -1243,17 +1243,62 @@ Expr lower_rounding_shift_right(const Expr &a, const Expr &b) { } Expr lower_saturating_add(const Expr &a, const Expr &b) { - internal_assert(a.type() == b.type()); // Lower saturating add without using widening arithmetic, which may require // types that aren't supported. - return simplify(clamp(a, a.type().min() - min(b, 0), a.type().max() - max(b, 0))) + b; + internal_assert(a.type() == b.type()); + if (a.type().is_float()) { + return a + b; + } else if (a.type().is_uint()) { + Expr sum = a + b; + return select(sum < a, a.type().max(), sum); + } else if (a.type().is_int()) { + Type u = a.type().with_code(halide_type_uint); + Expr ua = cast(u, a); + Expr ub = cast(u, b); + Expr upper = make_const(u, (uint64_t(1) << (a.type().bits() - 1)) - 1); + Expr lower = make_const(u, (uint64_t(1) << (a.type().bits() - 1))); + Expr sum = ua + ub; + // For a 32-bit input, 'sum' is the low 32 bits of the true 33-bit + // sum. So it's the true sum, possibly plus 2^32 in the case where the + // true sum is supposed to be negative. The true sum is positive when: + // a + b >= 0 === a >= -b === a >= ~b + 1 === a > ~b + Expr pos_result = min(sum, upper); + Expr neg_result = max(sum, lower); + return simplify(reinterpret(a.type(), select(~b < a, pos_result, neg_result))); + } else { + internal_error << "Bad type for saturating_add: " << a.type() << "\n"; + return Expr(); + } } Expr lower_saturating_sub(const Expr &a, const Expr &b) { - internal_assert(a.type() == b.type()); - // Lower saturating add without using widening arithmetic, which may require + // Lower saturating sub without using widening arithmetic, which may require // types that aren't supported. - return simplify(clamp(a, a.type().min() + max(b, 0), a.type().max() + min(b, 0))) - b; + internal_assert(a.type() == b.type()); + if (a.type().is_float()) { + return a - b; + } else if (a.type().is_int()) { + // Do the math in unsigned, to avoid overflow in the simplifier. + Type u = a.type().with_code(halide_type_uint); + Expr ua = cast(u, a); + Expr ub = cast(u, b); + Expr upper = make_const(u, (uint64_t(1) << (a.type().bits() - 1)) - 1); + Expr lower = make_const(u, (uint64_t(1) << (a.type().bits() - 1))); + Expr diff = ua - ub; + // If a >= b, then diff is the (positive) difference. If a < b then diff + // is the (negative) difference plus 2^32 due to wraparound. + // We saturate the positive difference to be at most 2^31 - 1 + Expr pos_diff = min(upper, diff); + // and saturate the negative difference to be at least -2^31 + 2^32 = 2^31 + Expr neg_diff = max(lower, diff); + // Then select between them, and cast back to the signed type. + return simplify(reinterpret(a.type(), select(b <= a, pos_diff, neg_diff))); + } else if (a.type().is_uint()) { + return simplify(select(b < a, a - b, make_zero(a.type()))); + } else { + internal_error << "Bad type for saturating_sub: " << a.type() << "\n"; + return Expr(); + } } Expr lower_saturating_cast(const Type &t, const Expr &a) { diff --git a/src/Simplify_Internal.h b/src/Simplify_Internal.h index 7a53769fd55c..5b8405ab948b 100644 --- a/src/Simplify_Internal.h +++ b/src/Simplify_Internal.h @@ -109,7 +109,7 @@ class Simplify : public VariadicVisitor { } } -#if (LOG_EXPR_MUTATORIONS || LOG_STMT_MUTATIONS) +#if (LOG_EXPR_MUTATIONS || LOG_STMT_MUTATIONS) static int debug_indent; #endif From 6f97a802874081a89dad4087cf717f604e35a11b Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 24 Mar 2023 12:28:44 -0700 Subject: [PATCH 5/5] cast -> reinterpret --- src/FindIntrinsics.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 4c0b70d06130..c55669c578e8 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -1253,8 +1253,8 @@ Expr lower_saturating_add(const Expr &a, const Expr &b) { return select(sum < a, a.type().max(), sum); } else if (a.type().is_int()) { Type u = a.type().with_code(halide_type_uint); - Expr ua = cast(u, a); - Expr ub = cast(u, b); + Expr ua = reinterpret(u, a); + Expr ub = reinterpret(u, b); Expr upper = make_const(u, (uint64_t(1) << (a.type().bits() - 1)) - 1); Expr lower = make_const(u, (uint64_t(1) << (a.type().bits() - 1))); Expr sum = ua + ub; @@ -1280,8 +1280,8 @@ Expr lower_saturating_sub(const Expr &a, const Expr &b) { } else if (a.type().is_int()) { // Do the math in unsigned, to avoid overflow in the simplifier. Type u = a.type().with_code(halide_type_uint); - Expr ua = cast(u, a); - Expr ub = cast(u, b); + Expr ua = reinterpret(u, a); + Expr ub = reinterpret(u, b); Expr upper = make_const(u, (uint64_t(1) << (a.type().bits() - 1)) - 1); Expr lower = make_const(u, (uint64_t(1) << (a.type().bits() - 1))); Expr diff = ua - ub;