Skip to content

Commit

Permalink
Track whether or not let expressions failed to solve in solver (#7982)
Browse files Browse the repository at this point in the history
* Track whether or not let expressions failed to solve in solver

After mutating an expression, the solver needs to know two things:

1) Did the expression contain the variable we're solving for
2) Was the expression successfully "solved" for the variable. I.e. the
variable only appears once in the leftmost position. We need to know
this to know property 1 of any subexpressions (i.e. does the right child
of the expression contain the variable). This drives what
transformations we do in ways that are guaranteed to terminate and not
take exponential time.

We were tracking property 1 through lets but not property 2, and this
meant we were doing unhelpful transformations in some cases. I found a
case in the wild where this made a pipeline take > 1 hour to compile (I
killed it after an hour). It may have been in an infinite transformation
loop, or it might have just been exponential. Not sure.

* Remove surplus comma

* Fix use of uninitialized value that could cause bad transformation
  • Loading branch information
abadams authored and steven-johnson committed Feb 1, 2024
1 parent be6d6c6 commit 2111594
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 11 deletions.
6 changes: 4 additions & 2 deletions src/ModulusRemainder.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#include <cstdint>

#include "Util.h"

namespace Halide {

struct Expr;
Expand Down Expand Up @@ -83,8 +85,8 @@ ModulusRemainder modulus_remainder(const Expr &e, const Scope<ModulusRemainder>
/** Reduce an expression modulo some integer. Returns true and assigns
* to remainder if an answer could be found. */
///@{
bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder);
bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder, const Scope<ModulusRemainder> &scope);
HALIDE_MUST_USE_RESULT bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder);
HALIDE_MUST_USE_RESULT bool reduce_expr_modulo(const Expr &e, int64_t modulus, int64_t *remainder, const Scope<ModulusRemainder> &scope);
///@}

void modulus_remainder_test();
Expand Down
35 changes: 26 additions & 9 deletions src/Solve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,22 @@ class SolveExpression : public IRMutator {
map<Expr, CacheEntry, ExprCompare>::iterator iter = cache.find(e);
if (iter == cache.end()) {
// Not in the cache, call the base class version.
debug(4) << "Mutating " << e << " (" << uses_var << ")\n";
debug(4) << "Mutating " << e << " (" << uses_var << ", " << failed << ")\n";
bool old_uses_var = uses_var;
uses_var = false;
bool old_failed = failed;
failed = false;
Expr new_e = IRMutator::mutate(e);
CacheEntry entry = {new_e, uses_var};
CacheEntry entry = {new_e, uses_var, failed};
uses_var = old_uses_var || uses_var;
failed = old_failed || failed;
cache[e] = entry;
debug(4) << "(Miss) Rewrote " << e << " -> " << new_e << " (" << uses_var << ")\n";
debug(4) << "(Miss) Rewrote " << e << " -> " << new_e << " (" << uses_var << ", " << failed << ")\n";
return new_e;
} else {
// Cache hit.
uses_var = uses_var || iter->second.uses_var;
failed = failed || iter->second.failed;
debug(4) << "(Hit) Rewrote " << e << " -> " << iter->second.expr << " (" << uses_var << ")\n";
return iter->second.expr;
}
Expand All @@ -75,7 +79,7 @@ class SolveExpression : public IRMutator {
// stateless, so we can cache everything.
struct CacheEntry {
Expr expr;
bool uses_var;
bool uses_var, failed;
};
map<Expr, CacheEntry, ExprCompare> cache;

Expand Down Expand Up @@ -388,16 +392,25 @@ class SolveExpression : public IRMutator {
const Mul *mul_a = a.as<Mul>();
Expr expr;
if (a_uses_var && !b_uses_var) {
const int64_t *ib = as_const_int(b);
auto is_multiple_of_b = [&](const Expr &e) {
if (ib) {
int64_t r = 0;
return reduce_expr_modulo(e, *ib, &r) && r == 0;
} else {
return can_prove(e / b * b == e);
}
};
if (add_a && !a_failed &&
can_prove(add_a->a / b * b == add_a->a)) {
is_multiple_of_b(add_a->a)) {
// (f(x) + a) / b -> f(x) / b + a / b
expr = mutate(simplify(add_a->a / b) + add_a->b / b);
} else if (sub_a && !a_failed &&
can_prove(sub_a->a / b * b == sub_a->a)) {
is_multiple_of_b(sub_a->a)) {
// (f(x) - a) / b -> f(x) / b - a / b
expr = mutate(simplify(sub_a->a / b) - sub_a->b / b);
} else if (mul_a && !a_failed && no_overflow_int(op->type) &&
can_prove(mul_a->b / b * b == mul_a->b)) {
is_multiple_of_b(mul_a->b)) {
// (f(x) * a) / b -> f(x) * (a / b)
expr = mutate(mul_a->a * (mul_a->b / b));
}
Expand Down Expand Up @@ -776,6 +789,7 @@ class SolveExpression : public IRMutator {
} else if (scope.contains(op->name)) {
CacheEntry e = scope.get(op->name);
uses_var = uses_var || e.uses_var;
failed = failed || e.failed;
return e.expr;
} else if (external_scope.contains(op->name)) {
Expr e = external_scope.get(op->name);
Expand All @@ -790,11 +804,14 @@ class SolveExpression : public IRMutator {

Expr visit(const Let *op) override {
bool old_uses_var = uses_var;
bool old_failed = failed;
uses_var = false;
failed = false;
Expr value = mutate(op->value);
CacheEntry e = {value, uses_var};

CacheEntry e = {value, uses_var, failed};
uses_var = old_uses_var;
failed = old_failed;

ScopedBinding<CacheEntry> bind(scope, op->name, e);
return mutate(op->body);
}
Expand Down

0 comments on commit 2111594

Please sign in to comment.