Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix rfactor adding too many pure loops #8086

Merged
merged 1 commit into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,17 @@ Func Stage::rfactor(vector<pair<RVar, Var>> preserved) {
vector<Expr> &args = definition.args();
vector<Expr> &values = definition.values();

// Figure out which pure vars were used in this update definition.
std::set<string> pure_vars_used;
internal_assert(args.size() == dim_vars.size());
for (size_t i = 0; i < args.size(); i++) {
if (const Internal::Variable *var = args[i].as<Variable>()) {
if (var->name == dim_vars[i].name()) {
pure_vars_used.insert(var->name);
}
}
}

// Check whether the operator is associative and determine the operator and
// its identity for each value in the definition if it is a Tuple
const auto &prover_result = prove_associativity(func_name, args, values);
Expand Down Expand Up @@ -1012,16 +1023,20 @@ Func Stage::rfactor(vector<pair<RVar, Var>> preserved) {

// Determine the dims of the new update definition

// The new update definition needs all the pure vars of the Func, but the
// one we're rfactoring may not have used them all. Add any missing ones to
// the dims list.

// Add pure Vars from the original init definition to the dims list
// if they are not already in the list
for (const Var &v : dim_vars) {
const auto &iter = std::find_if(dims.begin(), dims.end(),
[&v](const Dim &dim) { return var_name_match(dim.var, v.name()); });
if (iter == dims.end()) {
if (!pure_vars_used.count(v.name())) {
Dim d = {v.name(), ForType::Serial, DeviceAPI::None, DimType::PureVar, Partition::Auto};
// Insert it just before Var::outermost
dims.insert(dims.end() - 1, d);
}
}

// Then, we need to remove lifted RVars from the dims list
for (const string &rv : rvars_removed) {
remove(rv);
Expand Down Expand Up @@ -1888,6 +1903,11 @@ Stage &Stage::reorder(const std::vector<VarOrRVar> &vars) {

dims_old.swap(dims);

// We're not allowed to reorder Var::outermost inwards (rfactor assumes it's
// the last one).
user_assert(dims.back().var == Var::outermost().name())
<< "Var::outermost() may not be reordered inside any other var.\n";

return *this;
}

Expand Down
25 changes: 25 additions & 0 deletions test/correctness/fuzz_schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,31 @@ int main(int argc, char **argv) {
check_blur_output(buf, correct);
}

// https://github.com/halide/Halide/issues/7890
{
Func input("input");
Func local_sum("local_sum");
Func blurry("blurry");
Var x("x"), y("y");
RVar yryf;
input(x, y) = 2 * x + 5 * y;
RDom r(-2, 5, -2, 5, "rdom_r");
local_sum(x, y) = 0;
local_sum(x, y) += input(x + r.x, y + r.y);
blurry(x, y) = cast<int32_t>(local_sum(x, y) / 25);

Var yo, yi, xo, xi, u;
blurry.split(y, yo, yi, 2, TailStrategy::Auto);
local_sum.split(x, xo, xi, 4, TailStrategy::Auto);
local_sum.update(0).split(x, xo, xi, 1, TailStrategy::Auto);
local_sum.update(0).rfactor(r.x, u);
blurry.store_root();
local_sum.compute_root();
Pipeline p({blurry});
auto buf = p.realize({32, 32});
check_blur_output(buf, correct);
}

// https://github.com/halide/Halide/issues/8054
{
ImageParam input(Float(32), 2, "input");
Expand Down
Loading