From bf0d61149dde511f39b950689c2a08af7078e88b Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 12 Mar 2024 09:49:26 -0700 Subject: [PATCH] Rewrite the pass that adds mutexes for atomic nodes (#8105) * Avoid redundant scope lookups This pattern has been bugging me for a long time: ``` if (scope.contains(key)) { Foo f = scope.get(key); } ``` This redundantly looks up the key in the scope twice. I've finally gotten around to fixing it. I've introduced a find method that either returns a const pointer to the value, if it exists, or null. It also searches any containing scopes, which are held by const pointer, so the method has to return a const pointer. ``` if (const Foo *f = scope.find(key)) { } ``` For cases where you want to get and then mutate, I added shallow_find, which doesn't search enclosing scopes, but returns a mutable pointer. We were also doing redundant scope lookups in ScopedBinding. We stored the key in the helper object, and then did a pop on that key in the ScopedBinding destructor. This commit changes Scope so that Scope::push returns an opaque token that you can pass to Scope::pop to have it remove that element without doing a fresh lookup. ScopedBinding now uses this. Under the hood it's just an iterator on the underlying map (map iterators are not invalidated on inserting or removing other stuff). The net effect is to speed up local laplacian lowering by about 5% I also considered making it look more like an stl class, and having find return an iterator, but it doesn't really work. The iterator it returns might point to an entry in an enclosing scope, in which case you can't compare it to the .end() method of the scope you have. Scopes are different enough from maps that the interface really needs to be distinct. * Pacify clang-tidy * Rewrite the pass that injects mutexes to support atomics For O(n) nested allocate nodes, this pass was quadratic in n, even if there was no use of atomics. This commit rewrites it to use a linear-time algorithm, and skips it entirely after the first validation pass if there aren't any atomic nodes. It also needlessly used IRGraphMutators, which slowed things down, didn't handle LargeBuffers (could overflow in the allocation), incorrectly thought every producer/consumer node was associated with an output buffer, and didn't print the realization name when printing the atomic node (the body of an atomic node is only atomic w.r.t. a specific realization). I noticed all this because it stuck out in a profile. For resnet 50, the rewrite that changed to a linear algorithm took this stage from 185ms down to 6.7ms, and then skipping it entirely when it doesn't find any atomic nodes added 1.5 for the single IRVisitor check. For local laplacian with 100 pyramid levels (which contains many nested allocate nodes due to a large number of skip connections), the times are 5846 ms -> 16 ms -> 4.6 ms This is built on top of #8103 * Fix unintentional mutation of interval in scope --------- Co-authored-by: Steven Johnson --- src/AddAtomicMutex.cpp | 216 ++++++++++++++------------- src/AddAtomicMutex.h | 2 +- src/IRPrinter.cpp | 9 +- src/Lower.cpp | 2 +- src/runtime/HalideRuntime.h | 2 +- src/runtime/fake_thread_pool.cpp | 2 +- src/runtime/synchronization_common.h | 2 +- 7 files changed, 119 insertions(+), 116 deletions(-) diff --git a/src/AddAtomicMutex.cpp b/src/AddAtomicMutex.cpp index a2bf990e38f6..cf3b0ae8bb89 100644 --- a/src/AddAtomicMutex.cpp +++ b/src/AddAtomicMutex.cpp @@ -1,5 +1,4 @@ #include "AddAtomicMutex.h" - #include "ExprUsesVar.h" #include "Func.h" #include "IREquality.h" @@ -11,14 +10,10 @@ namespace Halide { namespace Internal { -using std::map; -using std::set; -using std::string; - namespace { /** Collect names of all stores matching the producer name inside a statement. */ -class CollectProducerStoreNames : public IRGraphVisitor { +class CollectProducerStoreNames : public IRVisitor { public: CollectProducerStoreNames(const std::string &producer_name) : producer_name(producer_name) { @@ -27,12 +22,12 @@ class CollectProducerStoreNames : public IRGraphVisitor { Scope store_names; protected: - using IRGraphVisitor::visit; + using IRVisitor::visit; void visit(const Store *op) override { - IRGraphVisitor::visit(op); + IRVisitor::visit(op); if (op->name == producer_name || starts_with(op->name, producer_name + ".")) { - // This is a Store for the desginated Producer. + // This is a Store for the designated Producer. store_names.push(op->name); } } @@ -42,7 +37,7 @@ class CollectProducerStoreNames : public IRGraphVisitor { /** Find Store inside of an Atomic node for the designated producer * and return their indices. */ -class FindProducerStoreIndex : public IRGraphVisitor { +class FindProducerStoreIndex : public IRVisitor { public: FindProducerStoreIndex(const std::string &producer_name) : producer_name(producer_name) { @@ -51,11 +46,11 @@ class FindProducerStoreIndex : public IRGraphVisitor { Expr index; // The returned index. protected: - using IRGraphVisitor::visit; + using IRVisitor::visit; // Need to also extract the let bindings of a Store index. void visit(const Let *op) override { - IRGraphVisitor::visit(op); // Make sure we visit the Store first. + IRVisitor::visit(op); // Make sure we visit the Store first. if (index.defined()) { if (expr_uses_var(index, op->name)) { index = Let::make(op->name, op->value, index); @@ -63,7 +58,7 @@ class FindProducerStoreIndex : public IRGraphVisitor { } } void visit(const LetStmt *op) override { - IRGraphVisitor::visit(op); // Make sure we visit the Store first. + IRVisitor::visit(op); // Make sure we visit the Store first. if (index.defined()) { if (expr_uses_var(index, op->name)) { index = Let::make(op->name, op->value, index); @@ -72,7 +67,7 @@ class FindProducerStoreIndex : public IRGraphVisitor { } void visit(const Store *op) override { - IRGraphVisitor::visit(op); + IRVisitor::visit(op); if (op->name == producer_name || starts_with(op->name, producer_name + ".")) { // This is a Store for the designated producer. @@ -94,11 +89,13 @@ class FindProducerStoreIndex : public IRGraphVisitor { /** Throws an assertion for cases where the indexing on left-hand-side of * an atomic update references to itself. * e.g. f(clamp(f(r), 0, 100)) = f(r) + 1 should be rejected. */ -class CheckAtomicValidity : public IRGraphVisitor { +class CheckAtomicValidity : public IRVisitor { protected: - using IRGraphVisitor::visit; + using IRVisitor::visit; void visit(const Atomic *op) override { + any_atomic = true; + // Collect the names of all Store nodes inside. CollectProducerStoreNames collector(op->producer_name); op->body.accept(&collector); @@ -115,13 +112,16 @@ class CheckAtomicValidity : public IRGraphVisitor { } op->body.accept(this); } + +public: + bool any_atomic = false; }; /** Search if the value of a Store node has a variable pointing to a let binding, * where the let binding contains the Store location. Use for checking whether * we need a mutex lock for Atomic since some lowering pass before lifted a let * binding from the Store node (currently only SplitTuple would do this). */ -class FindAtomicLetBindings : public IRGraphVisitor { +class FindAtomicLetBindings : public IRVisitor { public: FindAtomicLetBindings(const Scope &store_names) : store_names(store_names) { @@ -133,18 +133,18 @@ class FindAtomicLetBindings : public IRGraphVisitor { using IRVisitor::visit; void visit(const Let *op) override { - include(op->value); + op->value.accept(this); { ScopedBinding bind(let_bindings, op->name, op->value); - include(op->body); + op->body.accept(this); } } void visit(const LetStmt *op) override { - include(op->value); + op->value.accept(this); { ScopedBinding bind(let_bindings, op->name, op->value); - include(op->body); + op->body.accept(this); } } @@ -159,19 +159,19 @@ class FindAtomicLetBindings : public IRGraphVisitor { } void visit(const Store *op) override { - include(op->predicate); + op->predicate.accept(this); + op->index.accept(this); if (store_names.contains(op->name)) { // If we are in a designated store and op->value has a let binding // that uses one of the store_names, we found a lifted let. - ScopedValue old_inside_store(inside_store, op->name); - include(op->value); + ScopedValue old_inside_store(inside_store, op->name); + op->value.accept(this); } else { - include(op->value); + op->value.accept(this); } - include(op->index); } - string inside_store; + std::string inside_store; const Scope &store_names; Scope let_bindings; }; @@ -179,7 +179,7 @@ class FindAtomicLetBindings : public IRGraphVisitor { /** Clear out the Atomic node's mutex usages if it doesn't need one. */ class RemoveUnnecessaryMutexUse : public IRMutator { public: - set remove_mutex_lock_names; + std::set remove_mutex_lock_names; protected: using IRMutator::visit; @@ -200,30 +200,30 @@ class RemoveUnnecessaryMutexUse : public IRMutator { remove_mutex_lock_names.insert(op->mutex_name); Stmt body = mutate(op->body); return Atomic::make(op->producer_name, - string(), + std::string{}, std::move(body)); } } }; /** Find Store inside an Atomic that matches the provided store_names. */ -class FindStoreInAtomicMutex : public IRGraphVisitor { +class FindStoreInAtomicMutex : public IRVisitor { public: - using IRGraphVisitor::visit; + using IRVisitor::visit; FindStoreInAtomicMutex(const std::set &store_names) : store_names(store_names) { } bool found = false; - string producer_name; - string mutex_name; + std::string producer_name; + std::string mutex_name; protected: void visit(const Atomic *op) override { if (!found && !op->mutex_name.empty()) { ScopedValue old_in_atomic_mutex(in_atomic_mutex, true); - include(op->body); + op->body.accept(this); if (found) { // We found a Store inside Atomic with matching name, // record the mutex information. @@ -231,7 +231,7 @@ class FindStoreInAtomicMutex : public IRGraphVisitor { mutex_name = op->mutex_name; } } else { - include(op->body); + op->body.accept(this); } } @@ -241,11 +241,11 @@ class FindStoreInAtomicMutex : public IRGraphVisitor { found = true; } } - IRGraphVisitor::visit(op); + IRVisitor::visit(op); } bool in_atomic_mutex = false; - const set &store_names; + const std::set &store_names; }; /** Replace the indices in the Store nodes with the specified variable. */ @@ -276,26 +276,32 @@ class ReplaceStoreIndexWithVar : public IRMutator { /** Add mutex allocation & lock & unlock if required. */ class AddAtomicMutex : public IRMutator { public: - AddAtomicMutex(const map &env) - : env(env) { + AddAtomicMutex(const std::vector &o) { + for (const Function &f : o) { + outputs.emplace(f.name(), f); + } } protected: using IRMutator::visit; - const map &env; - // The set of producers that have allocated a mutex buffer - set allocated_mutexes; + // Maps from a producer name to a mutex name, for all encountered atomic + // nodes. + Scope needs_mutex_allocation; - Stmt allocate_mutex(const string &mutex_name, Expr extent, Stmt body) { + // Pipeline outputs + std::map outputs; + + Stmt allocate_mutex(const std::string &mutex_name, Expr extent, Stmt body) { Expr mutex_array = Call::make(type_of(), "halide_mutex_array_create", {std::move(extent)}, Call::Extern); + // Allocate a scalar of halide_mutex_array. // This generates halide_mutex_array mutex[1]; body = Allocate::make(mutex_name, - Handle(), + type_of(), MemoryType::Stack, {}, const_true(), @@ -309,37 +315,44 @@ class AddAtomicMutex : public IRMutator { // If this Allocate node is allocating a buffer for a producer, // and there is a Store node inside of an Atomic node requiring mutex lock // matching the name of the Allocate, allocate a mutex lock. - set store_names{op->name}; - FindStoreInAtomicMutex finder(store_names); - op->body.accept(&finder); - if (!finder.found) { - // No Atomic node that requires mutex lock from this node inside. - return IRMutator::visit(op); - } - if (allocated_mutexes.find(finder.mutex_name) != allocated_mutexes.end()) { - // We've already allocated a mutex. - return IRMutator::visit(op); + Stmt body = mutate(op->body); + + std::string producer_name; + if (ends_with(op->name, ".0")) { + producer_name = op->name.substr(0, op->name.size() - 2); + } else { + producer_name = op->name; } - allocated_mutexes.insert(finder.mutex_name); + if (const std::string *mutex_name = needs_mutex_allocation.find(producer_name)) { + Expr extent = cast(1); // uint64_t to handle LargeBuffers + for (const Expr &e : op->extents) { + extent = extent * e; + } - const string &mutex_name = finder.mutex_name; - Stmt body = mutate(op->body); - Expr extent = Expr(1); - for (const Expr &e : op->extents) { - extent = extent * e; + body = allocate_mutex(*mutex_name, extent, body); + + // At this stage in lowering it should be impossible to have an + // allocation that shadows the name of an outer allocation, but may as + // well handle it anyway by using a scope and popping at each allocate + // node. + needs_mutex_allocation.pop(producer_name); + } + + if (body.same_as(op->body)) { + return op; + } else { + return Allocate::make(op->name, + op->type, + op->memory_type, + op->extents, + op->condition, + std::move(body), + op->new_expr, + op->free_function, + op->padding); } - body = allocate_mutex(mutex_name, extent, body); - return Allocate::make(op->name, - op->type, - op->memory_type, - op->extents, - op->condition, - std::move(body), - op->new_expr, - op->free_function, - op->padding); } Stmt visit(const ProducerConsumer *op) override { @@ -348,50 +361,35 @@ class AddAtomicMutex : public IRMutator { // buffer at the producer node. if (!op->is_producer) { - // This is a consumer. + // This is a consumer return IRMutator::visit(op); } - // Find the corresponding output. - auto func_it = env.find(op->name); - if (func_it == env.end()) { - // Not an output. - return IRMutator::visit(op); - } - Func f = Func(func_it->second); - if (f.output_buffers().empty()) { - // Not an output. + auto it = outputs.find(op->name); + if (it == outputs.end()) { + // Not an output return IRMutator::visit(op); } - set store_names; - for (const auto &buffer : f.output_buffers()) { - store_names.insert(buffer.name()); - } + Function f = it->second; - FindStoreInAtomicMutex finder(store_names); - op->body.accept(&finder); - if (!finder.found) { - // No Atomic node that requires mutex lock from this node inside. - return IRMutator::visit(op); - } + Stmt body = mutate(op->body); - if (allocated_mutexes.find(finder.mutex_name) != allocated_mutexes.end()) { - // We've already allocated a mutex. - return IRMutator::visit(op); + if (const std::string *mutex_name = needs_mutex_allocation.find(it->first)) { + // All output buffers in a Tuple have the same extent. + OutputImageParam output_buffer = Func(f).output_buffers()[0]; + Expr extent = cast(1); // uint64_t to handle LargeBuffers + for (int i = 0; i < output_buffer.dimensions(); i++) { + extent *= output_buffer.dim(i).extent(); + } + body = allocate_mutex(*mutex_name, extent, body); } - allocated_mutexes.insert(finder.mutex_name); - - // We assume all output buffers in a Tuple have the same extent. - OutputImageParam output_buffer = f.output_buffers()[0]; - Expr extent = Expr(1); - for (int i = 0; i < output_buffer.dimensions(); i++) { - extent = extent * output_buffer.dim(i).extent(); + if (body.same_as(op->body)) { + return op; + } else { + return ProducerConsumer::make(op->name, op->is_producer, std::move(body)); } - Stmt body = mutate(op->body); - body = allocate_mutex(finder.mutex_name, extent, body); - return ProducerConsumer::make(op->name, op->is_producer, std::move(body)); } Stmt visit(const Atomic *op) override { @@ -414,7 +412,7 @@ class AddAtomicMutex : public IRMutator { // Lift the index outside of the atomic node. // This is for avoiding side-effects inside those expressions // being evaluated twice. - string name = unique_name('t'); + std::string name = unique_name('t'); index_let = index; index = Variable::make(index.type(), name); body = ReplaceStoreIndexWithVar(op->producer_name, index).mutate(body); @@ -444,17 +442,21 @@ class AddAtomicMutex : public IRMutator { internal_assert(index.as() != nullptr); ret = LetStmt::make(index.as()->name, index_let, ret); } + needs_mutex_allocation.push(op->producer_name, op->mutex_name); + return ret; } }; } // namespace -Stmt add_atomic_mutex(Stmt s, const map &env) { +Stmt add_atomic_mutex(Stmt s, const std::vector &outputs) { CheckAtomicValidity check; s.accept(&check); - s = RemoveUnnecessaryMutexUse().mutate(s); - s = AddAtomicMutex(env).mutate(s); + if (check.any_atomic) { + s = RemoveUnnecessaryMutexUse().mutate(s); + s = AddAtomicMutex(outputs).mutate(s); + } return s; } diff --git a/src/AddAtomicMutex.h b/src/AddAtomicMutex.h index c27b0346f349..5b11de621e97 100644 --- a/src/AddAtomicMutex.h +++ b/src/AddAtomicMutex.h @@ -23,7 +23,7 @@ namespace Internal { class Function; -Stmt add_atomic_mutex(Stmt s, const std::map &env); +Stmt add_atomic_mutex(Stmt s, const std::vector &outputs); } // namespace Internal } // namespace Halide diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index bc03dd124d9a..a186be1874d7 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -1112,11 +1112,12 @@ void IRPrinter::visit(const VectorReduce *op) { void IRPrinter::visit(const Atomic *op) { if (op->mutex_name.empty()) { - stream << get_indent() << "atomic {\n"; + stream << get_indent() << "atomic (" + << op->producer_name << ") {\n"; } else { - stream << get_indent() << "atomic ("; - stream << op->mutex_name; - stream << ") {\n"; + stream << get_indent() << "atomic (" + << op->producer_name << ", " + << op->mutex_name << ") {\n"; } indent += 2; print(op->body); diff --git a/src/Lower.cpp b/src/Lower.cpp index 3b357eb3061e..e39d55a65b9f 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -299,7 +299,7 @@ void lower_impl(const vector &output_funcs, log("Lowering after storage flattening:", s); debug(1) << "Adding atomic mutex allocation...\n"; - s = add_atomic_mutex(s, env); + s = add_atomic_mutex(s, outputs); log("Lowering after adding atomic mutex allocation:", s); debug(1) << "Unpacking buffer arguments...\n"; diff --git a/src/runtime/HalideRuntime.h b/src/runtime/HalideRuntime.h index 62fbaeb66d43..1a19202745bb 100644 --- a/src/runtime/HalideRuntime.h +++ b/src/runtime/HalideRuntime.h @@ -195,7 +195,7 @@ extern void halide_cond_wait(struct halide_cond *cond, struct halide_mutex *mute /** Functions for constructing/destroying/locking/unlocking arrays of mutexes. */ struct halide_mutex_array; //@{ -extern struct halide_mutex_array *halide_mutex_array_create(int sz); +extern struct halide_mutex_array *halide_mutex_array_create(uint64_t sz); extern void halide_mutex_array_destroy(void *user_context, void *array); extern int halide_mutex_array_lock(struct halide_mutex_array *array, int entry); extern int halide_mutex_array_unlock(struct halide_mutex_array *array, int entry); diff --git a/src/runtime/fake_thread_pool.cpp b/src/runtime/fake_thread_pool.cpp index 9c3cfddc5a47..531a16d1312e 100644 --- a/src/runtime/fake_thread_pool.cpp +++ b/src/runtime/fake_thread_pool.cpp @@ -96,7 +96,7 @@ WEAK void halide_mutex_unlock(halide_mutex *mutex) { // (e.g. correctness/multiple_scatter). Since we don't have threads, we don't // need to mutex to do anything, but returning a null would trigger an error // condition that would be misrepoted as out-of-memory. -WEAK halide_mutex_array *halide_mutex_array_create(int sz) { +WEAK halide_mutex_array *halide_mutex_array_create(uint64_t sz) { return &halide_fake_mutex_array; } diff --git a/src/runtime/synchronization_common.h b/src/runtime/synchronization_common.h index cb244f360eeb..778c423e4046 100644 --- a/src/runtime/synchronization_common.h +++ b/src/runtime/synchronization_common.h @@ -908,7 +908,7 @@ struct halide_mutex_array { struct halide_mutex *array; }; -WEAK halide_mutex_array *halide_mutex_array_create(int sz) { +WEAK halide_mutex_array *halide_mutex_array_create(uint64_t sz) { // TODO: If sz is huge, we should probably hash it down to something smaller // in the accessors below. Check for deadlocks before doing so. halide_mutex_array *array = (halide_mutex_array *)halide_malloc(