Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler stack usage improvements #6239

Merged
merged 19 commits into from
Oct 1, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 70 additions & 25 deletions src/AsyncProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,28 +457,51 @@ class TightenProducerConsumerNodes : public IRMutator {

Stmt make_producer_consumer(const string &name, bool is_producer, Stmt body, const Scope<int> &scope) {
if (const LetStmt *let = body.as<LetStmt>()) {
if (expr_uses_vars(let->value, scope)) {
return ProducerConsumer::make(name, is_producer, body);
Stmt orig = body; // Only used to keep a reference to the let chain in scope.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but why is that necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated comment. It's because 'body' may be the only reference-counted Stmt keeping the first LetStmt alive, but we're mutating body to point to its innards. We don't want that first LetStmt to have its refcount hit zero and make our pointer dangle.


// Peel off all lets that don't depend on any vars in scope.
vector<const LetStmt *> containing_lets;
while (let && !expr_uses_vars(let->value, scope)) {
containing_lets.push_back(let);
body = let->body;
let = body.as<LetStmt>();
}

if (let) {
// That's as far as we can go
body = ProducerConsumer::make(name, is_producer, body);
} else {
return LetStmt::make(let->name, let->value, make_producer_consumer(name, is_producer, let->body, scope));
// Recurse onto a non-let-node
body = make_producer_consumer(name, is_producer, body, scope);
}

for (auto it = containing_lets.rbegin(); it != containing_lets.rend(); it++) {
body = LetStmt::make((*it)->name, (*it)->value, body);
}

return body;
} else if (const Block *block = body.as<Block>()) {
// Check which sides it's used on
bool first = stmt_uses_vars(block->first, scope);
bool rest = stmt_uses_vars(block->rest, scope);
if (is_producer) {
// We don't push produce nodes into blocks
return ProducerConsumer::make(name, is_producer, body);
} else if (first && rest) {
return Block::make(make_producer_consumer(name, is_producer, block->first, scope),
make_producer_consumer(name, is_producer, block->rest, scope));
} else if (first) {
return Block::make(make_producer_consumer(name, is_producer, block->first, scope), block->rest);
} else if (rest) {
return Block::make(block->first, make_producer_consumer(name, is_producer, block->rest, scope));
} else {
// Used on neither side?!
return body;
}
vector<Stmt> sub_stmts;
Stmt rest;
do {
Stmt first = block->first;
sub_stmts.push_back(block->first);
rest = block->rest;
block = rest.as<Block>();
} while (block);
sub_stmts.push_back(rest);

for (Stmt &s : sub_stmts) {
if (stmt_uses_vars(s, scope)) {
s = make_producer_consumer(name, is_producer, s, scope);
}
}

return Block::make(sub_stmts);
} else if (const ProducerConsumer *pc = body.as<ProducerConsumer>()) {
return ProducerConsumer::make(pc->name, pc->is_producer, make_producer_consumer(name, is_producer, pc->body, scope));
} else if (const Realize *r = body.as<Realize>()) {
Expand Down Expand Up @@ -561,16 +584,38 @@ class ExpandAcquireNodes : public IRMutator {
}

Stmt visit(const LetStmt *op) override {
Stmt body = mutate(op->body);
const Acquire *a = body.as<Acquire>();
if (a &&
!expr_uses_var(a->semaphore, op->name) &&
!expr_uses_var(a->count, op->name)) {
return Acquire::make(a->semaphore, a->count,
LetStmt::make(op->name, op->value, a->body));
} else {
return LetStmt::make(op->name, op->value, body);
Stmt orig = op;
Stmt body;
vector<const LetStmt *> frames;
do {
frames.push_back(op);
body = op->body;
op = body.as<LetStmt>();
} while (op);

Stmt s = mutate(body);

if (const Acquire *a = s.as<Acquire>()) {
// Pull the acquire node outside as many lets as possible,
// wrapping them around the Acquire node's original body.
body = a->body;
while (!frames.empty() &&
!expr_uses_var(a->semaphore, frames.back()->name) &&
!expr_uses_var(a->count, frames.back()->name)) {
body = LetStmt::make(frames.back()->name, frames.back()->value, body);
frames.pop_back();
}
s = Acquire::make(a->semaphore, a->count, body);
} else if (body.same_as(s)) {
return orig;
}

// Rewrap the rest of the lets
for (auto it = frames.rbegin(); it != frames.rend(); it++) {
s = LetStmt::make((*it)->name, (*it)->value, s);
}

return s;
}

Stmt visit(const ProducerConsumer *op) override {
Expand Down
63 changes: 59 additions & 4 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1909,10 +1909,30 @@ class SolveIfThenElse : public IRMutator {
}

Stmt visit(const LetStmt *op) override {
push_var(op->name);
Stmt stmt = IRMutator::visit(op);
pop_var(op->name);
return stmt;
Stmt orig = op;
vector<const LetStmt *> frames;
Stmt body;
do {
frames.push_back(op);
push_var(op->name);
body = op->body;
op = body.as<LetStmt>();
} while (op);

Stmt s = mutate(body);

if (s.same_as(body)) {
for (auto it = frames.rbegin(); it != frames.rend(); it++) {
pop_var((*it)->name);
}
return orig;
} else {
for (auto it = frames.rbegin(); it != frames.rend(); it++) {
pop_var((*it)->name);
s = LetStmt::make((*it)->name, (*it)->value, s);
}
return s;
}
}

Stmt visit(const For *op) override {
Expand Down Expand Up @@ -2903,6 +2923,41 @@ map<string, Box> boxes_touched(const Expr &e, Stmt s, bool consider_calls, bool
return op;
}

Stmt visit(const LetStmt *op) override {
// Walk eagerly through an entire let chain and either
// accept or reject all of them, not worrying about
// the case where some outer lets are relevant and
// some inner lets are not.
vector<const LetStmt *> frames;
Stmt orig = op;
Stmt body;
do {
// Visit the value just to check relevance. We
// don't expect Exprs to be mutated, so no need to
// keep the result.
mutate(op->value);
frames.push_back(op);
body = op->body;
op = body.as<LetStmt>();
} while (op);

Stmt s = mutate(body);

if (s.same_as(body)) {
return orig;
} else if (!relevant) {
// All the lets were irrelevant and so was the body
internal_assert(s.same_as(no_op));
return s;
} else {
// Rewrap the lets around the mutated body
for (auto it = frames.rbegin(); it != frames.rend(); it++) {
s = LetStmt::make((*it)->name, (*it)->value, s);
}
return s;
}
}

public:
Stmt mutate(const Stmt &s) override {
bool old = relevant;
Expand Down
5 changes: 3 additions & 2 deletions src/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -413,12 +413,13 @@ struct Realize : public StmtNode<Realize> {
static const IRNodeType _node_type = IRNodeType::Realize;
};

/** A sequence of statements to be executed in-order. 'rest' may be
* undefined. Used rest.defined() to find out. */
/** A sequence of statements to be executed in-order. 'first' is never
a Block, so this can be treated as a linked list. */
struct Block : public StmtNode<Block> {
Stmt first, rest;

static Stmt make(Stmt first, Stmt rest);

/** Construct zero or more Blocks to invoke a list of statements in order.
* This method may not return a Block statement if stmts.size() <= 1. */
static Stmt make(const std::vector<Stmt> &stmts);
Expand Down
45 changes: 28 additions & 17 deletions src/SkipStages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,37 +78,48 @@ class PredicateFinder : public IRVisitor {
op->body.accept(this);
if (should_pop) {
varying.pop(op->name);
//internal_assert(!expr_uses_var(predicate, op->name));
// internal_assert(!expr_uses_var(predicate, op->name));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should either remove this line or add a comment about why we aren't doing this check.

} else if (expr_uses_var(predicate, op->name)) {
predicate = Let::make(op->name, op->min, predicate);
}
}

template<typename T>
void visit_let(const std::string &name, const Expr &value, T body) {
bool old_varies = varies;
varies = false;
value.accept(this);
bool value_varies = varies;
varies |= old_varies;
if (value_varies) {
varying.push(name);
}
void visit_let(const T *op) {
struct Frame {
const T *op;
ScopedBinding<> binding;
};
vector<Frame> frames;

decltype(op->body) body;
do {
bool old_varies = varies;
varies = false;
op->value.accept(this);

frames.push_back(Frame{op, ScopedBinding<>(varies, varying, op->name)});

varies |= old_varies;
body = op->body;
op = body.template as<T>();
} while (op);

body.accept(this);
if (value_varies) {
varying.pop(name);
}
if (expr_uses_var(predicate, name)) {
predicate = Let::make(name, value, predicate);

for (auto it = frames.rbegin(); it != frames.rend(); it++) {
if (expr_uses_var(predicate, it->op->name)) {
predicate = Let::make(it->op->name, it->op->value, predicate);
}
}
}

void visit(const LetStmt *op) override {
visit_let(op->name, op->value, op->body);
visit_let(op);
}

void visit(const Let *op) override {
visit_let(op->name, op->value, op->body);
visit_let(op);
}

void visit(const ProducerConsumer *op) override {
Expand Down
55 changes: 34 additions & 21 deletions src/Substitute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,32 +40,45 @@ class Substitute : public IRMutator {
}
}

Expr visit(const Let *op) override {
Expr new_value = mutate(op->value);
hidden.push(op->name);
Expr new_body = mutate(op->body);
hidden.pop(op->name);

if (new_value.same_as(op->value) &&
new_body.same_as(op->body)) {
return op;
template<typename T>
auto visit_let(const T *op) -> decltype(op->body) {
decltype(op->body) orig = op;

struct Frame {
const T *op;
Expr new_value;
ScopedBinding<> bind;
};
std::vector<Frame> frames;
decltype(op->body) body;
bool values_unchanged = true;
do {
Expr new_value = mutate(op->value);
values_unchanged &= new_value.same_as(op->value);
frames.push_back(Frame{op, std::move(new_value), ScopedBinding<>(hidden, op->name)});
body = op->body;
op = body.template as<T>();
} while (op);

auto new_body = mutate(body);

if (values_unchanged &&
new_body.same_as(body)) {
return orig;
} else {
return Let::make(op->name, new_value, new_body);
for (auto it = frames.rbegin(); it != frames.rend(); it++) {
new_body = T::make(it->op->name, it->new_value, new_body);
}
return new_body;
}
}

Stmt visit(const LetStmt *op) override {
Expr new_value = mutate(op->value);
hidden.push(op->name);
Stmt new_body = mutate(op->body);
hidden.pop(op->name);
Expr visit(const Let *op) override {
return visit_let(op);
}

if (new_value.same_as(op->value) &&
new_body.same_as(op->body)) {
return op;
} else {
return LetStmt::make(op->name, new_value, new_body);
}
Stmt visit(const LetStmt *op) override {
return visit_let(op);
}

Stmt visit(const For *op) override {
Expand Down
25 changes: 15 additions & 10 deletions src/UniquifyVariableNames.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,25 @@ class FindFreeVars : public IRVisitor {
}
}

template<typename T>
void visit_let(const T *op) {
vector<ScopedBinding<>> frame;
decltype(op->body) body;
do {
op->value.accept(this);
frame.emplace_back(scope, op->name);
body = op->body;
op = body.template as<T>();
} while (op);
body.accept(this);
}

void visit(const Let *op) override {
op->value.accept(this);
{
ScopedBinding<> bind(scope, op->name);
op->body.accept(this);
}
visit_let(op);
}

void visit(const LetStmt *op) override {
op->value.accept(this);
{
ScopedBinding<> bind(scope, op->name);
op->body.accept(this);
}
visit_let(op);
}

void visit(const For *op) override {
Expand Down
7 changes: 6 additions & 1 deletion src/UnrollLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ class UnrollLoops : public IRMutator {
Stmt iters;
for (int i = e->value - 1; i >= 0; i--) {
Stmt iter = substitute(for_loop->name, for_loop->min + i, body);
// It's necessary to simplify eagerly this iteration
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "eagerly simplify"

// here to resolve things like muxes down to a single
// item before we go and make N copies of something of
// size N.
iter = simplify(iter);
if (!iters.defined()) {
iters = iter;
} else {
Expand All @@ -93,7 +98,7 @@ class UnrollLoops : public IRMutator {
}
}

return simplify(iters);
return iters;

} else {
return IRMutator::visit(for_loop);
Expand Down
Loading