diff --git a/src/AddAtomicMutex.cpp b/src/AddAtomicMutex.cpp index a2bf990e38f6..cf3b0ae8bb89 100644 --- a/src/AddAtomicMutex.cpp +++ b/src/AddAtomicMutex.cpp @@ -1,5 +1,4 @@ #include "AddAtomicMutex.h" - #include "ExprUsesVar.h" #include "Func.h" #include "IREquality.h" @@ -11,14 +10,10 @@ namespace Halide { namespace Internal { -using std::map; -using std::set; -using std::string; - namespace { /** Collect names of all stores matching the producer name inside a statement. */ -class CollectProducerStoreNames : public IRGraphVisitor { +class CollectProducerStoreNames : public IRVisitor { public: CollectProducerStoreNames(const std::string &producer_name) : producer_name(producer_name) { @@ -27,12 +22,12 @@ class CollectProducerStoreNames : public IRGraphVisitor { Scope store_names; protected: - using IRGraphVisitor::visit; + using IRVisitor::visit; void visit(const Store *op) override { - IRGraphVisitor::visit(op); + IRVisitor::visit(op); if (op->name == producer_name || starts_with(op->name, producer_name + ".")) { - // This is a Store for the desginated Producer. + // This is a Store for the designated Producer. store_names.push(op->name); } } @@ -42,7 +37,7 @@ class CollectProducerStoreNames : public IRGraphVisitor { /** Find Store inside of an Atomic node for the designated producer * and return their indices. */ -class FindProducerStoreIndex : public IRGraphVisitor { +class FindProducerStoreIndex : public IRVisitor { public: FindProducerStoreIndex(const std::string &producer_name) : producer_name(producer_name) { @@ -51,11 +46,11 @@ class FindProducerStoreIndex : public IRGraphVisitor { Expr index; // The returned index. protected: - using IRGraphVisitor::visit; + using IRVisitor::visit; // Need to also extract the let bindings of a Store index. void visit(const Let *op) override { - IRGraphVisitor::visit(op); // Make sure we visit the Store first. + IRVisitor::visit(op); // Make sure we visit the Store first. if (index.defined()) { if (expr_uses_var(index, op->name)) { index = Let::make(op->name, op->value, index); @@ -63,7 +58,7 @@ class FindProducerStoreIndex : public IRGraphVisitor { } } void visit(const LetStmt *op) override { - IRGraphVisitor::visit(op); // Make sure we visit the Store first. + IRVisitor::visit(op); // Make sure we visit the Store first. if (index.defined()) { if (expr_uses_var(index, op->name)) { index = Let::make(op->name, op->value, index); @@ -72,7 +67,7 @@ class FindProducerStoreIndex : public IRGraphVisitor { } void visit(const Store *op) override { - IRGraphVisitor::visit(op); + IRVisitor::visit(op); if (op->name == producer_name || starts_with(op->name, producer_name + ".")) { // This is a Store for the designated producer. @@ -94,11 +89,13 @@ class FindProducerStoreIndex : public IRGraphVisitor { /** Throws an assertion for cases where the indexing on left-hand-side of * an atomic update references to itself. * e.g. f(clamp(f(r), 0, 100)) = f(r) + 1 should be rejected. */ -class CheckAtomicValidity : public IRGraphVisitor { +class CheckAtomicValidity : public IRVisitor { protected: - using IRGraphVisitor::visit; + using IRVisitor::visit; void visit(const Atomic *op) override { + any_atomic = true; + // Collect the names of all Store nodes inside. CollectProducerStoreNames collector(op->producer_name); op->body.accept(&collector); @@ -115,13 +112,16 @@ class CheckAtomicValidity : public IRGraphVisitor { } op->body.accept(this); } + +public: + bool any_atomic = false; }; /** Search if the value of a Store node has a variable pointing to a let binding, * where the let binding contains the Store location. Use for checking whether * we need a mutex lock for Atomic since some lowering pass before lifted a let * binding from the Store node (currently only SplitTuple would do this). */ -class FindAtomicLetBindings : public IRGraphVisitor { +class FindAtomicLetBindings : public IRVisitor { public: FindAtomicLetBindings(const Scope &store_names) : store_names(store_names) { @@ -133,18 +133,18 @@ class FindAtomicLetBindings : public IRGraphVisitor { using IRVisitor::visit; void visit(const Let *op) override { - include(op->value); + op->value.accept(this); { ScopedBinding bind(let_bindings, op->name, op->value); - include(op->body); + op->body.accept(this); } } void visit(const LetStmt *op) override { - include(op->value); + op->value.accept(this); { ScopedBinding bind(let_bindings, op->name, op->value); - include(op->body); + op->body.accept(this); } } @@ -159,19 +159,19 @@ class FindAtomicLetBindings : public IRGraphVisitor { } void visit(const Store *op) override { - include(op->predicate); + op->predicate.accept(this); + op->index.accept(this); if (store_names.contains(op->name)) { // If we are in a designated store and op->value has a let binding // that uses one of the store_names, we found a lifted let. - ScopedValue old_inside_store(inside_store, op->name); - include(op->value); + ScopedValue old_inside_store(inside_store, op->name); + op->value.accept(this); } else { - include(op->value); + op->value.accept(this); } - include(op->index); } - string inside_store; + std::string inside_store; const Scope &store_names; Scope let_bindings; }; @@ -179,7 +179,7 @@ class FindAtomicLetBindings : public IRGraphVisitor { /** Clear out the Atomic node's mutex usages if it doesn't need one. */ class RemoveUnnecessaryMutexUse : public IRMutator { public: - set remove_mutex_lock_names; + std::set remove_mutex_lock_names; protected: using IRMutator::visit; @@ -200,30 +200,30 @@ class RemoveUnnecessaryMutexUse : public IRMutator { remove_mutex_lock_names.insert(op->mutex_name); Stmt body = mutate(op->body); return Atomic::make(op->producer_name, - string(), + std::string{}, std::move(body)); } } }; /** Find Store inside an Atomic that matches the provided store_names. */ -class FindStoreInAtomicMutex : public IRGraphVisitor { +class FindStoreInAtomicMutex : public IRVisitor { public: - using IRGraphVisitor::visit; + using IRVisitor::visit; FindStoreInAtomicMutex(const std::set &store_names) : store_names(store_names) { } bool found = false; - string producer_name; - string mutex_name; + std::string producer_name; + std::string mutex_name; protected: void visit(const Atomic *op) override { if (!found && !op->mutex_name.empty()) { ScopedValue old_in_atomic_mutex(in_atomic_mutex, true); - include(op->body); + op->body.accept(this); if (found) { // We found a Store inside Atomic with matching name, // record the mutex information. @@ -231,7 +231,7 @@ class FindStoreInAtomicMutex : public IRGraphVisitor { mutex_name = op->mutex_name; } } else { - include(op->body); + op->body.accept(this); } } @@ -241,11 +241,11 @@ class FindStoreInAtomicMutex : public IRGraphVisitor { found = true; } } - IRGraphVisitor::visit(op); + IRVisitor::visit(op); } bool in_atomic_mutex = false; - const set &store_names; + const std::set &store_names; }; /** Replace the indices in the Store nodes with the specified variable. */ @@ -276,26 +276,32 @@ class ReplaceStoreIndexWithVar : public IRMutator { /** Add mutex allocation & lock & unlock if required. */ class AddAtomicMutex : public IRMutator { public: - AddAtomicMutex(const map &env) - : env(env) { + AddAtomicMutex(const std::vector &o) { + for (const Function &f : o) { + outputs.emplace(f.name(), f); + } } protected: using IRMutator::visit; - const map &env; - // The set of producers that have allocated a mutex buffer - set allocated_mutexes; + // Maps from a producer name to a mutex name, for all encountered atomic + // nodes. + Scope needs_mutex_allocation; - Stmt allocate_mutex(const string &mutex_name, Expr extent, Stmt body) { + // Pipeline outputs + std::map outputs; + + Stmt allocate_mutex(const std::string &mutex_name, Expr extent, Stmt body) { Expr mutex_array = Call::make(type_of(), "halide_mutex_array_create", {std::move(extent)}, Call::Extern); + // Allocate a scalar of halide_mutex_array. // This generates halide_mutex_array mutex[1]; body = Allocate::make(mutex_name, - Handle(), + type_of(), MemoryType::Stack, {}, const_true(), @@ -309,37 +315,44 @@ class AddAtomicMutex : public IRMutator { // If this Allocate node is allocating a buffer for a producer, // and there is a Store node inside of an Atomic node requiring mutex lock // matching the name of the Allocate, allocate a mutex lock. - set store_names{op->name}; - FindStoreInAtomicMutex finder(store_names); - op->body.accept(&finder); - if (!finder.found) { - // No Atomic node that requires mutex lock from this node inside. - return IRMutator::visit(op); - } - if (allocated_mutexes.find(finder.mutex_name) != allocated_mutexes.end()) { - // We've already allocated a mutex. - return IRMutator::visit(op); + Stmt body = mutate(op->body); + + std::string producer_name; + if (ends_with(op->name, ".0")) { + producer_name = op->name.substr(0, op->name.size() - 2); + } else { + producer_name = op->name; } - allocated_mutexes.insert(finder.mutex_name); + if (const std::string *mutex_name = needs_mutex_allocation.find(producer_name)) { + Expr extent = cast(1); // uint64_t to handle LargeBuffers + for (const Expr &e : op->extents) { + extent = extent * e; + } - const string &mutex_name = finder.mutex_name; - Stmt body = mutate(op->body); - Expr extent = Expr(1); - for (const Expr &e : op->extents) { - extent = extent * e; + body = allocate_mutex(*mutex_name, extent, body); + + // At this stage in lowering it should be impossible to have an + // allocation that shadows the name of an outer allocation, but may as + // well handle it anyway by using a scope and popping at each allocate + // node. + needs_mutex_allocation.pop(producer_name); + } + + if (body.same_as(op->body)) { + return op; + } else { + return Allocate::make(op->name, + op->type, + op->memory_type, + op->extents, + op->condition, + std::move(body), + op->new_expr, + op->free_function, + op->padding); } - body = allocate_mutex(mutex_name, extent, body); - return Allocate::make(op->name, - op->type, - op->memory_type, - op->extents, - op->condition, - std::move(body), - op->new_expr, - op->free_function, - op->padding); } Stmt visit(const ProducerConsumer *op) override { @@ -348,50 +361,35 @@ class AddAtomicMutex : public IRMutator { // buffer at the producer node. if (!op->is_producer) { - // This is a consumer. + // This is a consumer return IRMutator::visit(op); } - // Find the corresponding output. - auto func_it = env.find(op->name); - if (func_it == env.end()) { - // Not an output. - return IRMutator::visit(op); - } - Func f = Func(func_it->second); - if (f.output_buffers().empty()) { - // Not an output. + auto it = outputs.find(op->name); + if (it == outputs.end()) { + // Not an output return IRMutator::visit(op); } - set store_names; - for (const auto &buffer : f.output_buffers()) { - store_names.insert(buffer.name()); - } + Function f = it->second; - FindStoreInAtomicMutex finder(store_names); - op->body.accept(&finder); - if (!finder.found) { - // No Atomic node that requires mutex lock from this node inside. - return IRMutator::visit(op); - } + Stmt body = mutate(op->body); - if (allocated_mutexes.find(finder.mutex_name) != allocated_mutexes.end()) { - // We've already allocated a mutex. - return IRMutator::visit(op); + if (const std::string *mutex_name = needs_mutex_allocation.find(it->first)) { + // All output buffers in a Tuple have the same extent. + OutputImageParam output_buffer = Func(f).output_buffers()[0]; + Expr extent = cast(1); // uint64_t to handle LargeBuffers + for (int i = 0; i < output_buffer.dimensions(); i++) { + extent *= output_buffer.dim(i).extent(); + } + body = allocate_mutex(*mutex_name, extent, body); } - allocated_mutexes.insert(finder.mutex_name); - - // We assume all output buffers in a Tuple have the same extent. - OutputImageParam output_buffer = f.output_buffers()[0]; - Expr extent = Expr(1); - for (int i = 0; i < output_buffer.dimensions(); i++) { - extent = extent * output_buffer.dim(i).extent(); + if (body.same_as(op->body)) { + return op; + } else { + return ProducerConsumer::make(op->name, op->is_producer, std::move(body)); } - Stmt body = mutate(op->body); - body = allocate_mutex(finder.mutex_name, extent, body); - return ProducerConsumer::make(op->name, op->is_producer, std::move(body)); } Stmt visit(const Atomic *op) override { @@ -414,7 +412,7 @@ class AddAtomicMutex : public IRMutator { // Lift the index outside of the atomic node. // This is for avoiding side-effects inside those expressions // being evaluated twice. - string name = unique_name('t'); + std::string name = unique_name('t'); index_let = index; index = Variable::make(index.type(), name); body = ReplaceStoreIndexWithVar(op->producer_name, index).mutate(body); @@ -444,17 +442,21 @@ class AddAtomicMutex : public IRMutator { internal_assert(index.as() != nullptr); ret = LetStmt::make(index.as()->name, index_let, ret); } + needs_mutex_allocation.push(op->producer_name, op->mutex_name); + return ret; } }; } // namespace -Stmt add_atomic_mutex(Stmt s, const map &env) { +Stmt add_atomic_mutex(Stmt s, const std::vector &outputs) { CheckAtomicValidity check; s.accept(&check); - s = RemoveUnnecessaryMutexUse().mutate(s); - s = AddAtomicMutex(env).mutate(s); + if (check.any_atomic) { + s = RemoveUnnecessaryMutexUse().mutate(s); + s = AddAtomicMutex(outputs).mutate(s); + } return s; } diff --git a/src/AddAtomicMutex.h b/src/AddAtomicMutex.h index c27b0346f349..5b11de621e97 100644 --- a/src/AddAtomicMutex.h +++ b/src/AddAtomicMutex.h @@ -23,7 +23,7 @@ namespace Internal { class Function; -Stmt add_atomic_mutex(Stmt s, const std::map &env); +Stmt add_atomic_mutex(Stmt s, const std::vector &outputs); } // namespace Internal } // namespace Halide diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index bc03dd124d9a..a186be1874d7 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -1112,11 +1112,12 @@ void IRPrinter::visit(const VectorReduce *op) { void IRPrinter::visit(const Atomic *op) { if (op->mutex_name.empty()) { - stream << get_indent() << "atomic {\n"; + stream << get_indent() << "atomic (" + << op->producer_name << ") {\n"; } else { - stream << get_indent() << "atomic ("; - stream << op->mutex_name; - stream << ") {\n"; + stream << get_indent() << "atomic (" + << op->producer_name << ", " + << op->mutex_name << ") {\n"; } indent += 2; print(op->body); diff --git a/src/Lower.cpp b/src/Lower.cpp index 3b357eb3061e..e39d55a65b9f 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -299,7 +299,7 @@ void lower_impl(const vector &output_funcs, log("Lowering after storage flattening:", s); debug(1) << "Adding atomic mutex allocation...\n"; - s = add_atomic_mutex(s, env); + s = add_atomic_mutex(s, outputs); log("Lowering after adding atomic mutex allocation:", s); debug(1) << "Unpacking buffer arguments...\n"; diff --git a/src/runtime/HalideRuntime.h b/src/runtime/HalideRuntime.h index 62fbaeb66d43..1a19202745bb 100644 --- a/src/runtime/HalideRuntime.h +++ b/src/runtime/HalideRuntime.h @@ -195,7 +195,7 @@ extern void halide_cond_wait(struct halide_cond *cond, struct halide_mutex *mute /** Functions for constructing/destroying/locking/unlocking arrays of mutexes. */ struct halide_mutex_array; //@{ -extern struct halide_mutex_array *halide_mutex_array_create(int sz); +extern struct halide_mutex_array *halide_mutex_array_create(uint64_t sz); extern void halide_mutex_array_destroy(void *user_context, void *array); extern int halide_mutex_array_lock(struct halide_mutex_array *array, int entry); extern int halide_mutex_array_unlock(struct halide_mutex_array *array, int entry); diff --git a/src/runtime/fake_thread_pool.cpp b/src/runtime/fake_thread_pool.cpp index 9c3cfddc5a47..531a16d1312e 100644 --- a/src/runtime/fake_thread_pool.cpp +++ b/src/runtime/fake_thread_pool.cpp @@ -96,7 +96,7 @@ WEAK void halide_mutex_unlock(halide_mutex *mutex) { // (e.g. correctness/multiple_scatter). Since we don't have threads, we don't // need to mutex to do anything, but returning a null would trigger an error // condition that would be misrepoted as out-of-memory. -WEAK halide_mutex_array *halide_mutex_array_create(int sz) { +WEAK halide_mutex_array *halide_mutex_array_create(uint64_t sz) { return &halide_fake_mutex_array; } diff --git a/src/runtime/synchronization_common.h b/src/runtime/synchronization_common.h index cb244f360eeb..778c423e4046 100644 --- a/src/runtime/synchronization_common.h +++ b/src/runtime/synchronization_common.h @@ -908,7 +908,7 @@ struct halide_mutex_array { struct halide_mutex *array; }; -WEAK halide_mutex_array *halide_mutex_array_create(int sz) { +WEAK halide_mutex_array *halide_mutex_array_create(uint64_t sz) { // TODO: If sz is huge, we should probably hash it down to something smaller // in the accessors below. Check for deadlocks before doing so. halide_mutex_array *array = (halide_mutex_array *)halide_malloc(