Skip to content

Commit

Permalink
Rewrite the skip stages lowering pass (#8115)
Browse files Browse the repository at this point in the history
* Avoid redundant scope lookups

This pattern has been bugging me for a long time:

```
if (scope.contains(key)) {
  Foo f = scope.get(key);
}
```

This redundantly looks up the key in the scope twice. I've finally
gotten around to fixing it. I've introduced a find method that either
returns a const pointer to the value, if it exists, or null. It also
searches any containing scopes, which are held by const pointer, so the
method has to return a const pointer.

```
if (const Foo *f = scope.find(key)) {
}
```

For cases where you want to get and then mutate, I added shallow_find,
which doesn't search enclosing scopes, but returns a mutable pointer.

We were also doing redundant scope lookups in ScopedBinding. We stored
the key in the helper object, and then did a pop on that key in the
ScopedBinding destructor. This commit changes Scope so that Scope::push
returns an opaque token that you can pass to Scope::pop to have it
remove that element without doing a fresh lookup. ScopedBinding now uses
this. Under the hood it's just an iterator on the underlying map (map
iterators are not invalidated on inserting or removing other stuff).

The net effect is to speed up local laplacian lowering by about 5%

I also considered making it look more like an stl class, and having find
return an iterator, but it doesn't really work. The iterator it returns
might point to an entry in an enclosing scope, in which case you can't
compare it to the .end() method of the scope you have. Scopes are
different enough from maps that the interface really needs to be
distinct.

* Pacify clang-tidy

* Fix unintentional mutation of interval in scope

* Fix accidental Scope::get

* Rewrite the skip stages lowering pass

Skip stages was slow due to crappy computational complexity (quadratic?)

I reworked it into a two-pass linear-time algorithm. The first part
remembers which pieces of IR are actually relevant to the task, and the
second pass performs the task using a bounds-inference-like algorithm.

On main resnet50 spends 519 ms in this pass. This commit reduces it to
40 ms. Local laplacian with 100 pyramid levels spends 7.4 seconds in
this pass. This commit reduces it to ~3 ms.

This commit also moves the cache store for memoized Funcs into the
produce node, instead of at the top of the consume node, because it
naturally places it inside a condition you inject into the produce node.

* clang-tidy fixes

* Fix skip stages interaction with compute_with

* Unify let visitors, and use fewer stack frames for them

* Fix accidental leakage of .used into .loaded

* Visit the bodies of uninteresting let chains

* Another used -> loaded

* Fix hoist_storage not handling condition correctly.

---------

Co-authored-by: Steven Johnson <srj@google.com>
  • Loading branch information
abadams and steven-johnson committed Feb 27, 2024
1 parent 2b5beb3 commit 36d74a8
Show file tree
Hide file tree
Showing 10 changed files with 721 additions and 409 deletions.
7 changes: 6 additions & 1 deletion src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1383,9 +1383,14 @@ Stmt bounds_inference(Stmt s,
fused_pairs_in_groups.push_back(pairs);
}

// Add a note in the IR for where the outermost dynamic-stage skipping
// checks should go. These are injected in a later pass.
Expr marker = Call::make(Int(32), Call::skip_stages_marker, {}, Call::Intrinsic);
s = Block::make(Evaluate::make(marker), s);

// Add a note in the IR for where assertions on input images
// should go. Those are handled by a later lowering pass.
Expr marker = Call::make(Int(32), Call::add_image_checks_marker, {}, Call::Intrinsic);
marker = Call::make(Int(32), Call::add_image_checks_marker, {}, Call::Intrinsic);
s = Block::make(Evaluate::make(marker), s);

// Add a synthetic outermost loop to act as 'root'.
Expand Down
1 change: 1 addition & 0 deletions src/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ const char *const intrinsic_op_names[] = {
"shift_right",
"signed_integer_overflow",
"size_of_halide_buffer_t",
"skip_stages_marker",
"sliding_window_marker",
"sorted_avg",
"strict_float",
Expand Down
4 changes: 4 additions & 0 deletions src/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,10 @@ struct Call : public ExprNode<Call> {
signed_integer_overflow,
size_of_halide_buffer_t,

// Marks the point in lowering where the outermost skip stages checks
// should be introduced.
skip_stages_marker,

// Takes a realization name and a loop variable. Declares that values of
// the realization that were stored on earlier loop iterations of the
// given loop are potentially loaded in this loop iteration somewhere
Expand Down
2 changes: 1 addition & 1 deletion src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ void lower_impl(const vector<Function> &output_funcs,
log("Lowering after discarding safe promises:", s);

debug(1) << "Dynamically skipping stages...\n";
s = skip_stages(s, order);
s = skip_stages(s, outputs, fused_groups, env);
log("Lowering after dynamically skipping stages:", s);

debug(1) << "Forking asynchronous producers...\n";
Expand Down
14 changes: 6 additions & 8 deletions src/Memoization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,10 @@ class InjectMemoization : public IRMutator {

Stmt body = mutate(op->body);

std::string cache_miss_name = op->name + ".cache_miss";
Expr cache_miss = Variable::make(Bool(), cache_miss_name);

if (op->is_producer) {
Stmt mutated_body = IfThenElse::make(cache_miss, body);
return ProducerConsumer::make(op->name, op->is_producer, mutated_body);
} else {
std::string cache_miss_name = op->name + ".cache_miss";
Expr cache_miss = Variable::make(Bool(), cache_miss_name);

const Function f(iter->second);
KeyInfo key_info(f, top_level_name, memoize_instance);

Expand All @@ -447,9 +444,10 @@ class InjectMemoization : public IRMutator {
key_info.store_computation(cache_key_name, computed_bounds_name,
eviction_key_name, f.outputs(), op->name));

Stmt mutated_body = Block::make(cache_store_back, body);
return ProducerConsumer::make(op->name, op->is_producer, mutated_body);
body = Block::make(body, cache_store_back);
body = IfThenElse::make(cache_miss, body);
}
return ProducerConsumer::make(op->name, op->is_producer, body);
} else {
return IRMutator::visit(op);
}
Expand Down
5 changes: 5 additions & 0 deletions src/Scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,11 @@ class Scope {
}
}

/** How many distinct names exist (does not count nested definitions of the same name) */
size_t size() const {
return table.size();
}

struct PushToken {
typename std::map<std::string, SmallStack<T>>::iterator iter;
};
Expand Down
Loading

0 comments on commit 36d74a8

Please sign in to comment.