Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster vars used tracking in simplify let visitor #8205

Merged
merged 18 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 23 additions & 32 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ int static_sign(const Expr &x) {
return 0;
}

Interval simplify(const Interval &i) {
Interval result;
result.min = simplify(i.min);
if (i.is_single_point()) {
result.max = result.min;
} else {
result.max = simplify(i.max);
}
return result;
}

} // anonymous namespace

const FuncValueBounds &empty_func_value_bounds() {
Expand All @@ -109,8 +120,7 @@ Expr find_constant_bound(const Expr &e, Direction d, const Scope<Interval> &scop
Interval find_constant_bounds(const Expr &e, const Scope<Interval> &scope) {
Expr expr = bound_correlated_differences(simplify(remove_likelies(e)));
Interval interval = bounds_of_expr_in_scope(expr, scope, FuncValueBounds(), true);
interval.min = simplify(interval.min);
interval.max = simplify(interval.max);
interval = simplify(interval);

// Note that we can get non-const but well-defined results (e.g. signed_integer_overflow);
// for our purposes here, treat anything non-const as no-bound.
Expand Down Expand Up @@ -158,16 +168,6 @@ class Bounds : public IRVisitor {
Bounds(const Scope<Interval> *s, const FuncValueBounds &fb, bool const_bound)
: func_bounds(fb), const_bound(const_bound) {
scope.set_containing_scope(s);

// Find any points that are single_points but fail is_single_point due to
// pointer equality checks and replace with single_points.
for (auto item = s->cbegin(); item != s->cend(); ++item) {
const Interval &item_interval = item.value();
if (!item_interval.is_single_point() &&
equal(item_interval.min, item_interval.max)) {
scope.push(item.name(), Interval::single_point(item_interval.min));
}
}
}

#if DO_TRACK_BOUNDS_INTERVALS
Expand Down Expand Up @@ -325,8 +325,7 @@ class Bounds : public IRVisitor {
// constants, so try to make the constants first.

// First constant-fold
a.min = simplify(a.min);
a.max = simplify(a.max);
a = simplify(a);

// Then try to strip off junk mins and maxes.
bool old_constant_bound = const_bound;
Expand Down Expand Up @@ -355,8 +354,7 @@ class Bounds : public IRVisitor {
// a is bounded, but from and to can't necessarily represent
// each other; however, if the bounds can be simplified to
// constants, they might fit regardless of types.
a.min = simplify(a.min);
a.max = simplify(a.max);
a = simplify(a);
const auto *umin = as_const_uint(a.min);
const auto *umax = as_const_uint(a.max);
if (umin && umax && to.can_represent(*umin) && to.can_represent(*umax)) {
Expand Down Expand Up @@ -2573,13 +2571,11 @@ class BoxesTouched : public IRGraphVisitor {
op->value.accept(this);

f.value_bounds = bounds_of_expr_in_scope(op->value, scope, func_bounds);

bool fixed = f.value_bounds.min.same_as(f.value_bounds.max);
f.value_bounds.min = simplify(f.value_bounds.min);
f.value_bounds.max = fixed ? f.value_bounds.min : simplify(f.value_bounds.max);
f.value_bounds = simplify(f.value_bounds);

if (is_small_enough_to_substitute(f.value_bounds.min) &&
(fixed || is_small_enough_to_substitute(f.value_bounds.max))) {
(f.value_bounds.is_single_point() ||
is_small_enough_to_substitute(f.value_bounds.max))) {
scope.push(op->name, f.value_bounds);
} else {
f.max_name = unique_name('t');
Expand Down Expand Up @@ -2769,9 +2765,7 @@ class BoxesTouched : public IRGraphVisitor {
const Expr *val = let_stmts.find(l.var);
internal_assert(val);
v_bound = bounds_of_expr_in_scope(*val, scope, func_bounds);
bool fixed = v_bound.min.same_as(v_bound.max);
v_bound.min = simplify(v_bound.min);
v_bound.max = fixed ? v_bound.min : simplify(v_bound.max);
v_bound = simplify(v_bound);

const Interval *old_bound = scope.find(l.var);
internal_assert(old_bound);
Expand Down Expand Up @@ -3368,12 +3362,12 @@ FuncValueBounds compute_function_value_bounds(const vector<string> &order,
result = compute_pure_function_definition_value_bounds(f.definition(), arg_scope, fb, j);
// These can expand combinatorially as we go down the
// pipeline if we don't run CSE on them.
bool fixed = result.is_single_point();
if (result.has_lower_bound()) {
result.min = simplify(common_subexpression_elimination(result.min));
}

if (result.has_upper_bound()) {
result.max = simplify(common_subexpression_elimination(result.max));
result.max = fixed ? result.min : simplify(common_subexpression_elimination(result.max));
}

fb[key] = result;
Expand Down Expand Up @@ -3431,8 +3425,7 @@ namespace {
void check(const Scope<Interval> &scope, const Expr &e, const Expr &correct_min, const Expr &correct_max) {
FuncValueBounds fb;
Interval result = bounds_of_expr_in_scope(e, scope, fb);
result.min = simplify(result.min);
result.max = simplify(result.max);
result = simplify(result);
if (!equal(result.min, correct_min)) {
internal_error << "In bounds of " << e << ":\n"
<< "Incorrect min: " << result.min << "\n"
Expand All @@ -3448,8 +3441,7 @@ void check(const Scope<Interval> &scope, const Expr &e, const Expr &correct_min,
void check_constant_bound(const Scope<Interval> &scope, const Expr &e, const Expr &correct_min, const Expr &correct_max) {
FuncValueBounds fb;
Interval result = bounds_of_expr_in_scope(e, scope, fb, true);
result.min = simplify(result.min);
result.max = simplify(result.max);
result = simplify(result);
if (!equal(result.min, correct_min)) {
internal_error << "In find constant bound of " << e << ":\n"
<< "Incorrect min constant bound: " << result.min << "\n"
Expand Down Expand Up @@ -3603,8 +3595,7 @@ void boxes_touched_test() {
for (size_t i = 0; i < result.size(); ++i) {
const Interval &correct = expected[i];
Interval b = result[i];
b.min = simplify(b.min);
b.max = simplify(b.max);
b = simplify(b);
if (!equal(correct.min, b.min)) {
internal_error << "In bounds of dim " << i << ":\n"
<< "Incorrect min: " << b.min << "\n"
Expand Down
4 changes: 2 additions & 2 deletions src/Interval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ bool Interval::is_everything() const {
}

bool Interval::is_single_point() const {
return min.same_as(max);
return is_bounded() && equal(min, max);
}

bool Interval::is_single_point(const Expr &e) const {
return min.same_as(e) && max.same_as(e);
return is_bounded() && equal(min, e) && equal(max, e);
}

bool Interval::has_upper_bound() const {
Expand Down
73 changes: 51 additions & 22 deletions src/Simplify_Let.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "Simplify_Internal.h"
#include "Substitute.h"

#include <unordered_set>

namespace Halide {
namespace Internal {

Expand All @@ -9,34 +11,50 @@ using std::vector;

namespace {

class CountVarUses : public IRVisitor {
std::map<std::string, int> &var_uses;
class FindVarUses : public IRVisitor {
std::unordered_set<std::string> &unused_vars;

void visit(const Variable *var) override {
var_uses[var->name]++;
unused_vars.erase(var->name);
}

void visit(const Load *op) override {
var_uses[op->name]++;
IRVisitor::visit(op);
if (!unused_vars.empty()) {
unused_vars.erase(op->name);
IRVisitor::visit(op);
}
}

void visit(const Store *op) override {
var_uses[op->name]++;
IRVisitor::visit(op);
if (!unused_vars.empty()) {
unused_vars.erase(op->name);
IRVisitor::visit(op);
}
}

void visit(const Block *op) override {
// Early out at Block nodes if we've already seen every name we're
// interested in. In principal we could early-out at every node, but
// blocks, loads, and stores seem to be enough.
if (!unused_vars.empty()) {
rootjalex marked this conversation as resolved.
Show resolved Hide resolved
op->first.accept(this);
if (!unused_vars.empty()) {
op->rest.accept(this);
}
}
}

using IRVisitor::visit;

public:
CountVarUses(std::map<std::string, int> &var_uses)
: var_uses(var_uses) {
FindVarUses(std::unordered_set<std::string> &unused_vars)
: unused_vars(unused_vars) {
}
};

template<typename StmtOrExpr>
void count_var_uses(StmtOrExpr x, std::map<std::string, int> &var_uses) {
CountVarUses counter(var_uses);
void find_var_uses(StmtOrExpr x, std::unordered_set<std::string> &unused_vars) {
FindVarUses counter(unused_vars);
x.accept(&counter);
}

Expand All @@ -53,6 +71,7 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {
string new_name;
bool new_value_alignment_tracked = false, new_value_bounds_tracked = false;
bool value_alignment_tracked = false, value_bounds_tracked = false;
VarInfo info;
Frame(const LetOrLetStmt *op)
: op(op) {
}
Expand Down Expand Up @@ -226,14 +245,27 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {

result = mutate_let_body(result, bounds);

// TODO: var_info and vars_used are pretty redundant; however, at the time
// TODO: var_info and unused_vars are pretty redundant; however, at the time
// of writing, both cover cases that the other does not:
// - var_info prevents duplicate lets from being generated, even
// from different Frame objects.
// - vars_used avoids dead lets being generated in cases where vars are
// - unused_vars avoids dead lets being generated in cases where vars are
// seen as used by var_info, and then later removed.
std::map<std::string, int> vars_used;
count_var_uses(result, vars_used);

std::unordered_set<std::string> unused_vars(frames.size() * 2);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does frames.size() * 2 come from?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, each frame can maybe insert two names. Maybe add a comment above this definition, for explanation?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think old_uses and new_uses might be mutually exclusive. Will test.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Fixed.

// Insert everything we think *might* be used, and then visit the body,
// removing things from the set as we find uses of them.
for (auto &f : frames) {
f.info = var_info.get(f.op->name);
var_info.pop(f.op->name);
if (f.info.old_uses) {
unused_vars.insert(f.op->name);
}
if (f.info.new_uses) {
unused_vars.insert(f.new_name);
}
}
find_var_uses(result, unused_vars);

for (auto it = frames.rbegin(); it != frames.rend(); it++) {
if (it->value_bounds_tracked) {
Expand All @@ -243,20 +275,17 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {
bounds_and_alignment_info.pop(it->new_name);
}

VarInfo info = var_info.get(it->op->name);
var_info.pop(it->op->name);

if (it->new_value.defined() && (info.new_uses > 0 && vars_used.count(it->new_name) > 0)) {
if (it->new_value.defined() && (it->info.new_uses > 0 && !unused_vars.count(it->new_name))) {
// The new name/value may be used
result = LetOrLetStmt::make(it->new_name, it->new_value, result);
count_var_uses(it->new_value, vars_used);
find_var_uses(it->new_value, unused_vars);
}

if ((!remove_dead_code && std::is_same<LetOrLetStmt, LetStmt>::value) ||
(info.old_uses > 0 && vars_used.count(it->op->name) > 0)) {
(it->info.old_uses > 0 && !unused_vars.count(it->op->name))) {
// The old name is still in use. We'd better keep it as well.
result = LetOrLetStmt::make(it->op->name, it->value, result);
count_var_uses(it->value, vars_used);
find_var_uses(it->value, unused_vars);
}

const LetOrLetStmt *new_op = result.template as<LetOrLetStmt>();
Expand Down
4 changes: 2 additions & 2 deletions src/SkipStages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ class SkipStages : public IRMutator {

Stmt emit_defs(Stmt stmt) {
for (auto &p : func_info) {
stmt = LetStmt::make(used_var_name(p.first), p.second.used, stmt);
stmt = LetStmt::make(loaded_var_name(p.first), p.second.loaded, stmt);
stmt = LetStmt::make(used_var_name(p.first), simplify(p.second.used), stmt);
stmt = LetStmt::make(loaded_var_name(p.first), simplify(p.second.loaded), stmt);
need_uniquify |= !lets_emitted.insert(p.first).second;
}
return stmt;
Expand Down
Loading