Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster vars used tracking in simplify let visitor #8205

Merged
merged 18 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/IntrusivePtr.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class RefCount {
bool is_const_zero() const {
return count == 0;
}
int atomic_get() const {
return count;
}
};

/**
Expand Down Expand Up @@ -173,6 +176,11 @@ struct IntrusivePtr {
bool operator<(const IntrusivePtr<T> &other) const {
return ptr < other.ptr;
}

HALIDE_ALWAYS_INLINE
bool is_sole_reference() const {
return ptr && ref_count(ptr).atomic_get() == 1;
}
};

} // namespace Internal
Expand Down
85 changes: 62 additions & 23 deletions src/Simplify_Let.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include "Simplify_Internal.h"
#include "Substitute.h"

#include <unordered_set>

namespace Halide {
namespace Internal {

Expand All @@ -9,34 +11,50 @@ using std::vector;

namespace {

class CountVarUses : public IRVisitor {
std::map<std::string, int> &var_uses;
class FindVarUses : public IRVisitor {
std::unordered_set<std::string> &unused_vars;

void visit(const Variable *var) override {
var_uses[var->name]++;
unused_vars.erase(var->name);
}

void visit(const Load *op) override {
var_uses[op->name]++;
IRVisitor::visit(op);
if (!unused_vars.empty()) {
unused_vars.erase(op->name);
IRVisitor::visit(op);
}
}

void visit(const Store *op) override {
var_uses[op->name]++;
IRVisitor::visit(op);
if (!unused_vars.empty()) {
unused_vars.erase(op->name);
IRVisitor::visit(op);
}
}

void visit(const Block *op) override {
// Early out at Block nodes if we've already seen every name we're
// interested in. In principal we could early-out at every node, but
// blocks, loads, and stores seem to be enough.
if (!unused_vars.empty()) {
rootjalex marked this conversation as resolved.
Show resolved Hide resolved
op->first.accept(this);
if (!unused_vars.empty()) {
op->rest.accept(this);
}
}
}

using IRVisitor::visit;

public:
CountVarUses(std::map<std::string, int> &var_uses)
: var_uses(var_uses) {
FindVarUses(std::unordered_set<std::string> &unused_vars)
: unused_vars(unused_vars) {
}
};

template<typename StmtOrExpr>
void count_var_uses(StmtOrExpr x, std::map<std::string, int> &var_uses) {
CountVarUses counter(var_uses);
void find_var_uses(StmtOrExpr x, std::unordered_set<std::string> &unused_vars) {
FindVarUses counter(unused_vars);
x.accept(&counter);
}

Expand All @@ -49,10 +67,11 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {
// the call stack where it could overflow onto an explicit stack.
struct Frame {
const LetOrLetStmt *op;
Expr value, new_value;
Expr value, new_value, new_var;
string new_name;
bool new_value_alignment_tracked = false, new_value_bounds_tracked = false;
bool value_alignment_tracked = false, value_bounds_tracked = false;
VarInfo info;
Frame(const LetOrLetStmt *op)
: op(op) {
}
Expand Down Expand Up @@ -189,6 +208,7 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {
// Nothing to substitute
f.new_value = Expr();
replacement = Expr();
new_var = Expr();
} else {
debug(4) << "new let " << f.new_name << " = " << f.new_value << " in ... " << replacement << " ...\n";
}
Expand All @@ -197,6 +217,7 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {
info.old_uses = 0;
info.new_uses = 0;
info.replacement = replacement;
f.new_var = new_var;

var_info.push(op->name, info);

Expand Down Expand Up @@ -226,14 +247,35 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {

result = mutate_let_body(result, bounds);

// TODO: var_info and vars_used are pretty redundant; however, at the time
// TODO: var_info and unused_vars are pretty redundant; however, at the time
// of writing, both cover cases that the other does not:
// - var_info prevents duplicate lets from being generated, even
// from different Frame objects.
// - vars_used avoids dead lets being generated in cases where vars are
// - unused_vars avoids dead lets being generated in cases where vars are
// seen as used by var_info, and then later removed.
std::map<std::string, int> vars_used;
count_var_uses(result, vars_used);

std::unordered_set<std::string> unused_vars(frames.size());
// Insert everything we think *might* be used, and then visit the body,
// removing things from the set as we find uses of them.
for (auto &f : frames) {
f.info = var_info.get(f.op->name);
// Drop any reference to new_var held by the replacement expression so
// that the only references are either f.new_var, or ones in the body or
// new_values of other lets.
f.info.replacement = Expr();
if (f.new_var.is_sole_reference()) {
// Any new_uses must have been eliminated by later mutations.
f.info.new_uses = 0;
}
var_info.pop(f.op->name);
if (f.info.old_uses) {
internal_assert(f.info.new_uses == 0);
unused_vars.insert(f.op->name);
} else if (f.info.new_uses && f.new_value.defined()) {
unused_vars.insert(f.new_name);
}
}
find_var_uses(result, unused_vars);

for (auto it = frames.rbegin(); it != frames.rend(); it++) {
if (it->value_bounds_tracked) {
Expand All @@ -243,20 +285,17 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *bounds) {
bounds_and_alignment_info.pop(it->new_name);
}

VarInfo info = var_info.get(it->op->name);
var_info.pop(it->op->name);

if (it->new_value.defined() && (info.new_uses > 0 && vars_used.count(it->new_name) > 0)) {
if (it->new_value.defined() && (it->info.new_uses > 0 && !unused_vars.count(it->new_name))) {
// The new name/value may be used
result = LetOrLetStmt::make(it->new_name, it->new_value, result);
count_var_uses(it->new_value, vars_used);
find_var_uses(it->new_value, unused_vars);
}

if ((!remove_dead_code && std::is_same<LetOrLetStmt, LetStmt>::value) ||
(info.old_uses > 0 && vars_used.count(it->op->name) > 0)) {
(it->info.old_uses > 0 && !unused_vars.count(it->op->name))) {
// The old name is still in use. We'd better keep it as well.
result = LetOrLetStmt::make(it->op->name, it->value, result);
count_var_uses(it->value, vars_used);
find_var_uses(it->value, unused_vars);
}

const LetOrLetStmt *new_op = result.template as<LetOrLetStmt>();
Expand Down
Loading