Skip to content

Commit

Permalink
Hoist vector slices using rewrite rules (halide#7243)
Browse files Browse the repository at this point in the history
* Hoist slices using rewrite rules

This lets us add associative variants more easily, which are helpful in
the work on staging strided loads.

* Don't hoist extract_element shuffles

The Shuffle visitor wants to sink them

* Add some static asserts

* Add explanatory comment on shuffle hoisting

* Fix comment

* add lanes predicate to slice hoisting

* add vector slice hoisting test cases

Co-authored-by: Steven Johnson <srj@google.com>
Co-authored-by: Alexander <ajroot@stanford.edu>
  • Loading branch information
3 people authored and ardier committed Mar 3, 2024
1 parent 3ed06f7 commit 38d41f1
Show file tree
Hide file tree
Showing 9 changed files with 127 additions and 112 deletions.
77 changes: 70 additions & 7 deletions src/IRMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -2080,6 +2080,69 @@ HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp<declty
return {t, pattern_arg(a)};
}

template<typename Vec, typename Base, typename Stride, typename Lanes>
struct SliceOp {
struct pattern_tag {};
Vec vec;
Base base;
Stride stride;
Lanes lanes;

static constexpr uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds;

constexpr static IRNodeType min_node_type = IRNodeType::Shuffle;
constexpr static IRNodeType max_node_type = IRNodeType::Shuffle;
constexpr static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical;

template<uint32_t bound>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
if (e.node_type != IRNodeType::Shuffle) {
return false;
}
const Shuffle &v = (const Shuffle &)e;
return v.vectors.size() == 1 &&
vec.template match<bound>(*v.vectors[0].get(), state) &&
base.template match<bound | bindings<Vec>::mask>(v.slice_begin(), state) &&
stride.template match<bound | bindings<Vec>::mask | bindings<Base>::mask>(v.slice_stride(), state) &&
lanes.template match<bound | bindings<Vec>::mask | bindings<Base>::mask | bindings<Stride>::mask>(v.type.lanes(), state);
}

HALIDE_ALWAYS_INLINE
Expr make(MatcherState &state, halide_type_t type_hint) const {
halide_scalar_value_t base_val, stride_val, lanes_val;
halide_type_t ty;
base.make_folded_const(base_val, ty, state);
int b = (int)base_val.u.i64;
stride.make_folded_const(stride_val, ty, state);
int s = (int)stride_val.u.i64;
lanes.make_folded_const(lanes_val, ty, state);
int l = (int)lanes_val.u.i64;
return Shuffle::make_slice(vec.make(state, type_hint), b, s, l);
}

constexpr static bool foldable = false;

HALIDE_ALWAYS_INLINE
SliceOp(Vec v, Base b, Stride s, Lanes l)
: vec(v), base(b), stride(s), lanes(l) {
static_assert(Base::foldable, "Base of slice should consist only of operations that constant-fold");
static_assert(Stride::foldable, "Stride of slice should consist only of operations that constant-fold");
static_assert(Lanes::foldable, "Lanes of slice should consist only of operations that constant-fold");
}
};

template<typename Vec, typename Base, typename Stride, typename Lanes>
std::ostream &operator<<(std::ostream &s, const SliceOp<Vec, Base, Stride, Lanes> &op) {
s << "slice(" << op.vec << ", " << op.base << ", " << op.stride << ", " << op.lanes << ")";
return s;
}

template<typename Vec, typename Base, typename Stride, typename Lanes>
HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept
-> SliceOp<decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))> {
return {pattern_arg(vec), pattern_arg(base), pattern_arg(stride), pattern_arg(lanes)};
}

template<typename A>
struct Fold {
struct pattern_tag {};
Expand Down Expand Up @@ -2551,7 +2614,7 @@ std::ostream &operator<<(std::ostream &s, const IsMinValue<A> &op) {
}

template<typename A>
struct HasEvenLanes {
struct LanesOf {
struct pattern_tag {};
A a;

Expand All @@ -2568,22 +2631,22 @@ struct HasEvenLanes {
void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const {
// a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
Type t = a.make(state, {}).type();
val.u.u64 = (t.lanes() % 2 == 0);
val.u.u64 = t.lanes();
ty.code = halide_type_uint;
ty.bits = 1;
ty.lanes = t.lanes();
ty.bits = 32;
ty.lanes = 1;
}
};

template<typename A>
HALIDE_ALWAYS_INLINE auto has_even_lanes(A &&a) noexcept -> HasEvenLanes<decltype(pattern_arg(a))> {
HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf<decltype(pattern_arg(a))> {
assert_is_lvalue_if_expr<A>();
return {pattern_arg(a)};
}

template<typename A>
std::ostream &operator<<(std::ostream &s, const HasEvenLanes<A> &op) {
s << "has_even_lanes(" << op.a << ")";
std::ostream &operator<<(std::ostream &s, const LanesOf<A> &op) {
s << "lanes_of(" << op.a << ")";
return s;
}

Expand Down
22 changes: 10 additions & 12 deletions src/Simplify_Add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ Expr Simplify::visit(const Add *op, ExprInfo *bounds) {
rewrite(y*x + y*z, y*(x + z)) ||
rewrite(x*c0 + y*c1, (x + y*fold(c1/c0)) * c0, c1 % c0 == 0) ||
rewrite(x*c0 + y*c1, (x*fold(c0/c1) + y) * c1, c0 % c1 == 0) ||

// Hoist shuffles. The Shuffle visitor wants to sink
// extract_elements to the leaves, and those count as degenerate
// slices, so only hoist shuffles that grab more than one lane.
rewrite(slice(x, c0, c1, c2) + slice(y, c0, c1, c2), slice(x + y, c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(slice(x, c0, c1, c2) + (z + slice(y, c0, c1, c2)), slice(x + y, c0, c1, c2) + z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(slice(x, c0, c1, c2) + (slice(y, c0, c1, c2) + z), slice(x + y, c0, c1, c2) + z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(slice(x, c0, c1, c2) + (z - slice(y, c0, c1, c2)), slice(x - y, c0, c1, c2) + z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(slice(x, c0, c1, c2) + (slice(y, c0, c1, c2) - z), slice(x + y, c0, c1, c2) - z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||

(no_overflow(op->type) &&
(rewrite(x + x*y, x * (y + 1)) ||
rewrite(x + y*x, (y + 1) * x) ||
Expand Down Expand Up @@ -187,18 +197,6 @@ Expr Simplify::visit(const Add *op, ExprInfo *bounds) {
return mutate(rewrite.result, bounds);
}
// clang-format on

const Shuffle *shuffle_a = a.as<Shuffle>();
const Shuffle *shuffle_b = b.as<Shuffle>();
if (shuffle_a && shuffle_b &&
shuffle_a->is_slice() &&
shuffle_b->is_slice()) {
if (a.same_as(op->a) && b.same_as(op->b)) {
return hoist_slice_vector<Add>(op);
} else {
return hoist_slice_vector<Add>(Add::make(a, b));
}
}
}

if (a.same_as(op->a) && b.same_as(op->b)) {
Expand Down
3 changes: 0 additions & 3 deletions src/Simplify_Internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,6 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
return f;
}

template<typename T>
Expr hoist_slice_vector(Expr e);

Stmt mutate_let_body(const Stmt &s, ExprInfo *) {
return mutate(s);
}
Expand Down
19 changes: 7 additions & 12 deletions src/Simplify_Max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,13 @@ Expr Simplify::visit(const Max *op, ExprInfo *bounds) {

rewrite(max(select(x, y, z), select(x, w, u)), select(x, max(y, w), max(z, u))) ||

// Hoist shuffles. The Shuffle visitor wants to sink
// extract_elements to the leaves, and those count as degenerate
// slices, so only hoist shuffles that grab more than one lane.
rewrite(max(slice(x, c0, c1, c2), slice(y, c0, c1, c2)), slice(max(x, y), c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(max(slice(x, c0, c1, c2), max(slice(y, c0, c1, c2), z)), max(slice(max(x, y), c0, c1, c2), z), c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(max(slice(x, c0, c1, c2), max(z, slice(y, c0, c1, c2))), max(slice(max(x, y), c0, c1, c2), z), c2 > 1 && lanes_of(x) == lanes_of(y)) ||

(no_overflow(op->type) &&
(rewrite(max(max(x, y) + c0, x), max(x, y + c0), c0 < 0) ||
rewrite(max(max(x, y) + c0, x), max(x, y) + c0, c0 > 0) ||
Expand Down Expand Up @@ -299,18 +306,6 @@ Expr Simplify::visit(const Max *op, ExprInfo *bounds) {
// clang-format on
}

const Shuffle *shuffle_a = a.as<Shuffle>();
const Shuffle *shuffle_b = b.as<Shuffle>();
if (shuffle_a && shuffle_b &&
shuffle_a->is_slice() &&
shuffle_b->is_slice()) {
if (a.same_as(op->a) && b.same_as(op->b)) {
return hoist_slice_vector<Max>(op);
} else {
return hoist_slice_vector<Max>(Max::make(a, b));
}
}

if (a.same_as(op->a) && b.same_as(op->b)) {
return op;
} else {
Expand Down
18 changes: 6 additions & 12 deletions src/Simplify_Min.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ Expr Simplify::visit(const Min *op, ExprInfo *bounds) {

rewrite(min(select(x, y, z), select(x, w, u)), select(x, min(y, w), min(z, u))) ||

// Hoist shuffles. The Shuffle visitor wants to sink
// extract_elements to the leaves, and those count as degenerate
// slices, so only hoist shuffles that grab more than one lane.
rewrite(min(slice(x, c0, c1, c2), slice(y, c0, c1, c2)), slice(min(x, y), c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(min(slice(x, c0, c1, c2), min(slice(y, c0, c1, c2), z)), min(slice(min(x, y), c0, c1, c2), z), c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(min(slice(x, c0, c1, c2), min(z, slice(y, c0, c1, c2))), min(slice(min(x, y), c0, c1, c2), z), c2 > 1 && lanes_of(x) == lanes_of(y)) ||
(no_overflow(op->type) &&
(rewrite(min(min(x, y) + c0, x), min(x, y + c0), c0 > 0) ||
rewrite(min(min(x, y) + c0, x), min(x, y) + c0, c0 < 0) ||
Expand Down Expand Up @@ -311,18 +317,6 @@ Expr Simplify::visit(const Min *op, ExprInfo *bounds) {
// clang-format on
}

const Shuffle *shuffle_a = a.as<Shuffle>();
const Shuffle *shuffle_b = b.as<Shuffle>();
if (shuffle_a && shuffle_b &&
shuffle_a->is_slice() &&
shuffle_b->is_slice()) {
if (a.same_as(op->a) && b.same_as(op->b)) {
return hoist_slice_vector<Min>(op);
} else {
return hoist_slice_vector<Min>(Min::make(a, b));
}
}

if (a.same_as(op->a) && b.same_as(op->b)) {
return op;
} else {
Expand Down
20 changes: 8 additions & 12 deletions src/Simplify_Mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,19 @@ Expr Simplify::visit(const Mul *op, ExprInfo *bounds) {
rewrite(ramp(x, y, c0) * broadcast(z, c0), ramp(x * z, y * z, c0)) ||
rewrite(ramp(broadcast(x, c0), broadcast(y, c0), c1) * broadcast(z, c2),
ramp(broadcast(x * z, c0), broadcast(y * z, c0), c1), c2 == c0 * c1) ||

// Hoist shuffles. The Shuffle visitor wants to sink
// extract_elements to the leaves, and those count as degenerate
// slices, so only hoist shuffles that grab more than one lane.
rewrite(slice(x, c0, c1, c2) * slice(y, c0, c1, c2), slice(x * y, c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(slice(x, c0, c1, c2) * (slice(y, c0, c1, c2) * z), slice(x * y, c0, c1, c2) * z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(slice(x, c0, c1, c2) * (z * slice(y, c0, c1, c2)), slice(x * y, c0, c1, c2) * z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||

false) {
return mutate(rewrite.result, bounds);
}
}

const Shuffle *shuffle_a = a.as<Shuffle>();
const Shuffle *shuffle_b = b.as<Shuffle>();
if (shuffle_a && shuffle_b &&
shuffle_a->is_slice() &&
shuffle_b->is_slice()) {
if (a.same_as(op->a) && b.same_as(op->b)) {
return hoist_slice_vector<Mul>(op);
} else {
return hoist_slice_vector<Mul>(Mul::make(a, b));
}
}

if (a.same_as(op->a) && b.same_as(op->b)) {
return op;
} else {
Expand Down
42 changes: 0 additions & 42 deletions src/Simplify_Shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,47 +321,5 @@ Expr Simplify::visit(const Shuffle *op, ExprInfo *bounds) {
}
}

template<typename T>
Expr Simplify::hoist_slice_vector(Expr e) {
const T *op = e.as<T>();
internal_assert(op);

const Shuffle *shuffle_a = op->a.template as<Shuffle>();
const Shuffle *shuffle_b = op->b.template as<Shuffle>();

internal_assert(shuffle_a && shuffle_b &&
shuffle_a->is_slice() &&
shuffle_b->is_slice());

if (shuffle_a->indices != shuffle_b->indices) {
return e;
}

const std::vector<Expr> &slices_a = shuffle_a->vectors;
const std::vector<Expr> &slices_b = shuffle_b->vectors;
if (slices_a.size() != slices_b.size()) {
return e;
}

for (size_t i = 0; i < slices_a.size(); i++) {
if (slices_a[i].type() != slices_b[i].type()) {
return e;
}
}

vector<Expr> new_slices;
for (size_t i = 0; i < slices_a.size(); i++) {
new_slices.push_back(T::make(slices_a[i], slices_b[i]));
}

return Shuffle::make(new_slices, shuffle_a->indices);
}

template Expr Simplify::hoist_slice_vector<Add>(Expr);
template Expr Simplify::hoist_slice_vector<Sub>(Expr);
template Expr Simplify::hoist_slice_vector<Mul>(Expr);
template Expr Simplify::hoist_slice_vector<Min>(Expr);
template Expr Simplify::hoist_slice_vector<Max>(Expr);

} // namespace Internal
} // namespace Halide
21 changes: 9 additions & 12 deletions src/Simplify_Sub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ Expr Simplify::visit(const Sub *op, ExprInfo *bounds) {
rewrite(x - x%c0, (x/c0)*c0) ||
rewrite(x - ((x + c0)/c1)*c1, (x + c0)%c1 - c0, c1 > 0) ||

// Hoist shuffles. The Shuffle visitor wants to sink
// extract_elements to the leaves, and those count as degenerate
// slices, so only hoist shuffles that grab more than one lane.
rewrite(slice(x, c0, c1, c2) - slice(y, c0, c1, c2), slice(x - y, c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(slice(x, c0, c1, c2) - (z + slice(y, c0, c1, c2)), slice(x - y, c0, c1, c2) - z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite(slice(x, c0, c1, c2) - (slice(y, c0, c1, c2) + z), slice(x - y, c0, c1, c2) - z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite((slice(x, c0, c1, c2) - z) - slice(y, c0, c1, c2), slice(x - y, c0, c1, c2) - z, c2 > 1 && lanes_of(x) == lanes_of(y)) ||
rewrite((z - slice(x, c0, c1, c2)) - slice(y, c0, c1, c2), z - slice(x + y, c0, c1, c2), c2 > 1 && lanes_of(x) == lanes_of(y)) ||

(no_overflow(op->type) &&
(rewrite(max(x, y) - x, max(y - x, 0)) ||
rewrite(min(x, y) - x, min(y - x, 0)) ||
Expand Down Expand Up @@ -442,18 +451,6 @@ Expr Simplify::visit(const Sub *op, ExprInfo *bounds) {
}
// clang-format on

const Shuffle *shuffle_a = a.as<Shuffle>();
const Shuffle *shuffle_b = b.as<Shuffle>();
if (shuffle_a && shuffle_b &&
shuffle_a->is_slice() &&
shuffle_b->is_slice()) {
if (a.same_as(op->a) && b.same_as(op->b)) {
return hoist_slice_vector<Sub>(op);
} else {
return hoist_slice_vector<Sub>(Sub::make(a, b));
}
}

if (a.same_as(op->a) && b.same_as(op->b)) {
return op;
} else {
Expand Down
17 changes: 17 additions & 0 deletions test/correctness/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,23 @@ void check_vectors() {
check(concat_vectors(loads), Load::make(Float(32, lanes * vectors), "buf", ramp(0, 1, lanes * vectors), Buffer<>(), Parameter(), const_true(vectors * lanes), ModulusRemainder(0, 0)));
}

{
Expr vx = Variable::make(Int(32, 32), "x");
Expr vy = Variable::make(Int(32, 32), "y");
Expr vz = Variable::make(Int(32, 8), "z");
Expr vw = Variable::make(Int(32, 16), "w");
// Check that vector slices are hoisted.
check(slice(vx, 0, 2, 8) + slice(vy, 0, 2, 8), slice(vx + vy, 0, 2, 8));
check(slice(vx, 0, 2, 8) + (slice(vy, 0, 2, 8) + vz), slice(vx + vy, 0, 2, 8) + vz);
check(slice(vx, 0, 2, 8) + (vz + slice(vy, 0, 2, 8)), slice(vx + vy, 0, 2, 8) + vz);
// Check that degenerate vector slices are not hoisted.
check(slice(vx, 0, 2, 1) + slice(vy, 0, 2, 1), slice(vx, 0, 2, 1) + slice(vy, 0, 2, 1));
check(slice(vx, 0, 2, 1) + (slice(vy, 0, 2, 1) + z), slice(vx, 0, 2, 1) + (slice(vy, 0, 2, 1) + z));
// Check slices are only hoisted when the lanes of the sliced vectors match.
check(slice(vx, 0, 2, 8) + slice(vw, 0, 2, 8), slice(vx, 0, 2, 8) + slice(vw, 0, 2, 8));
check(slice(vx, 0, 2, 8) + (slice(vw, 0, 2, 8) + vz), slice(vx, 0, 2, 8) + (slice(vw, 0, 2, 8) + vz));
}

{
// A predicated store with a provably-false predicate.
Expr pred = ramp(x * y + x * z, 2, 8) > 2;
Expand Down

0 comments on commit 38d41f1

Please sign in to comment.