Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hoist vector slices using rewrite rules #7243

Merged
merged 10 commits into from
Jan 21, 2023
3 changes: 3 additions & 0 deletions src/Simplify_Add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ Expr Simplify::visit(const Add *op, ExprInfo *bounds) {
rewrite(x*c0 + y*c1, (x + y*fold(c1/c0)) * c0, c1 % c0 == 0) ||
rewrite(x*c0 + y*c1, (x*fold(c0/c1) + y) * c1, c0 % c1 == 0) ||

// Hoist shuffles. The Shuffle visitor wants to sink
// extract_elements to the leaves, and those count as degenerate
// shuffles, so only hoist shuffles that grab more than one lane.
abadams marked this conversation as resolved.
Show resolved Hide resolved
rewrite(slice(x, c0, c1, c2) + slice(y, c0, c1, c2), slice(x + y, c0, c1, c2), c2 > 1) ||
rewrite(slice(x, c0, c1, c2) + (z + slice(y, c0, c1, c2)), slice(x + y, c0, c1, c2) + z, c2 > 1) ||
rewrite(slice(x, c0, c1, c2) + (slice(y, c0, c1, c2) + z), slice(x + y, c0, c1, c2) + z, c2 > 1) ||
Expand Down
3 changes: 3 additions & 0 deletions src/Simplify_Max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ Expr Simplify::visit(const Max *op, ExprInfo *bounds) {

rewrite(max(select(x, y, z), select(x, w, u)), select(x, max(y, w), max(z, u))) ||

// Hoist shuffles. The Shuffle visitor wants to sink
// extract_elements to the leaves, and those count as degenerate
// shuffles, so only hoist shuffles that grab more than one lane.
rewrite(max(slice(x, c0, c1, c2), slice(y, c0, c1, c2)), slice(max(x, y), c0, c1, c2), c2 > 1) ||
rewrite(max(slice(x, c0, c1, c2), max(slice(y, c0, c1, c2), z)), max(slice(max(x, y), c0, c1, c2), z), c2 > 1) ||
rewrite(max(slice(x, c0, c1, c2), max(z, slice(y, c0, c1, c2))), max(slice(max(x, y), c0, c1, c2), z), c2 > 1) ||
Expand Down
3 changes: 3 additions & 0 deletions src/Simplify_Min.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ Expr Simplify::visit(const Min *op, ExprInfo *bounds) {

rewrite(min(select(x, y, z), select(x, w, u)), select(x, min(y, w), min(z, u))) ||

// Hoist shuffles. The Shuffle visitor wants to sink
// extract_elements to the leaves, and those count as degenerate
// shuffles, so only hoist shuffles that grab more than one lane.
rewrite(min(slice(x, c0, c1, c2), slice(y, c0, c1, c2)), slice(min(x, y), c0, c1, c2), c2 > 1) ||
rewrite(min(slice(x, c0, c1, c2), min(slice(y, c0, c1, c2), z)), min(slice(min(x, y), c0, c1, c2), z), c2 > 1) ||
rewrite(min(slice(x, c0, c1, c2), min(z, slice(y, c0, c1, c2))), min(slice(min(x, y), c0, c1, c2), z), c2 > 1) ||
Expand Down
3 changes: 3 additions & 0 deletions src/Simplify_Mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ Expr Simplify::visit(const Mul *op, ExprInfo *bounds) {
rewrite(ramp(broadcast(x, c0), broadcast(y, c0), c1) * broadcast(z, c2),
ramp(broadcast(x * z, c0), broadcast(y * z, c0), c1), c2 == c0 * c1) ||

// Hoist shuffles. The Shuffle visitor wants to sink
// extract_elements to the leaves, and those count as degenerate
// shuffles, so only hoist shuffles that grab more than one lane.
rewrite(slice(x, c0, c1, c2) * slice(y, c0, c1, c2), slice(x * y, c0, c1, c2), c2 > 1) ||
rewrite(slice(x, c0, c1, c2) * (slice(y, c0, c1, c2) * z), slice(x * y, c0, c1, c2) * z, c2 > 1) ||
rewrite(slice(x, c0, c1, c2) * (z * slice(y, c0, c1, c2)), slice(x * y, c0, c1, c2) * z, c2 > 1) ||
Expand Down
3 changes: 3 additions & 0 deletions src/Simplify_Sub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ Expr Simplify::visit(const Sub *op, ExprInfo *bounds) {
rewrite(x - x%c0, (x/c0)*c0) ||
rewrite(x - ((x + c0)/c1)*c1, (x + c0)%c1 - c0, c1 > 0) ||

// Hoist shuffles. The Shuffle visitor wants to sink
// extract_elements to the leaves, and those count as degenerate
// shuffles, so only hoist shuffles that grab more than one lane.
rewrite(slice(x, c0, c1, c2) - slice(y, c0, c1, c2), slice(x - y, c0, c1, c2), c2 > 1) ||
rewrite(slice(x, c0, c1, c2) - (z + slice(y, c0, c1, c2)), slice(x - y, c0, c1, c2) - z, c2 > 1) ||
rewrite(slice(x, c0, c1, c2) - (slice(y, c0, c1, c2) + z), slice(x - y, c0, c1, c2) - z, c2 > 1) ||
Expand Down