diff --git a/src/IRMatch.h b/src/IRMatch.h index 394858b1a958..ad045e3789d0 100644 --- a/src/IRMatch.h +++ b/src/IRMatch.h @@ -2080,6 +2080,69 @@ HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp +struct SliceOp { + struct pattern_tag {}; + Vec vec; + Base base; + Stride stride; + Lanes lanes; + + static constexpr uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds; + + constexpr static IRNodeType min_node_type = IRNodeType::Shuffle; + constexpr static IRNodeType max_node_type = IRNodeType::Shuffle; + constexpr static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical; + + template + HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { + if (e.node_type != IRNodeType::Shuffle) { + return false; + } + const Shuffle &v = (const Shuffle &)e; + return v.vectors.size() == 1 && + vec.template match(*v.vectors[0].get(), state) && + base.template match::mask>(v.slice_begin(), state) && + stride.template match::mask | bindings::mask>(v.slice_stride(), state) && + lanes.template match::mask | bindings::mask | bindings::mask>(v.type.lanes(), state); + } + + HALIDE_ALWAYS_INLINE + Expr make(MatcherState &state, halide_type_t type_hint) const { + halide_scalar_value_t base_val, stride_val, lanes_val; + halide_type_t ty; + base.make_folded_const(base_val, ty, state); + int b = (int)base_val.u.i64; + stride.make_folded_const(stride_val, ty, state); + int s = (int)stride_val.u.i64; + lanes.make_folded_const(lanes_val, ty, state); + int l = (int)lanes_val.u.i64; + return Shuffle::make_slice(vec.make(state, type_hint), b, s, l); + } + + constexpr static bool foldable = false; + + HALIDE_ALWAYS_INLINE + SliceOp(Vec v, Base b, Stride s, Lanes l) + : vec(v), base(b), stride(s), lanes(l) { + static_assert(Base::foldable, "Base of slice should consist only of operations that constant-fold"); + static_assert(Stride::foldable, "Stride of slice should consist only of operations that constant-fold"); + static_assert(Lanes::foldable, "Lanes of slice should consist only of operations that constant-fold"); + } +}; + +template +std::ostream &operator<<(std::ostream &s, const SliceOp &op) { + s << "slice(" << op.vec << ", " << op.base << ", " << op.stride << ", " << op.lanes << ")"; + return s; +} + +template +HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept + -> SliceOp { + return {pattern_arg(vec), pattern_arg(base), pattern_arg(stride), pattern_arg(lanes)}; +} + template struct Fold { struct pattern_tag {}; @@ -2551,7 +2614,7 @@ std::ostream &operator<<(std::ostream &s, const IsMinValue &op) { } template -struct HasEvenLanes { +struct LanesOf { struct pattern_tag {}; A a; @@ -2568,22 +2631,22 @@ struct HasEvenLanes { void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method. Type t = a.make(state, {}).type(); - val.u.u64 = (t.lanes() % 2 == 0); + val.u.u64 = t.lanes(); ty.code = halide_type_uint; - ty.bits = 1; - ty.lanes = t.lanes(); + ty.bits = 32; + ty.lanes = 1; } }; template -HALIDE_ALWAYS_INLINE auto has_even_lanes(A &&a) noexcept -> HasEvenLanes { +HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf { assert_is_lvalue_if_expr(); return {pattern_arg(a)}; } template -std::ostream &operator<<(std::ostream &s, const HasEvenLanes &op) { - s << "has_even_lanes(" << op.a << ")"; +std::ostream &operator<<(std::ostream &s, const LanesOf &op) { + s << "lanes_of(" << op.a << ")"; return s; } diff --git a/src/Simplify_Add.cpp b/src/Simplify_Add.cpp index 2c3a045abf78..fb9238dd9a6a 100644 --- a/src/Simplify_Add.cpp +++ b/src/Simplify_Add.cpp @@ -112,6 +112,16 @@ Expr Simplify::visit(const Add *op, ExprInfo *bounds) { rewrite(y*x + y*z, y*(x + z)) || rewrite(x*c0 + y*c1, (x + y*fold(c1/c0)) * c0, c1 % c0 == 0) || rewrite(x*c0 + y*c1, (x*fold(c0/c1) + y) * c1, c0 % c1 == 0) || + + // Hoist shuffles. The Shuffle visitor wants to sink + // extract_elements to the leaves, and those count as degenerate + // slices, so only hoist shuffles that grab more than one lane. + rewrite(slice(x, c0, c1, c2) + slice(y, c0, c1, c2), slice(x + y, c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) || + rewrite(slice(x, c0, c1, c2) + (z + slice(y, c0, c1, c2)), slice(x + y, c0, c1, c2) + z, c2 > 1 && lanes_of(x) == lanes_of(y)) || + rewrite(slice(x, c0, c1, c2) + (slice(y, c0, c1, c2) + z), slice(x + y, c0, c1, c2) + z, c2 > 1 && lanes_of(x) == lanes_of(y)) || + rewrite(slice(x, c0, c1, c2) + (z - slice(y, c0, c1, c2)), slice(x - y, c0, c1, c2) + z, c2 > 1 && lanes_of(x) == lanes_of(y)) || + rewrite(slice(x, c0, c1, c2) + (slice(y, c0, c1, c2) - z), slice(x + y, c0, c1, c2) - z, c2 > 1 && lanes_of(x) == lanes_of(y)) || + (no_overflow(op->type) && (rewrite(x + x*y, x * (y + 1)) || rewrite(x + y*x, (y + 1) * x) || @@ -187,18 +197,6 @@ Expr Simplify::visit(const Add *op, ExprInfo *bounds) { return mutate(rewrite.result, bounds); } // clang-format on - - const Shuffle *shuffle_a = a.as(); - const Shuffle *shuffle_b = b.as(); - if (shuffle_a && shuffle_b && - shuffle_a->is_slice() && - shuffle_b->is_slice()) { - if (a.same_as(op->a) && b.same_as(op->b)) { - return hoist_slice_vector(op); - } else { - return hoist_slice_vector(Add::make(a, b)); - } - } } if (a.same_as(op->a) && b.same_as(op->b)) { diff --git a/src/Simplify_Internal.h b/src/Simplify_Internal.h index cba41d444f4b..7a53769fd55c 100644 --- a/src/Simplify_Internal.h +++ b/src/Simplify_Internal.h @@ -289,9 +289,6 @@ class Simplify : public VariadicVisitor { return f; } - template - Expr hoist_slice_vector(Expr e); - Stmt mutate_let_body(const Stmt &s, ExprInfo *) { return mutate(s); } diff --git a/src/Simplify_Max.cpp b/src/Simplify_Max.cpp index d3dbb97631d6..1a79aef962fa 100644 --- a/src/Simplify_Max.cpp +++ b/src/Simplify_Max.cpp @@ -189,6 +189,13 @@ Expr Simplify::visit(const Max *op, ExprInfo *bounds) { rewrite(max(select(x, y, z), select(x, w, u)), select(x, max(y, w), max(z, u))) || + // Hoist shuffles. The Shuffle visitor wants to sink + // extract_elements to the leaves, and those count as degenerate + // slices, so only hoist shuffles that grab more than one lane. + rewrite(max(slice(x, c0, c1, c2), slice(y, c0, c1, c2)), slice(max(x, y), c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) || + rewrite(max(slice(x, c0, c1, c2), max(slice(y, c0, c1, c2), z)), max(slice(max(x, y), c0, c1, c2), z), c2 > 1 && lanes_of(x) == lanes_of(y)) || + rewrite(max(slice(x, c0, c1, c2), max(z, slice(y, c0, c1, c2))), max(slice(max(x, y), c0, c1, c2), z), c2 > 1 && lanes_of(x) == lanes_of(y)) || + (no_overflow(op->type) && (rewrite(max(max(x, y) + c0, x), max(x, y + c0), c0 < 0) || rewrite(max(max(x, y) + c0, x), max(x, y) + c0, c0 > 0) || @@ -299,18 +306,6 @@ Expr Simplify::visit(const Max *op, ExprInfo *bounds) { // clang-format on } - const Shuffle *shuffle_a = a.as(); - const Shuffle *shuffle_b = b.as(); - if (shuffle_a && shuffle_b && - shuffle_a->is_slice() && - shuffle_b->is_slice()) { - if (a.same_as(op->a) && b.same_as(op->b)) { - return hoist_slice_vector(op); - } else { - return hoist_slice_vector(Max::make(a, b)); - } - } - if (a.same_as(op->a) && b.same_as(op->b)) { return op; } else { diff --git a/src/Simplify_Min.cpp b/src/Simplify_Min.cpp index 7cc2949557d7..214ed09374d3 100644 --- a/src/Simplify_Min.cpp +++ b/src/Simplify_Min.cpp @@ -193,6 +193,12 @@ Expr Simplify::visit(const Min *op, ExprInfo *bounds) { rewrite(min(select(x, y, z), select(x, w, u)), select(x, min(y, w), min(z, u))) || + // Hoist shuffles. The Shuffle visitor wants to sink + // extract_elements to the leaves, and those count as degenerate + // slices, so only hoist shuffles that grab more than one lane. + rewrite(min(slice(x, c0, c1, c2), slice(y, c0, c1, c2)), slice(min(x, y), c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) || + rewrite(min(slice(x, c0, c1, c2), min(slice(y, c0, c1, c2), z)), min(slice(min(x, y), c0, c1, c2), z), c2 > 1 && lanes_of(x) == lanes_of(y)) || + rewrite(min(slice(x, c0, c1, c2), min(z, slice(y, c0, c1, c2))), min(slice(min(x, y), c0, c1, c2), z), c2 > 1 && lanes_of(x) == lanes_of(y)) || (no_overflow(op->type) && (rewrite(min(min(x, y) + c0, x), min(x, y + c0), c0 > 0) || rewrite(min(min(x, y) + c0, x), min(x, y) + c0, c0 < 0) || @@ -311,18 +317,6 @@ Expr Simplify::visit(const Min *op, ExprInfo *bounds) { // clang-format on } - const Shuffle *shuffle_a = a.as(); - const Shuffle *shuffle_b = b.as(); - if (shuffle_a && shuffle_b && - shuffle_a->is_slice() && - shuffle_b->is_slice()) { - if (a.same_as(op->a) && b.same_as(op->b)) { - return hoist_slice_vector(op); - } else { - return hoist_slice_vector(Min::make(a, b)); - } - } - if (a.same_as(op->a) && b.same_as(op->b)) { return op; } else { diff --git a/src/Simplify_Mul.cpp b/src/Simplify_Mul.cpp index 54e07df14879..881d09112f7d 100644 --- a/src/Simplify_Mul.cpp +++ b/src/Simplify_Mul.cpp @@ -94,23 +94,19 @@ Expr Simplify::visit(const Mul *op, ExprInfo *bounds) { rewrite(ramp(x, y, c0) * broadcast(z, c0), ramp(x * z, y * z, c0)) || rewrite(ramp(broadcast(x, c0), broadcast(y, c0), c1) * broadcast(z, c2), ramp(broadcast(x * z, c0), broadcast(y * z, c0), c1), c2 == c0 * c1) || + + // Hoist shuffles. The Shuffle visitor wants to sink + // extract_elements to the leaves, and those count as degenerate + // slices, so only hoist shuffles that grab more than one lane. + rewrite(slice(x, c0, c1, c2) * slice(y, c0, c1, c2), slice(x * y, c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) || + rewrite(slice(x, c0, c1, c2) * (slice(y, c0, c1, c2) * z), slice(x * y, c0, c1, c2) * z, c2 > 1 && lanes_of(x) == lanes_of(y)) || + rewrite(slice(x, c0, c1, c2) * (z * slice(y, c0, c1, c2)), slice(x * y, c0, c1, c2) * z, c2 > 1 && lanes_of(x) == lanes_of(y)) || + false) { return mutate(rewrite.result, bounds); } } - const Shuffle *shuffle_a = a.as(); - const Shuffle *shuffle_b = b.as(); - if (shuffle_a && shuffle_b && - shuffle_a->is_slice() && - shuffle_b->is_slice()) { - if (a.same_as(op->a) && b.same_as(op->b)) { - return hoist_slice_vector(op); - } else { - return hoist_slice_vector(Mul::make(a, b)); - } - } - if (a.same_as(op->a) && b.same_as(op->b)) { return op; } else { diff --git a/src/Simplify_Shuffle.cpp b/src/Simplify_Shuffle.cpp index 35622aee9c4e..7da4f6699ab7 100644 --- a/src/Simplify_Shuffle.cpp +++ b/src/Simplify_Shuffle.cpp @@ -321,47 +321,5 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) { } } -template -Expr Simplify::hoist_slice_vector(Expr e) { - const T *op = e.as(); - internal_assert(op); - - const Shuffle *shuffle_a = op->a.template as(); - const Shuffle *shuffle_b = op->b.template as(); - - internal_assert(shuffle_a && shuffle_b && - shuffle_a->is_slice() && - shuffle_b->is_slice()); - - if (shuffle_a->indices != shuffle_b->indices) { - return e; - } - - const std::vector &slices_a = shuffle_a->vectors; - const std::vector &slices_b = shuffle_b->vectors; - if (slices_a.size() != slices_b.size()) { - return e; - } - - for (size_t i = 0; i < slices_a.size(); i++) { - if (slices_a[i].type() != slices_b[i].type()) { - return e; - } - } - - vector new_slices; - for (size_t i = 0; i < slices_a.size(); i++) { - new_slices.push_back(T::make(slices_a[i], slices_b[i])); - } - - return Shuffle::make(new_slices, shuffle_a->indices); -} - -template Expr Simplify::hoist_slice_vector(Expr); -template Expr Simplify::hoist_slice_vector(Expr); -template Expr Simplify::hoist_slice_vector(Expr); -template Expr Simplify::hoist_slice_vector(Expr); -template Expr Simplify::hoist_slice_vector(Expr); - } // namespace Internal } // namespace Halide diff --git a/src/Simplify_Sub.cpp b/src/Simplify_Sub.cpp index ed18ca7c8209..1ab53b2dea90 100644 --- a/src/Simplify_Sub.cpp +++ b/src/Simplify_Sub.cpp @@ -175,6 +175,15 @@ Expr Simplify::visit(const Sub *op, ExprInfo *bounds) { rewrite(x - x%c0, (x/c0)*c0) || rewrite(x - ((x + c0)/c1)*c1, (x + c0)%c1 - c0, c1 > 0) || + // Hoist shuffles. The Shuffle visitor wants to sink + // extract_elements to the leaves, and those count as degenerate + // slices, so only hoist shuffles that grab more than one lane. + rewrite(slice(x, c0, c1, c2) - slice(y, c0, c1, c2), slice(x - y, c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) || + rewrite(slice(x, c0, c1, c2) - (z + slice(y, c0, c1, c2)), slice(x - y, c0, c1, c2) - z, c2 > 1 && lanes_of(x) == lanes_of(y)) || + rewrite(slice(x, c0, c1, c2) - (slice(y, c0, c1, c2) + z), slice(x - y, c0, c1, c2) - z, c2 > 1 && lanes_of(x) == lanes_of(y)) || + rewrite((slice(x, c0, c1, c2) - z) - slice(y, c0, c1, c2), slice(x - y, c0, c1, c2) - z, c2 > 1 && lanes_of(x) == lanes_of(y)) || + rewrite((z - slice(x, c0, c1, c2)) - slice(y, c0, c1, c2), z - slice(x + y, c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) || + (no_overflow(op->type) && (rewrite(max(x, y) - x, max(y - x, 0)) || rewrite(min(x, y) - x, min(y - x, 0)) || @@ -442,18 +451,6 @@ Expr Simplify::visit(const Sub *op, ExprInfo *bounds) { } // clang-format on - const Shuffle *shuffle_a = a.as(); - const Shuffle *shuffle_b = b.as(); - if (shuffle_a && shuffle_b && - shuffle_a->is_slice() && - shuffle_b->is_slice()) { - if (a.same_as(op->a) && b.same_as(op->b)) { - return hoist_slice_vector(op); - } else { - return hoist_slice_vector(Sub::make(a, b)); - } - } - if (a.same_as(op->a) && b.same_as(op->b)) { return op; } else { diff --git a/test/correctness/simplify.cpp b/test/correctness/simplify.cpp index aa10815663d2..bcf706f6b9f4 100644 --- a/test/correctness/simplify.cpp +++ b/test/correctness/simplify.cpp @@ -711,6 +711,23 @@ void check_vectors() { check(concat_vectors(loads), Load::make(Float(32, lanes * vectors), "buf", ramp(0, 1, lanes * vectors), Buffer<>(), Parameter(), const_true(vectors * lanes), ModulusRemainder(0, 0))); } + { + Expr vx = Variable::make(Int(32, 32), "x"); + Expr vy = Variable::make(Int(32, 32), "y"); + Expr vz = Variable::make(Int(32, 8), "z"); + Expr vw = Variable::make(Int(32, 16), "w"); + // Check that vector slices are hoisted. + check(slice(vx, 0, 2, 8) + slice(vy, 0, 2, 8), slice(vx + vy, 0, 2, 8)); + check(slice(vx, 0, 2, 8) + (slice(vy, 0, 2, 8) + vz), slice(vx + vy, 0, 2, 8) + vz); + check(slice(vx, 0, 2, 8) + (vz + slice(vy, 0, 2, 8)), slice(vx + vy, 0, 2, 8) + vz); + // Check that degenerate vector slices are not hoisted. + check(slice(vx, 0, 2, 1) + slice(vy, 0, 2, 1), slice(vx, 0, 2, 1) + slice(vy, 0, 2, 1)); + check(slice(vx, 0, 2, 1) + (slice(vy, 0, 2, 1) + z), slice(vx, 0, 2, 1) + (slice(vy, 0, 2, 1) + z)); + // Check slices are only hoisted when the lanes of the sliced vectors match. + check(slice(vx, 0, 2, 8) + slice(vw, 0, 2, 8), slice(vx, 0, 2, 8) + slice(vw, 0, 2, 8)); + check(slice(vx, 0, 2, 8) + (slice(vw, 0, 2, 8) + vz), slice(vx, 0, 2, 8) + (slice(vw, 0, 2, 8) + vz)); + } + { // A predicated store with a provably-false predicate. Expr pred = ramp(x * y + x * z, 2, 8) > 2;