Skip to content

Commit

Permalink
Fix hoist_storage not handling condition correctly. (#8123)
Browse files Browse the repository at this point in the history
The allocation condition wasn't getting relaxed over the scope and loop
vars like the extents were.
  • Loading branch information
abadams authored Feb 27, 2024
1 parent aae84f6 commit 2b5beb3
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 10 deletions.
32 changes: 23 additions & 9 deletions src/StorageFlattening.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,23 +293,37 @@ class FlattenDimensions : public IRMutator {
stmt = LetStmt::make(op->name + ".buffer", builder.build(), stmt);
if (hoisted_storages_map.count(op->name) > 0) {
HoistedStorageData &hoisted_storage_data = hoisted_storages[hoisted_storages_map[op->name]];
vector<Expr> bounded_extents;
for (const auto &e : allocation_extents) {
Expr expanded_extent = e;

auto expand_and_bound = [&](Expr e) {
// Iterate from innermost outwards
for (auto it = hoisted_storages.rbegin(); it != hoisted_storages.rend(); it++) {
expanded_extent = expand_expr(expanded_extent, it->scope);
e = expand_expr(e, it->scope);
if (it->name == op->name) {
break;
}
}
expanded_extent = simplify(common_subexpression_elimination(expanded_extent));
Interval bounds = bounds_of_expr_in_scope(expanded_extent, hoisted_storage_data.loop_vars);
user_assert(bounds.max.defined()) << "Couldn't infer the upper bound for the storage size of " << op->name << ", consider using bound_storage.\n";
bounded_extents.push_back(bounds.max);

e = simplify(common_subexpression_elimination(e));
Interval bounds = bounds_of_expr_in_scope(e, hoisted_storage_data.loop_vars);
return bounds.max;
};

vector<Expr> bounded_extents;
for (const auto &e : allocation_extents) {
Expr expanded_extent = expand_and_bound(e);
user_assert(expanded_extent.defined() &&
!expanded_extent.same_as(Interval::pos_inf()))
<< "Couldn't infer the upper bound for the storage size of " << op->name << ", consider using bound_storage.\n";
bounded_extents.push_back(expanded_extent);
}

Expr expanded_condition = expand_and_bound(condition);
if (!expanded_condition.defined() ||
expanded_condition.same_as(Interval::pos_inf())) {
expanded_condition = const_true();
}

HoistedAllocationInfo hoisted_alloc(op->name, op->types[0], op->memory_type, bounded_extents, condition);
HoistedAllocationInfo hoisted_alloc(op->name, op->types[0], op->memory_type, bounded_extents, expanded_condition);

hoisted_storage_data.hoisted_allocations.push_back(hoisted_alloc);
} else {
Expand Down
26 changes: 25 additions & 1 deletion test/correctness/skip_stages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void check_counts(int a = 0, int b = 0, int c = 0, int d = 0) {
}

int main(int argc, char **argv) {
Var x;
Var x, y;
Param<bool> toggle1, toggle2;

{
Expand Down Expand Up @@ -201,6 +201,30 @@ int main(int argc, char **argv) {
check_counts(11);
}

{
// Check the interation with storage hoisting

// This Func may or may not be loaded, depending on y
Func maybe_loaded("maybe_loaded");
maybe_loaded(x, y) = x + y;

// This Func may or may not be used, depending on y
Func maybe_used("maybe_used");
maybe_used(x, y) = maybe_loaded(x, y);

Func output("output");
output(x, y) = select(y % 100 == 37, 0, maybe_used(x, y));

// The allocation condition depends on y, but the actual allocation
// happens at the root level.
maybe_loaded.compute_at(output, y).hoist_storage_root();
maybe_used.compute_at(output, y).hoist_storage_root();

// This will fail to compile with an undefined symbol if we haven't
// handled the condition correctly.
output.realize({100, 100});
}

printf("Success!\n");
return 0;
}

0 comments on commit 2b5beb3

Please sign in to comment.