diff --git a/Makefile b/Makefile index f6ca8bcb3f80..1458773c22d0 100644 --- a/Makefile +++ b/Makefile @@ -491,6 +491,7 @@ SOURCE_FILES = \ LLVM_Runtime_Linker.cpp \ LoopCarry.cpp \ Lower.cpp \ + LowerParallelTasks.cpp \ LowerWarpShuffles.cpp \ MatlabWrapper.cpp \ Memoization.cpp \ @@ -669,6 +670,7 @@ HEADER_FILES = \ LLVM_Runtime_Linker.h \ LoopCarry.h \ Lower.h \ + LowerParallelTasks.h \ LowerWarpShuffles.h \ MainPage.h \ MatlabWrapper.h \ diff --git a/src/AsyncProducers.cpp b/src/AsyncProducers.cpp index 0931145a5ebf..3f64e4b1ab34 100644 --- a/src/AsyncProducers.cpp +++ b/src/AsyncProducers.cpp @@ -333,7 +333,7 @@ class ForkAsyncProducers : public IRMutator { vector sema_vars; for (int i = 0; i < consumes.count; i++) { sema_names.push_back(op->name + ".semaphore_" + std::to_string(i)); - sema_vars.push_back(Variable::make(Handle(), sema_names.back())); + sema_vars.push_back(Variable::make(type_of(), sema_names.back())); } Stmt producer = GenerateProducerBody(op->name, sema_vars, cloned_acquires).mutate(body); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 86074c5f3d20..4cb3b5d53741 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -98,6 +98,7 @@ set(HEADER_FILES LLVM_Runtime_Linker.h LoopCarry.h Lower.h + LowerParallelTasks.h LowerWarpShuffles.h MainPage.h MatlabWrapper.h @@ -258,6 +259,7 @@ set(SOURCE_FILES LLVM_Runtime_Linker.cpp LoopCarry.cpp Lower.cpp + LowerParallelTasks.cpp LowerWarpShuffles.cpp MatlabWrapper.cpp Memoization.cpp diff --git a/src/Closure.cpp b/src/Closure.cpp index c4a2f7e63b1f..5c5125a9b291 100644 --- a/src/Closure.cpp +++ b/src/Closure.cpp @@ -1,11 +1,18 @@ #include "Closure.h" #include "Debug.h" +#include "ExprUsesVar.h" +#include "IRMutator.h" +#include "IROperator.h" namespace Halide { namespace Internal { using std::string; +namespace { +constexpr int DBG = 3; +} // namespace + void Closure::include(const Stmt &s, const string &loop_variable) { if (!loop_variable.empty()) { ignore.push(loop_variable); @@ -38,7 +45,7 @@ void Closure::visit(const For *op) { void Closure::found_buffer_ref(const string &name, Type type, bool read, bool written, const Halide::Buffer<> &image) { if (!ignore.contains(name)) { - debug(3) << "Adding buffer " << name << " to closure\n"; + debug(DBG) << "Adding buffer " << name << " to closure:\n"; Buffer &ref = buffers[name]; ref.type = type.element_of(); // TODO: Validate type is the same as existing refs? ref.read = ref.read || read; @@ -49,8 +56,15 @@ void Closure::found_buffer_ref(const string &name, Type type, ref.size = image.size_in_bytes(); ref.dimensions = image.dimensions(); } + debug(DBG) << " " + << " t=" << ref.type + << " d=" << (int)ref.dimensions + << " r=" << ref.read + << " w=" << ref.write + << " mt=" << (int)ref.memory_type + << " sz=" << ref.size << "\n"; } else { - debug(3) << "Not adding " << name << " to closure\n"; + debug(DBG) << "Not adding buffer " << name << " to closure\n"; } } @@ -81,9 +95,9 @@ void Closure::visit(const Allocate *op) { void Closure::visit(const Variable *op) { if (ignore.contains(op->name)) { - debug(3) << "Not adding " << op->name << " to closure\n"; + debug(DBG) << "Not adding var " << op->name << " to closure\n"; } else { - debug(3) << "Adding " << op->name << " to closure\n"; + debug(DBG) << "Adding var " << op->name << " to closure\n"; vars[op->name] = op->type; } } @@ -95,5 +109,71 @@ void Closure::visit(const Atomic *op) { op->body.accept(this); } +Expr Closure::pack_into_struct() const { + std::vector elements; + + for (const auto &b : buffers) { + Expr ptr_var = Variable::make(type_of(), b.first); + elements.emplace_back(ptr_var); + } + for (const auto &v : vars) { + Expr var = Variable::make(v.second, v.first); + elements.emplace_back(var); + } + + // Sort by decreasing size, to guarantee the struct is densely packed in + // memory. We don't actually rely on this, it's just nice to have. + std::stable_sort(elements.begin(), elements.end(), + [&](const Expr &a, const Expr &b) { + return a.type().bytes() > b.type().bytes(); + }); + + Expr result = Call::make(Handle(), + Call::make_struct, elements, Call::Intrinsic); + return result; +} + +Stmt Closure::unpack_from_struct(const Expr &e, const Stmt &s) const { + // Use the struct-packing code just to make sure the order of elements is + // the same. + Expr packed = pack_into_struct(); + + // Make a prototype of the packed struct + class ReplaceCallArgsWithZero : public IRMutator { + public: + using IRMutator::mutate; + Expr mutate(const Expr &e) override { + if (!e.as()) { + return make_zero(e.type()); + } else { + return IRMutator::mutate(e); + } + } + } replacer; + string prototype_name = unique_name("closure_prototype"); + Expr prototype = replacer.mutate(packed); + Expr prototype_var = Variable::make(Handle(), prototype_name); + + const Call *c = packed.as(); + + Stmt result = s; + for (int idx = (int)c->args.size() - 1; idx >= 0; idx--) { + Expr arg = c->args[idx]; + const Variable *var = arg.as(); + Expr val = Call::make(var->type, + Call::load_typed_struct_member, + {e, prototype_var, idx}, + Call::Intrinsic); + if (stmt_uses_var(result, var->name)) { + // If a closure is generated for multiple consuming blocks of IR, + // then some of those blocks might only need some of the field. + result = LetStmt::make(var->name, val, result); + } + } + result = LetStmt::make(prototype_name, prototype, result); + + return result; +} + } // namespace Internal } // namespace Halide diff --git a/src/Closure.h b/src/Closure.h index 06ceb4ffb074..85b23a1cb31c 100644 --- a/src/Closure.h +++ b/src/Closure.h @@ -90,11 +90,19 @@ class Closure : public IRVisitor { **/ void include(const Stmt &s, const std::string &loop_variable = ""); - /** External variables referenced. */ + /** External variables referenced. There's code that assumes iterating over + * this repeatedly gives a consistent order, so don't swap out the data type + * for something non-deterministic. */ std::map vars; /** External allocations referenced. */ std::map buffers; + + /** Pack a closure into a struct. */ + Expr pack_into_struct() const; + + /** Unpack a closure around a Stmt, putting all the names in scope. */ + Stmt unpack_from_struct(const Expr &, const Stmt &) const; }; } // namespace Internal diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index 50a25f6748f4..2aea178100c0 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -1820,7 +1820,7 @@ void CodeGen_C::compile(const LoweredFunc &f, const std::map namespaces; - std::string simple_name = extract_namespaces(f.name, namespaces); + std::string simple_name = c_print_name(extract_namespaces(f.name, namespaces), false); if (!is_c_plus_plus_interface()) { user_assert(namespaces.empty()) << "Namespace qualifiers not allowed on function name if not compiling with Target::CPlusPlusNameMangling.\n"; } @@ -2024,7 +2024,13 @@ string CodeGen_C::print_assignment(Type t, const std::string &rhs) { if (cached == cache.end()) { id = unique_name('_'); const char *const_flag = output_kind == CPlusPlusImplementation ? "const " : ""; - stream << get_indent() << print_type(t, AppendSpace) << const_flag << id << " = " << rhs << ";\n"; + if (t.is_handle()) { + // Don't print void *, which might lose useful type information. just use auto. + stream << get_indent() << "auto *"; + } else { + stream << get_indent() << print_type(t, AppendSpace); + } + stream << const_flag << id << " = " << rhs << ";\n"; cache[rhs] = id; } else { id = cached->second; @@ -2051,7 +2057,12 @@ void CodeGen_C::close_scope(const std::string &comment) { } void CodeGen_C::visit(const Variable *op) { - id = print_name(op->name); + if (starts_with(op->name, "::")) { + // This is the name of a global, so we can't modify it. + id = op->name; + } else { + id = print_name(op->name); + } } void CodeGen_C::visit(const Cast *op) { @@ -2369,12 +2380,19 @@ void CodeGen_C::visit(const Call *op) { } else if (op->is_intrinsic(Call::alloca)) { internal_assert(op->args.size() == 1); internal_assert(op->type.is_handle()); + const int64_t *sz = as_const_int(op->args[0]); if (op->type == type_of() && Call::as_intrinsic(op->args[0], {Call::size_of_halide_buffer_t})) { stream << get_indent(); string buf_name = unique_name('b'); stream << "halide_buffer_t " << buf_name << ";\n"; rhs << "&" << buf_name; + } else if (op->type == type_of() && + sz && *sz == 16) { + stream << get_indent(); + string semaphore_name = unique_name("sema"); + stream << "halide_semaphore_t " << semaphore_name << ";\n"; + rhs << "&" << semaphore_name; } else { // Make a stack of uint64_ts string size = print_expr(simplify((op->args[0] + 7) / 8)); @@ -2422,7 +2440,7 @@ void CodeGen_C::visit(const Call *op) { rhs << shape_name; } else { // Emit a declaration like: - // struct {const int f_0, const char f_1, const int f_2} foo = {3, 'c', 4}; + // struct {int f_0, int f_1, char f_2} foo = {3, 4, 'c'}; // Get the args vector values; @@ -2433,7 +2451,7 @@ void CodeGen_C::visit(const Call *op) { // List the types. indent++; for (size_t i = 0; i < op->args.size(); i++) { - stream << get_indent() << "const " << print_type(op->args[i].type()) << " f_" << i << ";\n"; + stream << get_indent() << print_type(op->args[i].type()) << " f_" << i << ";\n"; } indent--; string struct_name = unique_name('s'); @@ -2460,6 +2478,26 @@ void CodeGen_C::visit(const Call *op) { } rhs << "(&" << struct_name << ")"; } + } else if (op->is_intrinsic(Call::load_typed_struct_member)) { + // Given a void * instance of a typed struct, an in-scope prototype + // struct of the same type, and the index of a slot, load the value of + // that slot. + // + // It is assumed that the slot index is valid for the given typed struct. + // + // TODO: this comment is replicated in CodeGen_LLVM and should be updated there too. + // TODO: https://github.com/halide/Halide/issues/6468 + + internal_assert(op->args.size() == 3); + std::string struct_instance = print_expr(op->args[0]); + std::string struct_prototype = print_expr(op->args[1]); + const int64_t *index = as_const_int(op->args[2]); + internal_assert(index != nullptr); + rhs << "((decltype(" << struct_prototype << "))" + << struct_instance << ")->f_" << *index; + } else if (op->is_intrinsic(Call::get_user_context)) { + internal_assert(op->args.empty()); + rhs << "_ucon"; } else if (op->is_intrinsic(Call::stringify)) { // Rewrite to an snprintf vector printf_args; @@ -2490,7 +2528,6 @@ void CodeGen_C::visit(const Call *op) { stream << get_indent() << "char " << buf_name << "[1024];\n"; stream << get_indent() << "snprintf(" << buf_name << ", 1024, \"" << format_string << "\", " << with_commas(printf_args) << ");\n"; rhs << buf_name; - } else if (op->is_intrinsic(Call::register_destructor)) { internal_assert(op->args.size() == 2); const StringImm *fn = op->args[0].as(); @@ -2507,7 +2544,7 @@ void CodeGen_C::visit(const Call *op) { << "" << struct_name << "(void *ucon, void *a) : ucon(ucon), arg((void *)a) {} " << "~" << struct_name << "() { " << fn->value + "(ucon, arg); } " << "} " << instance_name << "(_ucon, " << arg << ");\n"; - rhs << print_expr(0); + rhs << "(void *)nullptr"; } else if (op->is_intrinsic(Call::div_round_to_zero)) { rhs << print_expr(op->args[0]) << " / " << print_expr(op->args[1]); } else if (op->is_intrinsic(Call::mod_round_to_zero)) { @@ -2710,9 +2747,10 @@ void CodeGen_C::visit(const Let *op) { if (op->value.type().is_handle()) { // The body might contain a Load that references this directly // by name, so we can't rewrite the name. - stream << get_indent() << print_type(op->value.type()) - << " " << print_name(op->name) - << " = " << id_value << ";\n"; + std::string name = print_name(op->name); + stream << get_indent() << "auto " + << name << " = " << id_value << ";\n"; + stream << get_indent() << "halide_unused(" << name << ");\n"; } else { Expr new_var = Variable::make(op->value.type(), id_value); body = substitute(op->name, new_var, body); @@ -2800,12 +2838,14 @@ void CodeGen_C::visit(const VectorReduce *op) { void CodeGen_C::visit(const LetStmt *op) { string id_value = print_expr(op->value); Stmt body = op->body; + if (op->value.type().is_handle()) { // The body might contain a Load or Store that references this // directly by name, so we can't rewrite the name. - stream << get_indent() << print_type(op->value.type()) - << " " << print_name(op->name) - << " = " << id_value << ";\n"; + std::string name = print_name(op->name); + stream << get_indent() << "auto " + << name << " = " << id_value << ";\n"; + stream << get_indent() << "halide_unused(" << name << ");\n"; } else { Expr new_var = Variable::make(op->value.type(), id_value); body = substitute(op->name, new_var, body); @@ -3221,8 +3261,9 @@ extern "C" { HALIDE_FUNCTION_ATTRS int test1(struct halide_buffer_t *_buf_buffer, float _alpha, int32_t _beta, void const *__user_context) { void * const _ucon = const_cast(__user_context); - void *_0 = _halide_buffer_get_host(_buf_buffer); - void * _buf = _0; + auto *_0 = _halide_buffer_get_host(_buf_buffer); + auto _buf = _0; + halide_unused(_buf); { int64_t _1 = 43; int64_t _2 = _1 * _beta; @@ -3248,7 +3289,7 @@ int test1(struct halide_buffer_t *_buf_buffer, float _alpha, int32_t _beta, void { char b0[1024]; snprintf(b0, 1024, "%lld%s", (long long)(3), "\n"); - char const *_8 = b0; + auto *_8 = b0; halide_print(_ucon, _8); int32_t _9 = 0; int32_t _10 = return_second(_9, 3); diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index 847fe588d758..a32bca98ff7d 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -431,6 +431,30 @@ class InjectHVXLocks : public IRMutator { } Expr visit(const Call *op) override { uses_hvx = uses_hvx || op->type.is_vector(); + + if (op->name == "halide_do_par_for") { + // If we see a call to halide_do_par_for() at this point, it should mean that + // this statement was produced via HexagonOffload calling lower_parallel_tasks() + // explicitly; in this case, we won't see any parallel For statements, since they've + // all been transformed into closures already. To mirror the pattern above, + // we need to wrap the halide_do_par_for() call with an unlock/lock pair, but + // that's hard to do in Halide IR (we'd need to produce a Stmt to enforce the ordering, + // and the resulting Stmt can't easily be substituted for the Expr here). Rather than + // make fragile assumptions about the structure of the IR produced by lower_parallel_tasks(), + // we'll use a trick: we'll define a WEAK_INLINE function, _halide_hexagon_do_par_for, + // which simply encapsulates the unlock()/do_par_for()/lock() sequences, and swap out + // the call here. Since it is inlined, and since uses_hvx_var gets substituted at the end, + // we end up with LLVM IR that properly includes (or omits) the unlock/lock pair depending + // on the final value of uses_hvx_var in this scope. + + internal_assert(op->call_type == Call::Extern); + internal_assert(op->args.size() == 4); + + std::vector args = op->args; + args.push_back(cast(uses_hvx_var)); + + return Call::make(Int(32), "_halide_hexagon_do_par_for", args, Call::Extern); + } return op; } diff --git a/src/CodeGen_Internal.cpp b/src/CodeGen_Internal.cpp index bcb506741e02..f7f7b2917991 100644 --- a/src/CodeGen_Internal.cpp +++ b/src/CodeGen_Internal.cpp @@ -13,81 +13,9 @@ namespace Halide { namespace Internal { using std::string; -using std::vector; using namespace llvm; -namespace { - -vector llvm_types(const Closure &closure, llvm::StructType *halide_buffer_t_type, LLVMContext &context) { - vector res; - for (const auto &v : closure.vars) { - res.push_back(llvm_type_of(&context, v.second)); - } - for (const auto &b : closure.buffers) { - res.push_back(llvm_type_of(&context, b.second.type)->getPointerTo()); - res.push_back(halide_buffer_t_type->getPointerTo()); - } - return res; -} - -} // namespace - -StructType *build_closure_type(const Closure &closure, - llvm::StructType *halide_buffer_t_type, - LLVMContext *context) { - StructType *struct_t = StructType::create(*context, "closure_t"); - struct_t->setBody(llvm_types(closure, halide_buffer_t_type, *context), false); - return struct_t; -} - -void pack_closure(llvm::StructType *type, - Value *dst, - const Closure &closure, - const Scope &src, - llvm::StructType *halide_buffer_t_type, - IRBuilder<> *builder) { - // type, type of dst should be a pointer to a struct of the type returned by build_type - int idx = 0; - - auto add_to_closure = [&](const std::string &name) { - llvm::Type *t = type->elements()[idx]; - Value *ptr = builder->CreateConstInBoundsGEP2_32(type, dst, 0, idx++); - Value *val = src.get(name); - val = builder->CreateBitCast(val, t); - builder->CreateStore(val, ptr); - }; - - for (const auto &v : closure.vars) { - add_to_closure(v.first); - } - for (const auto &b : closure.buffers) { - add_to_closure(b.first); - } -} - -void unpack_closure(const Closure &closure, - Scope &dst, - llvm::StructType *type, - Value *src, - IRBuilder<> *builder) { - // type, type of src should be a pointer to a struct of the type returned by build_type - int idx = 0; - - auto load_from_closure = [&](const std::string &name) { - Value *ptr = builder->CreateConstInBoundsGEP2_32(type, src, 0, idx++); - LoadInst *load = builder->CreateLoad(ptr->getType()->getPointerElementType(), ptr); - dst.push(name, load); - load->setName(name); - }; - for (const auto &v : closure.vars) { - load_from_closure(v.first); - } - for (const auto &b : closure.buffers) { - load_from_closure(b.first); - } -} - llvm::Type *llvm_type_of(LLVMContext *c, Halide::Type t) { if (t.lanes() == 1) { if (t.is_float() && !t.is_bfloat()) { @@ -209,6 +137,7 @@ bool function_takes_user_context(const std::string &name) { "_halide_buffer_crop", "_halide_buffer_retire_crop_after_extern_stage", "_halide_buffer_retire_crops_after_extern_stage", + "_halide_hexagon_do_par_for", }; for (const char *user_context_runtime_func : user_context_runtime_funcs) { if (name == user_context_runtime_func) { diff --git a/src/CodeGen_Internal.h b/src/CodeGen_Internal.h index 52f6e51ba2b9..3fe1b8b696f5 100644 --- a/src/CodeGen_Internal.h +++ b/src/CodeGen_Internal.h @@ -37,30 +37,6 @@ struct Target; namespace Internal { -/** The llvm type of a struct containing all of the externally referenced state of a Closure. */ -llvm::StructType *build_closure_type(const Closure &closure, llvm::StructType *halide_buffer_t_type, llvm::LLVMContext *context); - -/** Emit code that builds a struct containing all the externally - * referenced state. Requires you to pass it a type and struct to fill in, - * a scope to retrieve the llvm values from and a builder to place - * the packing code. */ -void pack_closure(llvm::StructType *type, - llvm::Value *dst, - const Closure &closure, - const Scope &src, - llvm::StructType *halide_buffer_t_type, - llvm::IRBuilder *builder); - -/** Emit code that unpacks a struct containing all the externally - * referenced state into a symbol table. Requires you to pass it a - * state struct type and value, a scope to fill, and a builder to place the - * unpacking code. */ -void unpack_closure(const Closure &closure, - Scope &dst, - llvm::StructType *type, - llvm::Value *src, - llvm::IRBuilder *builder); - /** Get the llvm type equivalent to a given halide type */ llvm::Type *llvm_type_of(llvm::LLVMContext *context, Halide::Type t); diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 97b407638dbc..47b3a8dab0c5 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -22,6 +22,7 @@ #include "LLVM_Headers.h" #include "LLVM_Runtime_Linker.h" #include "Lerp.h" +#include "LowerParallelTasks.h" #include "MatlabWrapper.h" #include "Pipeline.h" #include "Simplify.h" @@ -503,12 +504,36 @@ std::unique_ptr CodeGen_LLVM::compile(const Module &input) { for (const auto &b : input.buffers()) { compile_buffer(b); } + + vector function_names; + + // Declare all functions for (const auto &f : input.functions()) { const auto names = get_mangled_names(f, get_target()); + function_names.push_back(names); - run_with_large_stack([&]() { - compile_func(f, names.simple_name, names.extern_name); - }); + // Deduce the types of the arguments to our function + vector arg_types(f.args.size()); + for (size_t i = 0; i < f.args.size(); i++) { + if (f.args[i].is_buffer()) { + arg_types[i] = halide_buffer_t_type->getPointerTo(); + } else { + arg_types[i] = llvm_type_of(upgrade_type_for_argument_passing(f.args[i].type)); + } + } + FunctionType *func_t = FunctionType::get(i32_t, arg_types, false); + function = llvm::Function::Create(func_t, llvm_linkage(f.linkage), names.extern_name, module.get()); + set_function_attributes_for_target(function, target); + + // Mark the buffer args as no alias + for (size_t i = 0; i < f.args.size(); i++) { + if (f.args[i].is_buffer()) { + function->addParamAttr(i, Attribute::NoAlias); + } + } + + // sym_push helpfully calls setName, which we don't want + symbol_table.push("::" + f.name, function); // If the Func is externally visible, also create the argv wrapper and metadata. // (useful for calling from JIT and other machine interfaces). @@ -524,8 +549,17 @@ std::unique_ptr CodeGen_LLVM::compile(const Module &input) { } } } + // Define all functions + int idx = 0; + for (const auto &f : input.functions()) { + const auto names = function_names[idx++]; - debug(2) << module.get() << "\n"; + run_with_large_stack([&]() { + compile_func(f, names.simple_name, names.extern_name); + }); + } + + debug(2) << "llvm::Module pointer: " << module.get() << "\n"; return finish_codegen(); } @@ -550,41 +584,9 @@ std::unique_ptr CodeGen_LLVM::finish_codegen() { void CodeGen_LLVM::begin_func(LinkageType linkage, const std::string &name, const std::string &extern_name, const std::vector &args) { current_function_args = args; - - // Deduce the types of the arguments to our function - vector arg_types(args.size()); - for (size_t i = 0; i < args.size(); i++) { - if (args[i].is_buffer()) { - arg_types[i] = halide_buffer_t_type->getPointerTo(); - } else { - arg_types[i] = llvm_type_of(upgrade_type_for_argument_passing(args[i].type)); - } - } - FunctionType *func_t = FunctionType::get(i32_t, arg_types, false); - - // Make our function. There may already be a declaration of it. function = module->getFunction(extern_name); if (!function) { - function = llvm::Function::Create(func_t, llvm_linkage(linkage), extern_name, module.get()); - } else { - user_assert(function->isDeclaration()) - << "Another function with the name " << extern_name - << " already exists in the same module\n"; - if (func_t != function->getFunctionType()) { - std::cerr << "Desired function type for " << extern_name << ":\n"; - func_t->print(dbgs(), true); - std::cerr << "Declared function type of " << extern_name << ":\n"; - function->getFunctionType()->print(dbgs(), true); - user_error << "Cannot create a function with a declaration of mismatched type.\n"; - } - } - set_function_attributes_for_target(function, target); - - // Mark the buffer args as no alias - for (size_t i = 0; i < args.size(); i++) { - if (args[i].is_buffer()) { - function->addParamAttr(i, Attribute::NoAlias); - } + internal_assert(function) << "Could not find a function of name " << extern_name << " in module\n"; } debug(1) << "Generating llvm bitcode prolog for function " << name << "...\n"; @@ -971,7 +973,7 @@ llvm::Function *CodeGen_LLVM::embed_metadata_getter(const std::string &metadata_ vector arguments_array_entries; for (int arg = 0; arg < num_args; ++arg) { - StructType *type_t_type = get_llvm_struct_type_by_name(module.get(), "struct.halide_type_t"); + llvm::StructType *type_t_type = get_llvm_struct_type_by_name(module.get(), "struct.halide_type_t"); internal_assert(type_t_type) << "Did not find halide_type_t in module.\n"; Constant *type_fields[] = { @@ -2807,7 +2809,7 @@ void CodeGen_LLVM::visit(const Call *op) { value = create_alloca_at_entry(types[0], 1); builder->CreateStore(args[0], value); } else { - llvm::Type *aggregate_t = (all_same_type ? (llvm::Type *)ArrayType::get(types[0], types.size()) : (llvm::Type *)StructType::get(*context, types)); + llvm::Type *aggregate_t = (all_same_type ? (llvm::Type *)ArrayType::get(types[0], types.size()) : (llvm::Type *)llvm::StructType::get(*context, types)); value = create_alloca_at_entry(aggregate_t, 1); for (size_t i = 0; i < args.size(); i++) { @@ -2816,6 +2818,38 @@ void CodeGen_LLVM::visit(const Call *op) { } } } + } else if (op->is_intrinsic(Call::load_typed_struct_member)) { + // Given a void * instance of a typed struct, an in-scope prototype + // struct of the same type, and the index of a slot, load the value of + // that slot. + // + // It is assumed that the slot index is valid for the given typed struct. + // + // TODO: this comment is replicated in CodeGen_LLVM and should be updated there too. + // TODO: https://github.com/halide/Halide/issues/6468 + internal_assert(op->args.size() == 3); + llvm::Value *struct_instance = codegen(op->args[0]); + llvm::Value *struct_prototype = codegen(op->args[1]); + llvm::Value *typed_struct_instance = builder->CreatePointerCast(struct_instance, struct_prototype->getType()); + const int64_t *index = as_const_int(op->args[2]); + + // make_struct can use a fixed-size struct, an array type, or a scalar + llvm::Type *pointee_type = struct_prototype->getType()->getPointerElementType(); + llvm::Type *struct_type = llvm::dyn_cast(pointee_type); + llvm::Type *array_type = llvm::dyn_cast(pointee_type); + if (struct_type || array_type) { + internal_assert(index != nullptr); + llvm::Value *gep = CreateInBoundsGEP(builder, typed_struct_instance, + {ConstantInt::get(i32_t, 0), + ConstantInt::get(i32_t, (int)*index)}); + value = builder->CreateLoad(gep->getType()->getPointerElementType(), gep); + } else { + // The struct is actually just a scalar + value = builder->CreateLoad(pointee_type, typed_struct_instance); + } + } else if (op->is_intrinsic(Call::get_user_context)) { + internal_assert(op->args.empty()); + value = get_user_context(); } else if (op->is_intrinsic(Call::saturating_add) || op->is_intrinsic(Call::saturating_sub)) { internal_assert(op->args.size() == 2); @@ -2934,7 +2968,9 @@ void CodeGen_LLVM::visit(const Call *op) { dst = builder->CreateCall(append_buffer, call_args); } else { internal_assert(t.is_handle()); - call_args.push_back(codegen(arg)); + Value *ptr = codegen(arg); + ptr = builder->CreatePointerCast(ptr, i8_t->getPointerTo()); + call_args.push_back(ptr); dst = builder->CreateCall(append_pointer, call_args); } } @@ -2965,11 +3001,15 @@ void CodeGen_LLVM::visit(const Call *op) { // restrictions if we recognize the most common types we // expect to get alloca'd. const Call *call = op->args[0].as(); + const int64_t *sz = as_const_int(op->args[0]); if (op->type == type_of() && call && call->is_intrinsic(Call::size_of_halide_buffer_t)) { value = create_alloca_at_entry(halide_buffer_t_type, 1); + } else if (op->type == type_of() && + semaphore_t_type != nullptr && + sz && *sz == 16) { + value = create_alloca_at_entry(semaphore_t_type, 1); } else { - const int64_t *sz = as_const_int(op->args[0]); internal_assert(sz != nullptr); if (op->type == type_of()) { value = create_alloca_at_entry(dimension_t_type, *sz / sizeof(halide_dimension_t)); @@ -3529,12 +3569,13 @@ void CodeGen_LLVM::visit(const For *op) { Value *extent = codegen(op->extent); const Acquire *acquire = op->body.as(); - if (op->for_type == ForType::Parallel || - (op->for_type == ForType::Serial && - acquire && - !expr_uses_var(acquire->count, op->name))) { - do_as_parallel_task(op); - } else if (op->for_type == ForType::Serial) { + // TODO(zalman): remove this after validating it doesn't happen + internal_assert(!(op->for_type == ForType::Parallel || + (op->for_type == ForType::Serial && + acquire && + !expr_uses_var(acquire->count, op->name)))); + + if (op->for_type == ForType::Serial) { Value *max = builder->CreateNSWAdd(min, extent); @@ -3579,412 +3620,6 @@ void CodeGen_LLVM::visit(const For *op) { } } -void CodeGen_LLVM::do_parallel_tasks(const vector &tasks) { - Closure closure; - for (const auto &t : tasks) { - Stmt s = t.body; - if (!t.loop_var.empty()) { - s = LetStmt::make(t.loop_var, 0, s); - } - s.accept(&closure); - } - - // Allocate a closure - StructType *closure_t = build_closure_type(closure, halide_buffer_t_type, context); - Value *closure_ptr = create_alloca_at_entry(closure_t, 1); - - // Fill in the closure - pack_closure(closure_t, closure_ptr, closure, symbol_table, halide_buffer_t_type, builder); - - closure_ptr = builder->CreatePointerCast(closure_ptr, i8_t->getPointerTo()); - - int num_tasks = (int)tasks.size(); - - // Make space on the stack for the tasks - llvm::Value *task_stack_ptr = create_alloca_at_entry(parallel_task_t_type, num_tasks); - - llvm::Type *args_t[] = {i8_t->getPointerTo(), i32_t, i8_t->getPointerTo()}; - FunctionType *task_t = FunctionType::get(i32_t, args_t, false); - llvm::Type *loop_args_t[] = {i8_t->getPointerTo(), i32_t, i32_t, i8_t->getPointerTo(), i8_t->getPointerTo()}; - FunctionType *loop_task_t = FunctionType::get(i32_t, loop_args_t, false); - - Value *result = nullptr; - - for (int i = 0; i < num_tasks; i++) { - ParallelTask t = tasks[i]; - - // Analyze the task body - class MayBlock : public IRVisitor { - using IRVisitor::visit; - void visit(const Acquire *op) override { - result = true; - } - - public: - bool result = false; - }; - - // TODO(zvookin|abadams): This makes multiple passes over the - // IR to cover each node. (One tree walk produces the min - // thread count for all nodes, but we redo each subtree when - // compiling a given node.) Ideally we'd move to a lowering pass - // that converts our parallelism constructs to Call nodes, or - // direct hardware operations in some cases. - // Also, this code has to exactly mirror the logic in get_parallel_tasks. - // It would be better to do one pass on the tree and centralize the task - // deduction logic in one place. - class MinThreads : public IRVisitor { - using IRVisitor::visit; - - std::pair skip_acquires(Stmt first) { - int count = 0; - while (first.defined()) { - const Acquire *acq = first.as(); - if (acq == nullptr) { - break; - } - count++; - first = acq->body; - } - return {first, count}; - } - - void visit(const Fork *op) override { - int total_threads = 0; - int direct_acquires = 0; - // Take the sum of min threads across all - // cascaded Fork nodes. - const Fork *node = op; - while (node != nullptr) { - result = 0; - auto after_acquires = skip_acquires(node->first); - direct_acquires += after_acquires.second; - - after_acquires.first.accept(this); - total_threads += result; - - const Fork *continued_branches = node->rest.as(); - if (continued_branches == nullptr) { - result = 0; - after_acquires = skip_acquires(node->rest); - direct_acquires += after_acquires.second; - after_acquires.first.accept(this); - total_threads += result; - } - node = continued_branches; - } - if (direct_acquires == 0 && total_threads == 0) { - result = 0; - } else { - result = total_threads + 1; - } - } - - void visit(const For *op) override { - result = 0; - - if (op->for_type == ForType::Parallel) { - IRVisitor::visit(op); - if (result > 0) { - result += 1; - } - } else if (op->for_type == ForType::Serial) { - auto after_acquires = skip_acquires(op->body); - if (after_acquires.second > 0 && - !expr_uses_var(op->body.as()->count, op->name)) { - after_acquires.first.accept(this); - result++; - } else { - IRVisitor::visit(op); - } - } else { - IRVisitor::visit(op); - } - } - - // This is a "standalone" Acquire and will result in its own task. - // Treat it requiring one more thread than its body. - void visit(const Acquire *op) override { - result = 0; - auto after_inner_acquires = skip_acquires(op); - after_inner_acquires.first.accept(this); - result = result + 1; - } - - void visit(const Block *op) override { - result = 0; - op->first.accept(this); - int result_first = result; - result = 0; - op->rest.accept(this); - result = std::max(result, result_first); - } - - public: - int result = 0; - }; - MinThreads min_threads; - t.body.accept(&min_threads); - - // Decide if we're going to call do_par_for or - // do_parallel_tasks. halide_do_par_for is simpler, but - // assumes a bunch of things. Programs that don't use async - // can also enter the task system via do_par_for. - Value *task_parent = sym_get("__task_parent", false); - bool use_do_par_for = (num_tasks == 1 && - min_threads.result == 0 && - t.semaphores.empty() && - !task_parent); - - // Make the array of semaphore acquisitions this task needs to do before it runs. - Value *semaphores; - Value *num_semaphores = ConstantInt::get(i32_t, (int)t.semaphores.size()); - if (!t.semaphores.empty()) { - semaphores = create_alloca_at_entry(semaphore_acquire_t_type, (int)t.semaphores.size()); - for (int i = 0; i < (int)t.semaphores.size(); i++) { - Value *semaphore = codegen(t.semaphores[i].semaphore); - semaphore = builder->CreatePointerCast(semaphore, semaphore_t_type->getPointerTo()); - Value *count = codegen(t.semaphores[i].count); - Value *slot_ptr = builder->CreateConstGEP2_32(semaphore_acquire_t_type, semaphores, i, 0); - builder->CreateStore(semaphore, slot_ptr); - slot_ptr = builder->CreateConstGEP2_32(semaphore_acquire_t_type, semaphores, i, 1); - builder->CreateStore(count, slot_ptr); - } - } else { - semaphores = ConstantPointerNull::get(semaphore_acquire_t_type->getPointerTo()); - } - - FunctionType *fn_type = use_do_par_for ? task_t : loop_task_t; - int closure_arg_idx = use_do_par_for ? 2 : 3; - - // Make a new function that does the body - llvm::Function *containing_function = function; - function = llvm::Function::Create(fn_type, llvm::Function::InternalLinkage, - t.name, module.get()); - - llvm::Value *task_ptr = builder->CreatePointerCast(function, fn_type->getPointerTo()); - - function->addParamAttr(closure_arg_idx, Attribute::NoAlias); - - set_function_attributes_for_target(function, target); - - // Make the initial basic block and jump the builder into the new function - IRBuilderBase::InsertPoint call_site = builder->saveIP(); - BasicBlock *block = BasicBlock::Create(*context, "entry", function); - builder->SetInsertPoint(block); - - // Save the destructor block - BasicBlock *parent_destructor_block = destructor_block; - destructor_block = nullptr; - - // Make a new scope to use - Scope saved_symbol_table; - symbol_table.swap(saved_symbol_table); - - // Get the function arguments - - // The user context is first argument of the function; it's - // important that we override the name to be "__user_context", - // since the LLVM function has a random auto-generated name for - // this argument. - llvm::Function::arg_iterator iter = function->arg_begin(); - sym_push("__user_context", iterator_to_pointer(iter)); - - if (use_do_par_for) { - // Next is the loop variable. - ++iter; - sym_push(t.loop_var, iterator_to_pointer(iter)); - } else if (!t.loop_var.empty()) { - // We peeled off a loop. Wrap a new loop around the body - // that just does the slice given by the arguments. - string loop_min_name = unique_name('t'); - string loop_extent_name = unique_name('t'); - t.body = For::make(t.loop_var, - Variable::make(Int(32), loop_min_name), - Variable::make(Int(32), loop_extent_name), - ForType::Serial, - DeviceAPI::None, - t.body); - ++iter; - sym_push(loop_min_name, iterator_to_pointer(iter)); - ++iter; - sym_push(loop_extent_name, iterator_to_pointer(iter)); - } else { - // This task is not any kind of loop, so skip these args. - ++iter; - ++iter; - } - - // The closure pointer is either the last (for halide_do_par_for) or - // second to last argument (for halide_do_parallel_tasks). - ++iter; - iter->setName("closure"); - Value *closure_handle = builder->CreatePointerCast(iterator_to_pointer(iter), - closure_t->getPointerTo()); - - // Load everything from the closure into the new scope - unpack_closure(closure, symbol_table, closure_t, closure_handle, builder); - - if (!use_do_par_for) { - // For halide_do_parallel_tasks the threading runtime task parent - // is the last argument. - ++iter; - iter->setName("task_parent"); - sym_push("__task_parent", iterator_to_pointer(iter)); - } - - // Generate the new function body - codegen(t.body); - - // Return success - return_with_error_code(ConstantInt::get(i32_t, 0)); - - // Move the builder back to the main function. - builder->restoreIP(call_site); - - // Now restore the scope - symbol_table.swap(saved_symbol_table); - function = containing_function; - - // Restore the destructor block - destructor_block = parent_destructor_block; - - Value *min = codegen(t.min); - Value *extent = codegen(t.extent); - Value *serial = codegen(cast(UInt(8), t.serial)); - - if (use_do_par_for) { - llvm::Function *do_par_for = module->getFunction("halide_do_par_for"); - internal_assert(do_par_for) << "Could not find halide_do_par_for in initial module\n"; - do_par_for->addParamAttr(4, Attribute::NoAlias); - Value *args[] = {get_user_context(), task_ptr, min, extent, closure_ptr}; - debug(4) << "Creating call to do_par_for\n"; - result = builder->CreateCall(do_par_for, args); - } else { - // Populate the task struct - Value *slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 0); - builder->CreateStore(task_ptr, slot_ptr); - slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 1); - builder->CreateStore(closure_ptr, slot_ptr); - slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 2); - builder->CreateStore(create_string_constant(t.name), slot_ptr); - slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 3); - builder->CreateStore(semaphores, slot_ptr); - slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 4); - builder->CreateStore(num_semaphores, slot_ptr); - slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 5); - builder->CreateStore(min, slot_ptr); - slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 6); - builder->CreateStore(extent, slot_ptr); - slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 7); - builder->CreateStore(ConstantInt::get(i32_t, min_threads.result), slot_ptr); - slot_ptr = builder->CreateConstGEP2_32(parallel_task_t_type, task_stack_ptr, i, 8); - builder->CreateStore(serial, slot_ptr); - } - } - - if (!result) { - llvm::Function *do_parallel_tasks = module->getFunction("halide_do_parallel_tasks"); - internal_assert(do_parallel_tasks) << "Could not find halide_do_parallel_tasks in initial module\n"; - do_parallel_tasks->addParamAttr(2, Attribute::NoAlias); - Value *task_parent = sym_get("__task_parent", false); - if (!task_parent) { - task_parent = ConstantPointerNull::get(i8_t->getPointerTo()); // void* - } - Value *args[] = {get_user_context(), - ConstantInt::get(i32_t, num_tasks), - task_stack_ptr, - task_parent}; - result = builder->CreateCall(do_parallel_tasks, args); - } - - // Check for success - Value *did_succeed = builder->CreateICmpEQ(result, ConstantInt::get(i32_t, 0)); - create_assertion(did_succeed, Expr(), result); -} - -namespace { - -string task_debug_name(const std::pair &prefix) { - if (prefix.second <= 1) { - return prefix.first; - } else { - return prefix.first + "_" + std::to_string(prefix.second - 1); - } -} - -void add_fork(std::pair &prefix) { - if (prefix.second == 0) { - prefix.first += ".fork"; - } - prefix.second++; -} - -void add_suffix(std::pair &prefix, const string &suffix) { - if (prefix.second > 1) { - prefix.first += "_" + std::to_string(prefix.second - 1); - prefix.second = 0; - } - prefix.first += suffix; -} - -} // namespace - -void CodeGen_LLVM::get_parallel_tasks(const Stmt &s, vector &result, std::pair prefix) { - const For *loop = s.as(); - const Acquire *acquire = loop ? loop->body.as() : s.as(); - if (const Fork *f = s.as()) { - add_fork(prefix); - get_parallel_tasks(f->first, result, prefix); - get_parallel_tasks(f->rest, result, prefix); - } else if (!loop && acquire) { - const Variable *v = acquire->semaphore.as(); - internal_assert(v); - add_suffix(prefix, "." + v->name); - ParallelTask t{s, {}, "", 0, 1, const_false(), task_debug_name(prefix)}; - while (acquire) { - t.semaphores.push_back({acquire->semaphore, acquire->count}); - t.body = acquire->body; - acquire = t.body.as(); - } - result.push_back(t); - } else if (loop && loop->for_type == ForType::Parallel) { - add_suffix(prefix, ".par_for." + loop->name); - result.push_back(ParallelTask{loop->body, {}, loop->name, loop->min, loop->extent, const_false(), task_debug_name(prefix)}); - } else if (loop && - loop->for_type == ForType::Serial && - acquire && - !expr_uses_var(acquire->count, loop->name)) { - const Variable *v = acquire->semaphore.as(); - internal_assert(v); - add_suffix(prefix, ".for." + v->name); - ParallelTask t{loop->body, {}, loop->name, loop->min, loop->extent, const_true(), task_debug_name(prefix)}; - while (acquire) { - t.semaphores.push_back({acquire->semaphore, acquire->count}); - t.body = acquire->body; - acquire = t.body.as(); - } - result.push_back(t); - } else { - add_suffix(prefix, "." + std::to_string(result.size())); - result.push_back(ParallelTask{s, {}, "", 0, 1, const_false(), task_debug_name(prefix)}); - } -} - -void CodeGen_LLVM::do_as_parallel_task(const Stmt &s) { - vector tasks; - get_parallel_tasks(s, tasks, {function->getName().str(), 0}); - do_parallel_tasks(tasks); -} - -void CodeGen_LLVM::visit(const Acquire *op) { - do_as_parallel_task(op); -} - -void CodeGen_LLVM::visit(const Fork *op) { - do_as_parallel_task(op); -} - void CodeGen_LLVM::visit(const Store *op) { if (!emit_atomic_stores) { // Peel lets off the index to make us more likely to pattern diff --git a/src/CodeGen_LLVM.h b/src/CodeGen_LLVM.h index ed3f0a0d34f9..a2bf7dec2478 100644 --- a/src/CodeGen_LLVM.h +++ b/src/CodeGen_LLVM.h @@ -262,24 +262,6 @@ class CodeGen_LLVM : public IRVisitor { /** Codegen a block of asserts with pure conditions */ void codegen_asserts(const std::vector &asserts); - /** Codegen a call to do_parallel_tasks */ - struct ParallelTask { - Stmt body; - struct SemAcquire { - Expr semaphore; - Expr count; - }; - std::vector semaphores; - std::string loop_var; - Expr min, extent; - Expr serial; - std::string name; - }; - int task_depth; - void get_parallel_tasks(const Stmt &s, std::vector &tasks, std::pair prefix); - void do_parallel_tasks(const std::vector &tasks); - void do_as_parallel_task(const Stmt &s); - /** Return the the pipeline with the given error code. Will run * the destructor block. */ void return_with_error_code(llvm::Value *error_code); @@ -356,10 +338,8 @@ class CodeGen_LLVM : public IRVisitor { void visit(const AssertStmt *) override; void visit(const ProducerConsumer *) override; void visit(const For *) override; - void visit(const Acquire *) override; void visit(const Store *) override; void visit(const Block *) override; - void visit(const Fork *) override; void visit(const IfThenElse *) override; void visit(const Evaluate *) override; void visit(const Shuffle *) override; diff --git a/src/HexagonOffload.cpp b/src/HexagonOffload.cpp index 1b71f995b1fe..8ffd1d0c2e4d 100644 --- a/src/HexagonOffload.cpp +++ b/src/HexagonOffload.cpp @@ -9,6 +9,7 @@ #include "InjectHostDevBufferCopies.h" #include "LLVM_Headers.h" #include "LLVM_Output.h" +#include "LowerParallelTasks.h" #include "Module.h" #include "Param.h" #include "Substitute.h" @@ -753,11 +754,20 @@ class InjectHexagonRpc : public IRMutator { } // Build a closure for the device code. + // Note that we must do this *before* calling lower_parallel_tasks(); + // otherwise the Closure may fail to find buffers that are referenced + // only in the closure. // TODO: Should this move the body of the loop to Hexagon, // or the loop itself? Currently, this moves the loop itself. Closure c; c.include(body); + std::vector closure_implementations; + body = lower_parallel_tasks(body, closure_implementations, hex_name, device_code.target()); + for (auto &lowered_func : closure_implementations) { + device_code.append(lowered_func); + } + // A buffer parameter potentially generates 3 scalar parameters (min, // extent, stride) per dimension. Pipelines with many buffers may // generate extreme numbers of scalar parameters, which can cause diff --git a/src/IR.cpp b/src/IR.cpp index 308261b3c459..d74f51d4d090 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -597,11 +597,12 @@ const char *const intrinsic_op_names[] = { "cast_mask", "count_leading_zeros", "count_trailing_zeros", - "declare_box_touched", "debug_to_file", + "declare_box_touched", "div_round_to_zero", "dynamic_shuffle", "extract_mask_element", + "get_user_context", "gpu_thread_barrier", "halving_add", "halving_sub", @@ -616,6 +617,7 @@ const char *const intrinsic_op_names[] = { "lerp", "likely", "likely_if_innermost", + "load_typed_struct_member", "make_struct", "memoize_expr", "mod_round_to_zero", diff --git a/src/IR.h b/src/IR.h index d125b3cfae4a..a0311f3c86a4 100644 --- a/src/IR.h +++ b/src/IR.h @@ -506,11 +506,12 @@ struct Call : public ExprNode { cast_mask, count_leading_zeros, count_trailing_zeros, - declare_box_touched, debug_to_file, + declare_box_touched, div_round_to_zero, dynamic_shuffle, extract_mask_element, + get_user_context, gpu_thread_barrier, halving_add, halving_sub, @@ -525,6 +526,7 @@ struct Call : public ExprNode { lerp, likely, likely_if_innermost, + load_typed_struct_member, make_struct, memoize_expr, mod_round_to_zero, @@ -564,6 +566,7 @@ struct Call : public ExprNode { widening_shift_left, widening_shift_right, widening_sub, + IntrinsicOpCount // Sentinel: keep last. }; diff --git a/src/Lower.cpp b/src/Lower.cpp index fd8ed5305d14..61eef0e1ae35 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -40,6 +40,7 @@ #include "Inline.h" #include "LICM.h" #include "LoopCarry.h" +#include "LowerParallelTasks.h" #include "LowerWarpShuffles.h" #include "Memoization.h" #include "OffloadGPULoops.h" @@ -110,6 +111,8 @@ void lower_impl(const vector &output_funcs, Module &result_module) { auto time_start = std::chrono::high_resolution_clock::now(); + size_t initial_lowered_function_count = result_module.functions().size(); + // Create a deep-copy of the entire graph of Funcs. auto [outputs, env] = deep_copy(output_funcs, build_environment(output_funcs)); @@ -440,6 +443,34 @@ void lower_impl(const vector &output_funcs, debug(1) << "Skipping GPU offload...\n"; } + // TODO: This needs to happen before lowering parallel tasks, because global + // images used inside parallel loops are rewritten from loads from images to + // loads from closure parameters. Closure parameters are missing the Buffer<> + // object, which needs to be found by infer_arguments here. Running + // infer_arguments prior to lower_parallel_tasks is a hacky solution to this + // problem. It would be better if closures could directly reference globals + // so they don't add overhead to the closure. + vector inferred_args = infer_arguments(s, outputs); + + std::vector closure_implementations; + debug(1) << "Lowering Parallel Tasks...\n"; + s = lower_parallel_tasks(s, closure_implementations, pipeline_name, t); + // Process any LoweredFunctions added by other passes. In practice, this + // will likely not work well enough due to ordering issues with + // closure generating passes and instead all such passes will need to + // be done at once. + for (size_t i = initial_lowered_function_count; i < result_module.functions().size(); i++) { + // Note that lower_parallel_tasks() appends to the end of closure_implementations + result_module.functions()[i].body = + lower_parallel_tasks(result_module.functions()[i].body, closure_implementations, + result_module.functions()[i].name, t); + } + for (auto &lowered_func : closure_implementations) { + result_module.append(lowered_func); + } + debug(2) << "Lowering after generating parallel tasks and closures:\n" + << s << "\n\n"; + vector public_args = args; for (const auto &out : outputs) { for (const Parameter &buf : out.output_buffers()) { @@ -449,7 +480,6 @@ void lower_impl(const vector &output_funcs, } } - vector inferred_args = infer_arguments(s, outputs); for (const InferredArgument &arg : inferred_args) { if (arg.param.defined() && arg.param.name() == "__user_context") { // The user context is always in the inferred args, but is diff --git a/src/LowerParallelTasks.cpp b/src/LowerParallelTasks.cpp new file mode 100644 index 000000000000..5ea1f9d539f1 --- /dev/null +++ b/src/LowerParallelTasks.cpp @@ -0,0 +1,436 @@ +#include "LowerParallelTasks.h" + +#include + +#include "Argument.h" +#include "Closure.h" +#include "DebugArguments.h" +#include "ExprUsesVar.h" +#include "IRMutator.h" +#include "IROperator.h" +#include "Module.h" +#include "Param.h" +#include "Simplify.h" + +namespace Halide { +namespace Internal { + +namespace { + +// TODO(zalman): Find a better place for this code to live. +LoweredFunc generate_closure_ir(const std::string &name, const Closure &closure, + std::vector &args, int closure_arg_index, + const Stmt &body, const Target &t) { + + std::string closure_arg_name = unique_name("closure_arg"); + args[closure_arg_index] = LoweredArgument(closure_arg_name, Argument::Kind::InputScalar, + type_of(), 0, ArgumentEstimates()); + + Expr closure_arg = Variable::make(closure.pack_into_struct().type(), closure_arg_name); + + Stmt wrapped_body = closure.unpack_from_struct(closure_arg, body); + + // TODO(zvookin): Figure out how we want to handle name mangling of closures. + // For now, the C++ backend makes them extern "C" so they have to be NameMangling::C. + LoweredFunc result{name, args, wrapped_body, LinkageType::External, NameMangling::C}; + if (t.has_feature(Target::Debug)) { + debug_arguments(&result, t); + } + return result; +} + +std::string task_debug_name(const std::pair &prefix) { + if (prefix.second <= 1) { + return prefix.first; + } else { + return prefix.first + "_" + std::to_string(prefix.second - 1); + } +} + +void add_fork(std::pair &prefix) { + if (prefix.second == 0) { + prefix.first += ".fork"; + } + prefix.second++; +} + +void add_suffix(std::pair &prefix, const std::string &suffix) { + if (prefix.second > 1) { + prefix.first += "_" + std::to_string(prefix.second - 1); + prefix.second = 0; + } + prefix.first += suffix; +} + +// TODO(zvookin|abadams): This makes multiple passes over the +// IR to cover each node. (One tree walk produces the min +// thread count for all nodes, but we redo each subtree when +// compiling a given node.) Ideally we'd move to a lowering pass +// that converts our parallelism constructs to Call nodes, or +// direct hardware operations in some cases. +// Also, this code has to exactly mirror the logic in get_parallel_tasks. +// It would be better to do one pass on the tree and centralize the task +// deduction logic in one place. +class MinThreads : public IRVisitor { + using IRVisitor::visit; + + std::pair skip_acquires(Stmt first) { + int count = 0; + while (first.defined()) { + const Acquire *acq = first.as(); + if (acq == nullptr) { + break; + } + count++; + first = acq->body; + } + return {first, count}; + } + + void visit(const Fork *op) override { + int total_threads = 0; + int direct_acquires = 0; + // Take the sum of min threads across all + // cascaded Fork nodes. + const Fork *node = op; + while (node != nullptr) { + result = 0; + auto after_acquires = skip_acquires(node->first); + direct_acquires += after_acquires.second; + + after_acquires.first.accept(this); + total_threads += result; + + const Fork *continued_branches = node->rest.as(); + if (continued_branches == nullptr) { + result = 0; + after_acquires = skip_acquires(node->rest); + direct_acquires += after_acquires.second; + after_acquires.first.accept(this); + total_threads += result; + } + node = continued_branches; + } + if (direct_acquires == 0 && total_threads == 0) { + result = 0; + } else { + result = total_threads + 1; + } + } + + void visit(const For *op) override { + result = 0; + + if (op->for_type == ForType::Parallel) { + IRVisitor::visit(op); + if (result > 0) { + result += 1; + } + } else if (op->for_type == ForType::Serial) { + auto after_acquires = skip_acquires(op->body); + if (after_acquires.second > 0 && + !expr_uses_var(op->body.as()->count, op->name)) { + after_acquires.first.accept(this); + result++; + } else { + IRVisitor::visit(op); + } + } else { + IRVisitor::visit(op); + } + } + + // This is a "standalone" Acquire and will result in its own task. + // Treat it requiring one more thread than its body. + void visit(const Acquire *op) override { + result = 0; + auto after_inner_acquires = skip_acquires(op); + after_inner_acquires.first.accept(this); + result = result + 1; + } + + void visit(const Block *op) override { + result = 0; + op->first.accept(this); + int result_first = result; + result = 0; + op->rest.accept(this); + result = std::max(result, result_first); + } + +public: + int result = 0; +}; + +int calculate_min_threads(const Stmt &body) { + MinThreads min_threads; + body.accept(&min_threads); + return min_threads.result; +} + +struct LowerParallelTasks : public IRMutator { + + /** Codegen a call to do_parallel_tasks */ + struct ParallelTask { + Stmt body; + struct SemAcquire { + Expr semaphore; + Expr count; + }; + std::vector semaphores; + std::string loop_var; + Expr min, extent; + Expr serial; + std::string name; + }; + + using IRMutator::visit; + + Stmt visit(const For *op) override { + const Acquire *acquire = op->body.as(); + + if (op->for_type == ForType::Parallel || + (op->for_type == ForType::Serial && + acquire && + !expr_uses_var(acquire->count, op->name))) { + return do_as_parallel_task(op); + } + return IRMutator::visit(op); + } + + Stmt visit(const Acquire *op) override { + return do_as_parallel_task(op); + } + + Stmt visit(const Fork *op) override { + return do_as_parallel_task(op); + } + + Stmt rewrite_parallel_tasks(const std::vector &tasks) { + Stmt body; + + Closure closure; + for (const auto &t : tasks) { + Stmt s = t.body; + if (!t.loop_var.empty()) { + s = LetStmt::make(t.loop_var, 0, s); + } + closure.include(s); + } + + // The same name can appear as a var and a buffer. Remove the var name in this case. + for (auto const &b : closure.buffers) { + closure.vars.erase(b.first); + } + + int num_tasks = (int)(tasks.size()); + std::vector tasks_array_args; + + std::string closure_name = unique_name("parallel_closure"); + Expr closure_struct_allocation = closure.pack_into_struct(); + Expr closure_struct = Variable::make(Handle(), closure_name); + + const bool has_task_parent = !task_parents.empty() && task_parents.top_ref().defined(); + + Expr result; + for (int i = 0; i < num_tasks; i++) { + ParallelTask t = tasks[i]; + + const int min_threads = calculate_min_threads(t.body); + + // Decide if we're going to call do_par_for or + // do_parallel_tasks. halide_do_par_for is simpler, but + // assumes a bunch of things. Programs that don't use async + // can also enter the task system via do_par_for. + const bool use_parallel_for = (num_tasks == 1 && + min_threads == 0 && + t.semaphores.empty() && + !has_task_parent); + + std::string semaphores_array_name = unique_name("task_semaphores"); + Expr semaphores_array; + std::vector semaphore_args(t.semaphores.size() * 2); + for (int i = 0; i < (int)t.semaphores.size(); i++) { + semaphore_args[i * 2] = t.semaphores[i].semaphore; + semaphore_args[i * 2 + 1] = t.semaphores[i].count; + } + semaphores_array = Call::make(type_of(), Call::make_struct, semaphore_args, Call::PureIntrinsic); + + Expr closure_task_parent; + std::vector closure_args(use_parallel_for ? 3 : 5); + int closure_arg_index; + closure_args[0] = LoweredArgument("__user_context", Argument::Kind::InputScalar, + type_of(), 0, ArgumentEstimates()); + if (use_parallel_for) { + closure_arg_index = 2; + // closure_task_parent remains undefined here. + closure_args[1] = LoweredArgument(t.loop_var, Argument::Kind::InputScalar, + Int(32), 0, ArgumentEstimates()); + } else { + closure_arg_index = 3; + const std::string closure_task_parent_name = unique_name("__task_parent"); + closure_task_parent = Variable::make(type_of(), closure_task_parent_name); + // We peeled off a loop. Wrap a new loop around the body + // that just does the slice given by the arguments. + std::string loop_min_name = unique_name('t'); + std::string loop_extent_name = unique_name('t'); + if (!t.loop_var.empty()) { + t.body = For::make(t.loop_var, + Variable::make(Int(32), loop_min_name), + Variable::make(Int(32), loop_extent_name), + ForType::Serial, + DeviceAPI::None, + t.body); + } else { + internal_assert(is_const_one(t.extent)); + } + closure_args[1] = LoweredArgument(loop_min_name, Argument::Kind::InputScalar, + Int(32), 0, ArgumentEstimates()); + closure_args[2] = LoweredArgument(loop_extent_name, Argument::Kind::InputScalar, + Int(32), 0, ArgumentEstimates()); + closure_args[4] = LoweredArgument(closure_task_parent_name, Argument::Kind::InputScalar, + type_of(), 0, ArgumentEstimates()); + } + + { + ScopedValue save_name(function_name, t.name); + + task_parents.push(closure_task_parent); + t.body = mutate(t.body); + task_parents.pop(); + } + + std::string new_function_name = c_print_name(unique_name(t.name), false); + // Note that closure_args[closure_arg_index] will be filled in by the call to generate_closure_ir() + closure_implementations.emplace_back(generate_closure_ir(new_function_name, closure, closure_args, + closure_arg_index, t.body, target)); + + if (use_parallel_for) { + std::vector args(4); + // Codegen will add user_context for us + + // Prefix the function name with "::" as we would in C++ to make + // it clear we're talking about something in global scope in + // case some joker names an intermediate Func or Var the same + // name as the pipeline. This prefix works transparently in the + // C++ backend. + args[0] = Variable::make(Handle(), "::" + new_function_name); + args[1] = t.min; + args[2] = t.extent; + args[3] = Cast::make(type_of(), closure_struct); + result = Call::make(Int(32), "halide_do_par_for", args, Call::Extern); + } else { + tasks_array_args.emplace_back(Variable::make(Handle(), "::" + new_function_name)); + tasks_array_args.emplace_back(Cast::make(type_of(), closure_struct)); + tasks_array_args.emplace_back(StringImm::make(t.name)); + tasks_array_args.emplace_back(semaphores_array); + tasks_array_args.emplace_back((int)t.semaphores.size()); + tasks_array_args.emplace_back(t.min); + tasks_array_args.emplace_back(t.extent); + tasks_array_args.emplace_back(min_threads); + tasks_array_args.emplace_back(Cast::make(Bool(), t.serial)); + } + } + + if (!tasks_array_args.empty()) { + // Allocate task list array + Expr tasks_list = Call::make(Handle(), Call::make_struct, tasks_array_args, Call::PureIntrinsic); + Expr user_context = Call::make(type_of(), Call::get_user_context, {}, Call::PureIntrinsic); + Expr task_parent = has_task_parent ? task_parents.top() : make_zero(Handle()); + result = Call::make(Int(32), "halide_do_parallel_tasks", + {user_context, make_const(Int(32), num_tasks), tasks_list, task_parent}, + Call::Extern); + } + + std::string closure_result_name = unique_name("closure_result"); + Expr closure_result = Variable::make(Int(32), closure_result_name); + Stmt stmt = AssertStmt::make(closure_result == 0, closure_result); + stmt = LetStmt::make(closure_result_name, result, stmt); + stmt = LetStmt::make(closure_name, closure_struct_allocation, stmt); + return stmt; + } + + void get_parallel_tasks(const Stmt &s, std::vector &result, std::pair prefix) { + const For *loop = s.as(); + const Acquire *acquire = loop ? loop->body.as() : s.as(); + if (const Fork *f = s.as()) { + add_fork(prefix); + get_parallel_tasks(f->first, result, prefix); + get_parallel_tasks(f->rest, result, prefix); + } else if (!loop && acquire) { + const Variable *v = acquire->semaphore.as(); + internal_assert(v); + add_suffix(prefix, "." + v->name); + ParallelTask t{s, {}, "", 0, 1, const_false(), task_debug_name(prefix)}; + while (acquire) { + t.semaphores.push_back({acquire->semaphore, acquire->count}); + t.body = acquire->body; + acquire = t.body.as(); + } + result.emplace_back(std::move(t)); + } else if (loop && loop->for_type == ForType::Parallel) { + add_suffix(prefix, ".par_for." + loop->name); + ParallelTask t{loop->body, {}, loop->name, loop->min, loop->extent, const_false(), task_debug_name(prefix)}; + result.emplace_back(std::move(t)); + } else if (loop && + loop->for_type == ForType::Serial && + acquire && + !expr_uses_var(acquire->count, loop->name)) { + const Variable *v = acquire->semaphore.as(); + internal_assert(v); + add_suffix(prefix, ".for." + v->name); + ParallelTask t{loop->body, {}, loop->name, loop->min, loop->extent, const_true(), task_debug_name(prefix)}; + while (acquire) { + t.semaphores.push_back({acquire->semaphore, acquire->count}); + t.body = acquire->body; + acquire = t.body.as(); + } + result.emplace_back(std::move(t)); + } else { + add_suffix(prefix, "." + std::to_string(result.size())); + ParallelTask t{s, {}, "", 0, 1, const_false(), task_debug_name(prefix)}; + result.emplace_back(std::move(t)); + } + } + + Stmt do_as_parallel_task(const Stmt &s) { + std::vector tasks; + get_parallel_tasks(s, tasks, {function_name, 0}); + return rewrite_parallel_tasks(tasks); + } + + LowerParallelTasks(const std::string &name, const Target &t) + : function_name(name), target(t) { + } + + std::string function_name; + const Target ⌖ + std::vector closure_implementations; + SmallStack task_parents; +}; + +} // namespace + +Stmt lower_parallel_tasks(const Stmt &s, std::vector &closure_implementations, + const std::string &name, const Target &t) { + LowerParallelTasks lowering_mutator(name, t); + Stmt result = lowering_mutator.mutate(s); + + // Main body will be dumped as part of standard lowering debugging, but closures will not be. + if (debug::debug_level() >= 2) { + for (const auto &lf : lowering_mutator.closure_implementations) { + debug(2) << "lower_parallel_tasks generated closure lowered function " << lf.name << ":\n" + << lf.body << "\n\n"; + } + } + + // Append to the end rather than replacing the list entirely. + closure_implementations.insert(closure_implementations.end(), + lowering_mutator.closure_implementations.begin(), + lowering_mutator.closure_implementations.end()); + + return result; +} + +} // namespace Internal +} // namespace Halide diff --git a/src/LowerParallelTasks.h b/src/LowerParallelTasks.h new file mode 100644 index 000000000000..509436659fe2 --- /dev/null +++ b/src/LowerParallelTasks.h @@ -0,0 +1,21 @@ +#ifndef HALIDE_LOWER_PARALLEL_TASKS_H +#define HALIDE_LOWER_PARALLEL_TASKS_H + +/** \file + * + * Support for platform independent lowering of Halide parallel and async mechanisms. + * May eventually become a lowering pass. + */ + +#include "IRVisitor.h" + +namespace Halide { +namespace Internal { + +Stmt lower_parallel_tasks(const Stmt &s, std::vector &closure_implementations, + const std::string &name, const Target &t); + +} // namespace Internal +} // namespace Halide + +#endif // HALIDE_LOWER_PARALLEL_TASKS_H diff --git a/src/Type.h b/src/Type.h index f281cba2ac93..6e9feee8b613 100644 --- a/src/Type.h +++ b/src/Type.h @@ -82,7 +82,7 @@ struct halide_handle_cplusplus_type { std::vector enclosing_types; /// One set of modifiers on a type. - /// The const/volatile/restrict propertises are "inside" the pointer property. + /// The const/volatile/restrict properties are "inside" the pointer property. enum Modifier : uint8_t { Const = 1 << 0, ///< Bitmask flag for "const" Volatile = 1 << 1, ///< Bitmask flag for "volatile" @@ -170,6 +170,7 @@ HALIDE_DECLARE_EXTERN_STRUCT_TYPE(halide_dimension_t); HALIDE_DECLARE_EXTERN_STRUCT_TYPE(halide_device_interface_t); HALIDE_DECLARE_EXTERN_STRUCT_TYPE(halide_filter_metadata_t); HALIDE_DECLARE_EXTERN_STRUCT_TYPE(halide_semaphore_t); +HALIDE_DECLARE_EXTERN_STRUCT_TYPE(halide_semaphore_acquire_t); HALIDE_DECLARE_EXTERN_STRUCT_TYPE(halide_parallel_task_t); // You can make arbitrary user-defined types be "Known" using the diff --git a/src/Util.cpp b/src/Util.cpp index 081bd1e777f4..53f15fa14b2e 100644 --- a/src/Util.cpp +++ b/src/Util.cpp @@ -575,11 +575,12 @@ void halide_toc_impl(const char *file, int line) { debug(1) << t1.file << ":" << t1.line << " ... " << f << ":" << line << " : " << diff.count() * 1000 << " ms\n"; } -std::string c_print_name(const std::string &name) { +std::string c_print_name(const std::string &name, + bool prefix_underscore) { ostringstream oss; // Prefix an underscore to avoid reserved words (e.g. a variable named "while") - if (isalpha(name[0])) { + if (prefix_underscore && isalpha(name[0])) { oss << "_"; } diff --git a/src/Util.h b/src/Util.h index 5404560394f3..6f551d4174d9 100644 --- a/src/Util.h +++ b/src/Util.h @@ -402,8 +402,11 @@ struct IsRoundtrippable { } }; -/** Emit a version of a string that is a valid identifier in C (. is replaced with _) */ -std::string c_print_name(const std::string &name); +/** Emit a version of a string that is a valid identifier in C (. is replaced with _) + * If prefix_underscore is true (the default), an underscore will be prepended if the + * input starts with an alphabetic character to avoid reserved word clashes. + */ +std::string c_print_name(const std::string &name, bool prefix_underscore = true); /** Return the LLVM_VERSION against which this libHalide is compiled. This is provided * only for internal tests which need to verify behavior; please don't use this outside diff --git a/src/runtime/qurt_hvx.cpp b/src/runtime/qurt_hvx.cpp index 9748cc662142..afc607bfce47 100644 --- a/src/runtime/qurt_hvx.cpp +++ b/src/runtime/qurt_hvx.cpp @@ -69,4 +69,29 @@ WEAK_INLINE uint8_t *_halide_hexagon_buffer_get_host(const hexagon_buffer_t_arg WEAK_INLINE uint64_t _halide_hexagon_buffer_get_device(const hexagon_buffer_t_arg *buf) { return buf->device; } + +WEAK_INLINE int _halide_hexagon_do_par_for(void *user_context, halide_task_t f, + int min, int size, uint8_t *closure, + int use_hvx) { + if (use_hvx) { + const int result = halide_qurt_hvx_unlock(user_context); + if (result != 0) { + return result; + } + } + + const int result = halide_do_par_for(user_context, f, min, size, closure); + if (result != 0) { + return result; + } + + if (use_hvx) { + const int result = halide_qurt_hvx_lock(user_context); + if (result != 0) { + return result; + } + } + + return 0; +} } diff --git a/test/common/check_call_graphs.h b/test/common/check_call_graphs.h index 2f65f86d01ca..c1db800bfee0 100644 --- a/test/common/check_call_graphs.h +++ b/test/common/check_call_graphs.h @@ -14,29 +14,29 @@ typedef std::map> CallGraphs; // For each producer node, find all functions that it calls. -class CheckCalls : public Halide::Internal::IRVisitor { +class CheckCalls : public Halide::Internal::IRMutator { public: CallGraphs calls; // Caller -> vector of callees std::string producer = ""; private: - using Halide::Internal::IRVisitor::visit; + using Halide::Internal::IRMutator::visit; - void visit(const Halide::Internal::ProducerConsumer *op) override { + Halide::Internal::Stmt visit(const Halide::Internal::ProducerConsumer *op) override { if (op->is_producer) { std::string old_producer = producer; producer = op->name; calls[producer]; // Make sure each producer is allocated a slot // Group the callees of the 'produce' and 'update' together - op->body.accept(this); + auto new_stmt = mutate(op->body); producer = old_producer; + return new_stmt; } else { - Halide::Internal::IRVisitor::visit(op); + return Halide::Internal::IRMutator::visit(op); } } - void visit(const Halide::Internal::Load *op) override { - Halide::Internal::IRVisitor::visit(op); + Halide::Expr visit(const Halide::Internal::Load *op) override { if (!producer.empty()) { assert(calls.count(producer) > 0); std::vector &callees = calls[producer]; @@ -44,11 +44,19 @@ class CheckCalls : public Halide::Internal::IRVisitor { callees.push_back(op->name); } } + return Halide::Internal::IRMutator::visit(op); } }; // These are declared "inline" to avoid "unused function" warnings -inline int check_call_graphs(CallGraphs &result, CallGraphs &expected) { +inline int check_call_graphs(Halide::Pipeline p, CallGraphs &expected) { + // Add a custom lowering pass that scrapes the call graph. We give ownership + // of it to the Pipeline, whose lifetime escapes this function. + CheckCalls *checker = new CheckCalls; + p.add_custom_lowering_pass(checker); + p.compile_to_module(p.infer_arguments(), ""); + CallGraphs &result = checker->calls; + if (result.size() != expected.size()) { printf("Expect %d callers instead of %d\n", (int)expected.size(), (int)result.size()); return -1; @@ -74,7 +82,7 @@ inline int check_call_graphs(CallGraphs &result, CallGraphs &expected) { return a.empty() ? b : a + ", " + b; }); - printf("Expect calless of %s to be (%s); got (%s) instead\n", + printf("Expect callees of %s to be (%s); got (%s) instead\n", iter.first.c_str(), expected_str.c_str(), result_str.c_str()); return -1; } diff --git a/test/correctness/func_clone.cpp b/test/correctness/func_clone.cpp index 26142b9c86d2..6009a53d0f83 100644 --- a/test/correctness/func_clone.cpp +++ b/test/correctness/func_clone.cpp @@ -64,15 +64,11 @@ int func_clone_test() { // Check the call graphs. // Expect 'g' to call 'clone', 'clone' to call nothing, and 'f' not // in the final IR. - Module m = g.compile_to_module({}); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {g.name(), {clone.name()}}, {clone.name(), {}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(g, expected) != 0) { return -1; } @@ -100,10 +96,6 @@ int multiple_funcs_sharing_clone_test() { // Expect 'g1' and 'g2' to call 'f_clone', 'g3' to call 'f', // f_clone' to call nothing, 'f' to call nothing Pipeline p({g1, g2, g3}); - Module m = p.compile_to_module({}, ""); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {g1.name(), {f_clone.name()}}, {g2.name(), {f_clone.name()}}, @@ -111,7 +103,7 @@ int multiple_funcs_sharing_clone_test() { {f_clone.name(), {}}, {f.name(), {}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(p, expected) != 0) { return -1; } @@ -156,50 +148,27 @@ int update_defined_after_clone_test() { f.compute_root(); clone.compute_root().vectorize(x, 8).unroll(x, 2).split(x, x, xi, 4).parallel(x); - { - param.set(true); - - // Check the call graphs. - // Expect initialization of 'g' to call 'clone' and its update to call - // 'clone' and 'g', clone' to call nothing, and 'f' not in the final IR. - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls c; - m.functions().front().body.accept(&c); - - CallGraphs expected = { - {g.name(), {clone.name(), g.name()}}, - {clone.name(), {}}, - }; - if (check_call_graphs(c.calls, expected) != 0) { - return -1; - } + // Check the call graphs. + // Expect initialization of 'g' to call 'clone' and its update to call + // 'clone' and 'g', clone' to call nothing, and 'f' not in the final IR. + CallGraphs expected = { + {g.name(), {clone.name(), g.name()}}, + {clone.name(), {}}, + }; + if (check_call_graphs(g, expected) != 0) { + return -1; + } - Buffer im = g.realize({200, 200}); - auto func = [](int x, int y) { - return ((0 <= x && x <= 99) && (0 <= y && y <= 99) && (x < y)) ? 3 * (x + y) : (x + y); - }; - if (check_image(im, func)) { - return -1; - } + Buffer im = g.realize({200, 200}); + auto func = [](int x, int y) { + return ((0 <= x && x <= 99) && (0 <= y && y <= 99) && (x < y)) ? 3 * (x + y) : (x + y); + }; + if (check_image(im, func)) { + return -1; } - { - param.set(false); - - // Check the call graphs. - // Expect initialization of 'g' to call 'clone' and its update to call - // 'clone' and 'g', clone' to call nothing, and 'f' not in the final IR. - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls c; - m.functions().front().body.accept(&c); - - CallGraphs expected = { - {g.name(), {clone.name(), g.name()}}, - {clone.name(), {}}, - }; - if (check_call_graphs(c.calls, expected) != 0) { - return -1; - } + for (bool param_value : {false, true}) { + param.set(param_value); Buffer im = g.realize({200, 200}); auto func = [](int x, int y) { @@ -236,10 +205,6 @@ int clone_depend_on_mutated_func_test() { // Check the call graphs. Pipeline p({d, e, f}); - Module m = p.compile_to_module({}, ""); - CheckCalls check; - m.functions().front().body.accept(&check); - CallGraphs expected = { {e.name(), {a.name()}}, {a.name(), {}}, @@ -250,7 +215,7 @@ int clone_depend_on_mutated_func_test() { {b.name(), {a_clone_in_b.name()}}, {a_clone_in_b.name(), {}}, }; - if (check_call_graphs(check.calls, expected) != 0) { + if (check_call_graphs(p, expected) != 0) { return -1; } @@ -298,10 +263,6 @@ int clone_on_clone_test() { // Check the call graphs. Pipeline p({c, d, e, f}); - Module m = p.compile_to_module({}, ""); - CheckCalls check; - m.functions().front().body.accept(&check); - CallGraphs expected = { {e.name(), {b.name(), a_clone_in_b_e_in_e.name()}}, {c.name(), {b.name()}}, @@ -313,7 +274,7 @@ int clone_on_clone_test() { {b_clone_in_d_f.name(), {a.name()}}, {a.name(), {}}, }; - if (check_call_graphs(check.calls, expected) != 0) { + if (check_call_graphs(p, expected) != 0) { return -1; } diff --git a/test/correctness/func_wrapper.cpp b/test/correctness/func_wrapper.cpp index 48107808d227..8d68e38a8f8b 100644 --- a/test/correctness/func_wrapper.cpp +++ b/test/correctness/func_wrapper.cpp @@ -78,16 +78,12 @@ int func_wrapper_test() { // Check the call graphs. // Expect 'g' to call 'wrapper', 'wrapper' to call 'f', 'f' to call nothing - Module m = g.compile_to_module({}); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {g.name(), {wrapper.name()}}, {wrapper.name(), {f.name()}}, {f.name(), {}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(g, expected) != 0) { return -1; } @@ -115,10 +111,6 @@ int multiple_funcs_sharing_wrapper_test() { // Expect 'g1' and 'g2' to call 'f_wrapper', 'g3' to call 'f', // f_wrapper' to call 'f', 'f' to call nothing Pipeline p({g1, g2, g3}); - Module m = p.compile_to_module({}, ""); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {g1.name(), {f_wrapper.name()}}, {g2.name(), {f_wrapper.name()}}, @@ -126,7 +118,7 @@ int multiple_funcs_sharing_wrapper_test() { {f_wrapper.name(), {f.name()}}, {f.name(), {}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(p, expected) != 0) { return -1; } @@ -165,17 +157,13 @@ int global_wrapper_test() { // Check the call graphs. // Expect 'g' to call 'wrapper', 'wrapper' to call 'f', 'f' to call nothing, // 'h' to call 'wrapper' and 'g' - Module m = h.compile_to_module({}); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {h.name(), {g.name(), wrapper.name()}}, {g.name(), {wrapper.name()}}, {wrapper.name(), {f.name()}}, {f.name(), {}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(h, expected) != 0) { return -1; } @@ -211,52 +199,17 @@ int update_defined_after_wrapper_test() { f.compute_root(); wrapper.compute_root().vectorize(x, 8).unroll(x, 2).split(x, x, xi, 4).parallel(x); - { - param.set(true); - - // Check the call graphs. - // Expect initialization of 'g' to call 'wrapper' and its update to call - // 'wrapper' and 'g', wrapper' to call 'f', 'f' to call nothing - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls c; - m.functions().front().body.accept(&c); - - CallGraphs expected = { - {g.name(), {wrapper.name(), g.name()}}, - {wrapper.name(), {f.name()}}, - {f.name(), {}}, - }; - if (check_call_graphs(c.calls, expected) != 0) { - return -1; - } - - Buffer im = g.realize({200, 200}); - auto func = [](int x, int y) { - return ((0 <= x && x <= 99) && (0 <= y && y <= 99) && (x < y)) ? 3 * (x + y) : (x + y); - }; - if (check_image(im, func)) { - return -1; - } + CallGraphs expected = { + {g.name(), {wrapper.name(), g.name()}}, + {wrapper.name(), {f.name()}}, + {f.name(), {}}, + }; + if (check_call_graphs(g, expected) != 0) { + return -1; } - { - param.set(false); - - // Check the call graphs. - // Expect initialization of 'g' to call 'wrapper' and its update to call - // 'wrapper' and 'g', wrapper' to call 'f', 'f' to call nothing - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls c; - m.functions().front().body.accept(&c); - - CallGraphs expected = { - {g.name(), {wrapper.name(), g.name()}}, - {wrapper.name(), {f.name()}}, - {f.name(), {}}, - }; - if (check_call_graphs(c.calls, expected) != 0) { - return -1; - } + for (bool param_value : {false, true}) { + param.set(param_value); Buffer im = g.realize({200, 200}); auto func = [](int x, int y) { @@ -292,16 +245,12 @@ int rdom_wrapper_test() { // Check the call graphs. // Expect 'wrapper' to call 'g', initialization of 'g' to call nothing // and its update to call 'f' and 'g', 'f' to call nothing - Module m = wrapper.compile_to_module({}); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {g.name(), {f.name(), g.name()}}, {wrapper.name(), {g.name()}}, {f.name(), {}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(wrapper, expected) != 0) { return -1; } @@ -329,9 +278,6 @@ int global_and_custom_wrapper_test() { // Check the call graphs. // Expect 'result' to call 'g' and 'f_wrapper', 'g' to call 'f_in_g', // 'f_wrapper' to call 'f', f_in_g' to call 'f', 'f' to call nothing - Module m = result.compile_to_module({}); - CheckCalls c; - m.functions().front().body.accept(&c); CallGraphs expected = { {result.name(), {g.name(), f_wrapper.name()}}, @@ -340,7 +286,7 @@ int global_and_custom_wrapper_test() { {f_in_g.name(), {f.name()}}, {f.name(), {}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(result, expected) != 0) { return -1; } @@ -373,10 +319,6 @@ int wrapper_depend_on_mutated_func_test() { // Check the call graphs. // Expect 'h' to call 'g_in_h', 'g_in_h' to call 'g', 'g' to call 'f', // 'f' to call 'e_in_f', e_in_f' to call 'e', 'e' to call nothing - Module m = h.compile_to_module({}); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {h.name(), {g_in_h.name()}}, {g_in_h.name(), {g.name()}}, @@ -385,7 +327,7 @@ int wrapper_depend_on_mutated_func_test() { {e_in_f.name(), {e.name()}}, {e.name(), {}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(h, expected) != 0) { return -1; } @@ -415,10 +357,6 @@ int wrapper_on_wrapper_test() { Func g_in_h = g.in(h).compute_root(); // Check the call graphs. - Module m = h.compile_to_module({}); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {h.name(), {f_in_h.name(), g_in_h.name(), f_in_f_in_g.name()}}, {f_in_h.name(), {f.name()}}, @@ -429,7 +367,7 @@ int wrapper_on_wrapper_test() { {f.name(), {e.name()}}, {e.name(), {}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(h, expected) != 0) { return -1; } @@ -461,10 +399,6 @@ int wrapper_on_rdom_predicate_test() { // Check the call graphs. // Expect 'g' to call nothing, update of 'g' to call 'g', f_in_g', and 'h_wrapper', // 'f_in_g' to call 'f', 'f' to call nothing, 'h_wrapper' to call 'h', 'h' to call nothing - Module m = g.compile_to_module({}); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {g.name(), {g.name(), f_in_g.name(), h_wrapper.name()}}, {f_in_g.name(), {f.name()}}, @@ -472,7 +406,7 @@ int wrapper_on_rdom_predicate_test() { {h_wrapper.name(), {h.name()}}, {h.name(), {}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(g, expected) != 0) { return -1; } @@ -502,17 +436,13 @@ int two_fold_wrapper_test() { input_in_output_in_output = input_in_output.in(output).compute_at(output, x).unroll(x).unroll(y); // Check the call graphs. - Module m = output.compile_to_module({}); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {output.name(), {input_in_output_in_output.name()}}, {input_in_output_in_output.name(), {input_in_output.name()}}, {input_in_output.name(), {input.name()}}, {input.name(), {}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(output, expected) != 0) { return -1; } @@ -545,10 +475,6 @@ int multi_folds_wrapper_test() { h.compute_root().tile(x, y, xi, yi, 8, 8); Pipeline p({g, h}); - Module m = p.compile_to_module({}, ""); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {g.name(), {f_in_g_in_g.name()}}, {f_in_g_in_g.name(), {f_in_g.name()}}, @@ -558,7 +484,7 @@ int multi_folds_wrapper_test() { {f_in_g_in_g_in_h_in_h.name(), {f_in_g_in_g_in_h.name()}}, {f_in_g_in_g_in_h.name(), {f_in_g_in_g.name()}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(p, expected) != 0) { return -1; } diff --git a/test/correctness/image_wrapper.cpp b/test/correctness/image_wrapper.cpp index 95b3f17d24ae..4fb5d9c1ef19 100644 --- a/test/correctness/image_wrapper.cpp +++ b/test/correctness/image_wrapper.cpp @@ -82,16 +82,12 @@ int func_wrapper_test() { // Check the call graphs. // Expect 'g' to call 'wrapper', 'wrapper' to call 'img_f', 'img_f' to call 'img' - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {g.name(), {wrapper.name()}}, {wrapper.name(), {img_f.name()}}, {img_f.name(), {img.name()}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(g, expected) != 0) { return -1; } @@ -124,10 +120,6 @@ int multiple_funcs_sharing_wrapper_test() { // Expect 'g1' and 'g2' to call 'im_wrapper', 'g3' to call 'img_f', // im_wrapper' to call 'img_f', 'img_f' to call 'img' Pipeline p({g1, g2, g3}); - Module m = p.compile_to_module({p.infer_arguments()}, ""); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {g1.name(), {im_wrapper.name()}}, {g2.name(), {im_wrapper.name()}}, @@ -135,7 +127,7 @@ int multiple_funcs_sharing_wrapper_test() { {im_wrapper.name(), {img_f.name()}}, {img_f.name(), {img.name()}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(p, expected) != 0) { return -1; } @@ -179,17 +171,13 @@ int global_wrapper_test() { // Check the call graphs. // Expect 'g' to call 'wrapper', 'wrapper' to call 'img_f', 'img_f' to call 'img', // 'h' to call 'wrapper' and 'g' - Module m = h.compile_to_module({h.infer_arguments()}); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {h.name(), {g.name(), wrapper.name()}}, {g.name(), {wrapper.name()}}, {wrapper.name(), {img_f.name()}}, {img_f.name(), {img.name()}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(h, expected) != 0) { return -1; } @@ -230,52 +218,20 @@ int update_defined_after_wrapper_test() { img_f.compute_root(); wrapper.compute_root().vectorize(_0, 8).unroll(_0, 2).split(_0, _0, xi, 4).parallel(_0); - { - param.set(true); - - // Check the call graphs. - // Expect initialization of 'g' to call 'wrapper' and its update to call - // 'wrapper' and 'g', wrapper' to call 'img_f', 'img_f' to call 'img' - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls c; - m.functions().front().body.accept(&c); - - CallGraphs expected = { - {g.name(), {wrapper.name(), g.name()}}, - {wrapper.name(), {img_f.name()}}, - {img_f.name(), {img.name()}}, - }; - if (check_call_graphs(c.calls, expected) != 0) { - return -1; - } - - Buffer im = g.realize({200, 200}); - auto func = [](int x, int y) { - return ((0 <= x && x <= 99) && (0 <= y && y <= 99) && (x < y)) ? 3 * (x + y) : (x + y); - }; - if (check_image(im, func)) { - return -1; - } + // Check the call graphs. + // Expect initialization of 'g' to call 'wrapper' and its update to call + // 'wrapper' and 'g', wrapper' to call 'img_f', 'img_f' to call 'img' + CallGraphs expected = { + {g.name(), {wrapper.name(), g.name()}}, + {wrapper.name(), {img_f.name()}}, + {img_f.name(), {img.name()}}, + }; + if (check_call_graphs(g, expected) != 0) { + return -1; } - { - param.set(false); - - // Check the call graphs. - // Expect initialization of 'g' to call 'wrapper' and its update to call - // 'wrapper' and 'g', wrapper' to call 'img_f', 'img_f' to call 'img' - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls c; - m.functions().front().body.accept(&c); - - CallGraphs expected = { - {g.name(), {wrapper.name(), g.name()}}, - {wrapper.name(), {img_f.name()}}, - {img_f.name(), {img.name()}}, - }; - if (check_call_graphs(c.calls, expected) != 0) { - return -1; - } + for (bool param_value : {false, true}) { + param.set(param_value); Buffer im = g.realize({200, 200}); auto func = [](int x, int y) { @@ -316,16 +272,12 @@ int rdom_wrapper_test() { // Check the call graphs. // Expect 'wrapper' to call 'g', initialization of 'g' to call nothing // and its update to call 'img_f' and 'g', 'img_f' to call 'img' - Module m = wrapper.compile_to_module({wrapper.infer_arguments()}); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {g.name(), {img_f.name(), g.name()}}, {wrapper.name(), {g.name()}}, {img_f.name(), {img.name()}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(wrapper, expected) != 0) { return -1; } @@ -358,10 +310,6 @@ int global_and_custom_wrapper_test() { // Check the call graphs. // Expect 'result' to call 'g' and 'img_wrapper', 'g' to call 'img_in_g', // 'img_wrapper' to call 'f', img_in_g' to call 'img_f', 'f' to call 'img' - Module m = result.compile_to_module({result.infer_arguments()}); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {result.name(), {g.name(), img_wrapper.name()}}, {g.name(), {img_in_g.name()}}, @@ -369,7 +317,7 @@ int global_and_custom_wrapper_test() { {img_in_g.name(), {img_f.name()}}, {img_f.name(), {img.name()}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(result, expected) != 0) { return -1; } @@ -407,10 +355,6 @@ int wrapper_depend_on_mutated_func_test() { // Check the call graphs. // Expect 'h' to call 'g_in_h', 'g_in_h' to call 'g', 'g' to call 'f', // 'f' to call 'img_in_f', img_in_f' to call 'img_f', 'img_f' to call 'img' - Module m = h.compile_to_module({h.infer_arguments()}); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {h.name(), {g_in_h.name()}}, {g_in_h.name(), {g.name()}}, @@ -419,7 +363,7 @@ int wrapper_depend_on_mutated_func_test() { {img_in_f.name(), {img_f.name()}}, {img_f.name(), {img.name()}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(h, expected) != 0) { return -1; } @@ -452,10 +396,6 @@ int wrapper_on_wrapper_test() { Func g_in_h = g.in(h).compute_root(); // Check the call graphs. - Module m = h.compile_to_module({h.infer_arguments()}); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {h.name(), {img_in_h.name(), g_in_h.name(), img_in_img_in_g.name()}}, {img_in_h.name(), {img_f.name()}}, @@ -465,7 +405,7 @@ int wrapper_on_wrapper_test() { {img_in_img_in_g.name(), {img_f.name()}}, {img_f.name(), {img.name()}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(h, expected) != 0) { return -1; } @@ -503,10 +443,6 @@ int wrapper_on_rdom_predicate_test() { // Expect 'g' to call nothing, update of 'g' to call 'g', img_in_g', and 'h_wrapper', // 'img_in_g' to call 'img_f', 'img_f' to call 'img', 'h_wrapper' to call 'h', // 'h' to call nothing - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {g.name(), {g.name(), img_in_g.name(), h_wrapper.name()}}, {img_in_g.name(), {img_f.name()}}, @@ -514,7 +450,7 @@ int wrapper_on_rdom_predicate_test() { {h_wrapper.name(), {h.name()}}, {h.name(), {}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(g, expected) != 0) { return -1; } @@ -549,17 +485,13 @@ int two_fold_wrapper_test() { img_in_output_in_output = img_in_output.in(output).compute_at(output, x).unroll(_0).unroll(_1); // Check the call graphs. - Module m = output.compile_to_module({output.infer_arguments()}); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {output.name(), {img_in_output_in_output.name()}}, {img_in_output_in_output.name(), {img_in_output.name()}}, {img_in_output.name(), {img_f.name()}}, {img_f.name(), {img.name()}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(output, expected) != 0) { return -1; } @@ -597,10 +529,6 @@ int multi_folds_wrapper_test() { h.compute_root().tile(x, y, xi, yi, 8, 8); Pipeline p({g, h}); - Module m = p.compile_to_module({p.infer_arguments()}, ""); - CheckCalls c; - m.functions().front().body.accept(&c); - CallGraphs expected = { {g.name(), {img_in_g_in_g.name()}}, {img_in_g_in_g.name(), {img_in_g.name()}}, @@ -610,7 +538,7 @@ int multi_folds_wrapper_test() { {img_in_g_in_g_in_h_in_h.name(), {img_in_g_in_g_in_h.name()}}, {img_in_g_in_g_in_h.name(), {img_in_g_in_g.name()}}, }; - if (check_call_graphs(c.calls, expected) != 0) { + if (check_call_graphs(p, expected) != 0) { return -1; } diff --git a/test/correctness/parallel_fork.cpp b/test/correctness/parallel_fork.cpp index 78b45bf78e22..e7b588da6f74 100644 --- a/test/correctness/parallel_fork.cpp +++ b/test/correctness/parallel_fork.cpp @@ -35,8 +35,9 @@ enum Schedule { }; Func make(Schedule schedule) { - Var x, y, z; - Func both, f, g; + Var x("x"), y("y"), z("z"); + std::string suffix = "_" + std::to_string((int)schedule); + Func both("both" + suffix), f("f" + suffix), g("g" + suffix); f(x, y) = halide_externs::five_ms(x + y); g(x, y) = halide_externs::five_ms(x - y); diff --git a/test/correctness/parallel_nested_1.cpp b/test/correctness/parallel_nested_1.cpp index dc642145908e..a59e36673541 100644 --- a/test/correctness/parallel_nested_1.cpp +++ b/test/correctness/parallel_nested_1.cpp @@ -24,14 +24,17 @@ int main(int argc, char **argv) { g.hexagon().vectorize(x, 32); f.vectorize(x, 32); } + printf("Using Target = %s\n", target.to_string().c_str()); - Buffer im = g.realize({64, 64, 64}); + Buffer im = g.realize({64, 64, 64}, target); for (int x = 0; x < 64; x++) { for (int y = 0; y < 64; y++) { for (int z = 0; z < 64; z++) { - if (im(x, y, z) != x * y + z * 3 + 3) { - printf("im(%d, %d, %d) = %d\n", x, y, z, im(x, y, z)); + const int expected = x * y + z * 3 + 3; + const int actual = im(x, y, z); + if (actual != expected) { + fprintf(stderr, "im(%d, %d, %d) = %d, expected %d\n", x, y, z, actual, expected); return -1; } } diff --git a/test/correctness/rfactor.cpp b/test/correctness/rfactor.cpp index 8511e7ac8e09..0cbaace189f2 100644 --- a/test/correctness/rfactor.cpp +++ b/test/correctness/rfactor.cpp @@ -32,16 +32,12 @@ int simple_rfactor_test(bool compile_module) { if (compile_module) { // Check the call graphs. - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls checker; - m.functions().front().body.accept(&checker); - CallGraphs expected = { {g.name(), {intm.name(), g.name()}}, {intm.name(), {f.name(), intm.name()}}, {f.name(), {}}, }; - if (check_call_graphs(checker.calls, expected) != 0) { + if (check_call_graphs(g, expected) != 0) { return -1; } } else { @@ -80,17 +76,13 @@ int reorder_split_rfactor_test(bool compile_module) { if (compile_module) { // Check the call graphs. - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls checker; - m.functions().front().body.accept(&checker); - CallGraphs expected = { {g.name(), {intm2.name(), g.name()}}, {intm2.name(), {intm1.name(), intm2.name()}}, {intm1.name(), {f.name(), intm1.name()}}, {f.name(), {}}, }; - if (check_call_graphs(checker.calls, expected) != 0) { + if (check_call_graphs(g, expected) != 0) { return -1; } } else { @@ -132,17 +124,13 @@ int multi_split_rfactor_test(bool compile_module) { if (compile_module) { // Check the call graphs. - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls checker; - m.functions().front().body.accept(&checker); - CallGraphs expected = { {g.name(), {intm2.name(), g.name()}}, {intm2.name(), {intm1.name(), intm2.name()}}, {intm1.name(), {f.name(), intm1.name()}}, {f.name(), {}}, }; - if (check_call_graphs(checker.calls, expected) != 0) { + if (check_call_graphs(g, expected) != 0) { return -1; } } else { @@ -183,17 +171,13 @@ int reorder_fuse_wrapper_rfactor_test(bool compile_module) { if (compile_module) { // Check the call graphs. - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls checker; - m.functions().front().body.accept(&checker); - CallGraphs expected = { {g.name(), {intm.name(), g.name()}}, {wrapper.name(), {f.name()}}, {intm.name(), {wrapper.name(), intm.name()}}, {f.name(), {}}, }; - if (check_call_graphs(checker.calls, expected) != 0) { + if (check_call_graphs(g, expected) != 0) { return -1; } } else { @@ -257,10 +241,6 @@ int non_trivial_lhs_rfactor_test(bool compile_module) { if (compile_module) { // Check the call graphs. - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls checker; - m.functions().front().body.accept(&checker); - CallGraphs expected = { {g.name(), {f.name()}}, {f.name(), {f.name(), intm.name()}}, @@ -269,7 +249,7 @@ int non_trivial_lhs_rfactor_test(bool compile_module) { {b.name(), {}}, {c.name(), {}}, }; - if (check_call_graphs(checker.calls, expected) != 0) { + if (check_call_graphs(g, expected) != 0) { return -1; } } else { @@ -306,16 +286,12 @@ int simple_rfactor_with_specialize_test(bool compile_module) { if (compile_module) { p.set(20); // Check the call graphs. - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls checker; - m.functions().front().body.accept(&checker); - CallGraphs expected = { {g.name(), {f.name(), intm.name(), g.name()}}, {intm.name(), {f.name(), intm.name()}}, {f.name(), {}}, }; - if (check_call_graphs(checker.calls, expected) != 0) { + if (check_call_graphs(g, expected) != 0) { return -1; } } else { @@ -365,16 +341,12 @@ int rdom_with_predicate_rfactor_test(bool compile_module) { if (compile_module) { // Check the call graphs. - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls checker; - m.functions().front().body.accept(&checker); - CallGraphs expected = { {g.name(), {intm.name(), g.name()}}, {intm.name(), {f.name(), intm.name()}}, {f.name(), {}}, }; - if (check_call_graphs(checker.calls, expected) != 0) { + if (check_call_graphs(g, expected) != 0) { return -1; } } else { @@ -426,17 +398,13 @@ int histogram_rfactor_test(bool compile_module) { if (compile_module) { // Check the call graphs. - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls checker; - m.functions().front().body.accept(&checker); - CallGraphs expected = { {g.name(), {hist.name()}}, {hist.name(), {intm.name(), hist.name()}}, {intm.name(), {in.name(), intm.name()}}, }; - if (check_call_graphs(checker.calls, expected) != 0) { + if (check_call_graphs(g, expected) != 0) { return -1; } } else { @@ -493,9 +461,6 @@ int parallel_dot_product_rfactor_test(bool compile_module) { if (compile_module) { // Check the call graphs. - Module m = dot.compile_to_module({dot.infer_arguments()}); - CheckCalls checker; - m.functions().front().body.accept(&checker); CallGraphs expected = { {dot.name(), {intm1.name(), dot.name()}}, @@ -504,7 +469,7 @@ int parallel_dot_product_rfactor_test(bool compile_module) { {a.name(), {}}, {b.name(), {}}, }; - if (check_call_graphs(checker.calls, expected) != 0) { + if (check_call_graphs(dot, expected) != 0) { return -1; } } else { @@ -553,17 +518,13 @@ int tuple_rfactor_test(bool compile_module) { if (compile_module) { // Check the call graphs. - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls checker; - m.functions().front().body.accept(&checker); - CallGraphs expected = { {g.name(), {intm1.name() + ".0", intm1.name() + ".1", g.name() + ".0", g.name() + ".1"}}, {intm1.name(), {intm2.name() + ".0", intm2.name() + ".1", intm1.name() + ".0", intm1.name() + ".1"}}, {intm2.name(), {f.name() + ".0", f.name() + ".1", intm2.name() + ".0", intm2.name() + ".1"}}, {f.name(), {}}, }; - if (check_call_graphs(checker.calls, expected) != 0) { + if (check_call_graphs(g, expected) != 0) { return -1; } } else { @@ -629,10 +590,6 @@ int tuple_specialize_rdom_predicate_rfactor_test(bool compile_module) { if (compile_module) { // Check the call graphs. - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls checker; - m.functions().front().body.accept(&checker); - CallGraphs expected = { {g.name(), {intm1.name() + ".0", intm1.name() + ".1", intm4.name() + ".0", intm4.name() + ".1", g.name() + ".0", g.name() + ".1"}}, {intm1.name(), {intm2.name() + ".0", intm2.name() + ".1", intm3.name() + ".0", intm3.name() + ".1", intm1.name() + ".0", intm1.name() + ".1"}}, @@ -641,7 +598,7 @@ int tuple_specialize_rdom_predicate_rfactor_test(bool compile_module) { {intm4.name(), {f.name() + ".0", f.name() + ".1", intm4.name() + ".0", intm4.name() + ".1"}}, {f.name(), {}}, }; - if (check_call_graphs(checker.calls, expected) != 0) { + if (check_call_graphs(g, expected) != 0) { return -1; } } else { @@ -966,17 +923,13 @@ int tuple_partial_reduction_rfactor_test(bool compile_module) { if (compile_module) { // Check the call graphs. - Module m = g.compile_to_module({g.infer_arguments()}); - CheckCalls checker; - m.functions().front().body.accept(&checker); - CallGraphs expected = { {g.name(), {intm1.name() + ".0", g.name() + ".0"}}, {intm1.name(), {intm2.name() + ".0", intm1.name() + ".0"}}, {intm2.name(), {f.name() + ".0", intm2.name() + ".0"}}, {f.name(), {}}, }; - if (check_call_graphs(checker.calls, expected) != 0) { + if (check_call_graphs(g, expected) != 0) { return -1; } } else {