Skip to content

Commit

Permalink
make is_3d_tile_index robust to indexing changes
Browse files Browse the repository at this point in the history
  • Loading branch information
frengels committed Sep 8, 2021
1 parent 34557cb commit 5ad06e0
Showing 1 changed file with 33 additions and 11 deletions.
44 changes: 33 additions & 11 deletions src/ExtractTileOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,28 @@ Tile<2> is_2d_tile_index(const Expr &e) {

Tile<3> is_3d_tile_index(const Expr &e) {
vector<Expr> matches;
auto add_sub_pattern = (wild_i32x + wild_i32x) - wild_i32x;
if (!expr_match(add_sub_pattern, e, matches)) {

// there could be a sub node
const Sub* sub = e.as<Sub>();

const Add* add = nullptr;

if (sub) {
add = sub->a.as<Add>();
}
else {
add = e.as<Add>();
}

if (!add) {
return {};
}
// ramp(x16(base), x16(stride), 4) + x16(ramp(idx, 1, 4)) y: 4, x: 4, r: 4
// ramp(x10(base), x10(stride), 3) + x6(ramp(idx, 1, 5)) y: 2, x: 3, r: 5
Expr first = std::move(matches[0]);
Expr second = std::move(matches[1]);
Expr adj = std::move(matches[2]);

auto& first = add->a;
auto& second = add->b;

// ramp(x[x*r](base), x[x*r](stride), x) + x[x*y](ramp(idx, 1, r))

const auto *r1 = first.as<Ramp>();
const auto *b2 = second.as<Broadcast>();
if (!r1 && !b2) {
Expand Down Expand Up @@ -105,11 +118,20 @@ Tile<3> is_3d_tile_index(const Expr &e) {
base += std::move(matches[0]);
Expr r_stride = std::move(matches[1]);

auto pattern3 = Broadcast::make(wild_i32, b1->lanes * r1->lanes);
if (!expr_match(pattern3, adj, matches)) {
return {};
if (sub) {
Expr adj = sub->b;
const Broadcast* bcast = adj.as<Broadcast>();

if (!bcast) {
return {};
}

if (bcast->lanes != b1->lanes * r1->lanes) {
return {};
}

base -= bcast->value;
}
base -= std::move(matches[0]);

return {true, base, {x_stride, 0, r_stride}, {x_tile, y_tile, r_tile}};
}
Expand Down

0 comments on commit 5ad06e0

Please sign in to comment.