diff --git a/Makefile b/Makefile index c9b179a28da0..92f46164fa67 100644 --- a/Makefile +++ b/Makefile @@ -1609,7 +1609,7 @@ $(FILTERS_DIR)/autograd_grad.a: $(BIN_DIR)/autograd.generator $(BIN_MULLAPUDI201 # all have the form nested_externs_*). $(FILTERS_DIR)/nested_externs_%.a: $(BIN_DIR)/nested_externs.generator @mkdir -p $(@D) - $(CURDIR)/$< -g nested_externs_$* $(GEN_AOT_OUTPUTS) -o $(CURDIR)/$(FILTERS_DIR) target=$(TARGET)-no_runtime-c_plus_plus_name_mangling + $(CURDIR)/$< -g nested_externs_$* $(GEN_AOT_OUTPUTS) -o $(CURDIR)/$(FILTERS_DIR) target=$(TARGET)-no_runtime-user_context-c_plus_plus_name_mangling # Similarly, gpu_multi needs two different kernels to test compilation caching. # Also requies user-context. diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index f91aa37d2a93..6713d62d7ec2 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -1937,6 +1937,24 @@ void CodeGen_C::compile(const LoweredFunc &f, const MetadataNameMap &metadata_na } if (output_kind != CPlusPlusFunctionInfoHeader) { + const auto emit_arg_decls = [&](const Type &ucon_type = Type()) { + const char *comma = ""; + for (const auto &arg : args) { + stream << comma; + if (arg.is_buffer()) { + stream << "struct halide_buffer_t *" + << print_name(arg.name) + << "_buffer"; + } else { + // If this arg is the user_context value, *and* ucon_type is valid, + // use ucon_type instead of arg.type. + const Type &t = (arg.name == "__user_context" && ucon_type.bits() != 0) ? ucon_type : arg.type; + stream << print_type(t, AppendSpace) << print_name(arg.name); + } + comma = ", "; + } + }; + // Emit the function prototype if (f.linkage == LinkageType::Internal) { // If the function isn't public, mark it static. @@ -1944,20 +1962,7 @@ void CodeGen_C::compile(const LoweredFunc &f, const MetadataNameMap &metadata_na } stream << "HALIDE_FUNCTION_ATTRS\n"; stream << "int " << simple_name << "("; - for (size_t i = 0; i < args.size(); i++) { - if (args[i].is_buffer()) { - stream << "struct halide_buffer_t *" - << print_name(args[i].name) - << "_buffer"; - } else { - stream << print_type(args[i].type, AppendSpace) - << print_name(args[i].name); - } - - if (i < args.size() - 1) { - stream << ", "; - } - } + emit_arg_decls(); if (is_header_or_extern_decl()) { stream << ");\n"; @@ -1995,6 +2000,59 @@ void CodeGen_C::compile(const LoweredFunc &f, const MetadataNameMap &metadata_na close_scope(""); } + // Workaround for https://github.com/halide/Halide/issues/635: + // For historical reasons, Halide-generated AOT code + // defines user_context as `void const*`, but expects all + // define_extern code with user_context usage to use `void *`. This + // usually isn't an issue, but if both the caller and callee of the + // pass a user_context, *and* c_plus_plus_name_mangling is enabled, + // we get link errors because of this dichotomy. Fixing this + // "correctly" (ie so that everything always uses identical types for + // user_context in all cases) will require a *lot* of downstream + // churn (see https://github.com/halide/Halide/issues/7298), + // so this is a workaround: Add a wrapper with `void*` + // ucon -> `void const*` ucon. In most cases this will be ignored + // (and probably dead-stripped), but in these cases it's critical. + // + // (Note that we don't check to see if c_plus_plus_name_mangling is + // enabled, since that would have to be done on the caller side, and + // this is purely a callee-side fix.) + if (f.linkage != LinkageType::Internal && + output_kind == CPlusPlusImplementation && + target.has_feature(Target::CPlusPlusMangling) && + get_target().has_feature(Target::UserContext)) { + + Type ucon_type = Type(); + for (const auto &arg : args) { + if (arg.name == "__user_context") { + ucon_type = arg.type; + break; + } + } + if (ucon_type == type_of()) { + stream << "\nHALIDE_FUNCTION_ATTRS\n"; + stream << "int " << simple_name << "("; + emit_arg_decls(type_of()); + stream << ") "; + open_scope(); + stream << get_indent() << " return " << simple_name << "("; + const char *comma = ""; + for (const auto &arg : args) { + if (arg.name == "__user_context") { + // Add an explicit cast here so we won't call ourselves into oblivion + stream << "(void const *)"; + } + stream << comma << print_name(arg.name); + if (arg.is_buffer()) { + stream << "_buffer"; + } + comma = ", "; + } + stream << ");\n"; + close_scope(""); + } + } + if (f.linkage == LinkageType::ExternalPlusArgv || f.linkage == LinkageType::ExternalPlusMetadata) { // Emit the argv version emit_argv_wrapper(simple_name, args); diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index fae9702f8db7..df0fc3159e71 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -536,6 +536,65 @@ std::unique_ptr CodeGen_LLVM::compile(const Module &input) { names.simple_name, f.args, input.get_metadata_name_map()); } } + + // Workaround for https://github.com/halide/Halide/issues/635: + // For historical reasons, Halide-generated AOT code + // defines user_context as `void const*`, but expects all + // define_extern code with user_context usage to use `void *`. This + // usually isn't an issue, but if both the caller and callee of the + // pass a user_context, *and* c_plus_plus_name_mangling is enabled, + // we get link errors because of this dichotomy. Fixing this + // "correctly" (ie so that everything always uses identical types for + // user_context in all cases) will require a *lot* of downstream + // churn (see https://github.com/halide/Halide/issues/7298), + // so this is a workaround: Add a wrapper with `void*` + // ucon -> `void const*` ucon. In most cases this will be ignored + // (and probably dead-stripped), but in these cases it's critical. + // + // (Note that we don't check to see if c_plus_plus_name_mangling is + // enabled, since that would have to be done on the caller side, and + // this is purely a callee-side fix.) + if (f.linkage != LinkageType::Internal && + target.has_feature(Target::CPlusPlusMangling) && + target.has_feature(Target::UserContext)) { + + int wrapper_ucon_index = -1; + auto wrapper_args = f.args; // make a copy + auto wrapper_llvm_arg_types = arg_types; // make a copy + for (int i = 0; i < (int)wrapper_args.size(); i++) { + if (wrapper_args[i].name == "__user_context" && wrapper_args[i].type == type_of()) { + // Update the type of the user_context argument to be void* rather than void const* + wrapper_args[i].type = type_of(); + wrapper_llvm_arg_types[i] = llvm_type_of(upgrade_type_for_argument_passing(wrapper_args[i].type)); + wrapper_ucon_index = i; + } + } + if (wrapper_ucon_index >= 0) { + const auto wrapper_names = get_mangled_names(f.name, f.linkage, f.name_mangling, wrapper_args, target); + + FunctionType *wrapper_func_t = FunctionType::get(i32_t, wrapper_llvm_arg_types, false); + llvm::Function *wrapper_func = llvm::Function::Create(wrapper_func_t, + llvm::GlobalValue::ExternalLinkage, + wrapper_names.extern_name, + module.get()); + set_function_attributes_from_halide_target_options(*wrapper_func); + llvm::BasicBlock *wrapper_block = llvm::BasicBlock::Create(module->getContext(), "entry", wrapper_func); + builder->SetInsertPoint(wrapper_block); + + std::vector wrapper_call_args; + for (auto &arg : wrapper_func->args()) { + wrapper_call_args.push_back(&arg); + } + wrapper_call_args[wrapper_ucon_index] = builder->CreatePointerCast(wrapper_call_args[wrapper_ucon_index], + llvm_type_of(type_of())); + + llvm::CallInst *wrapper_result = builder->CreateCall(function, wrapper_call_args); + // This call should never inline + wrapper_result->setIsNoInline(); + builder->CreateRet(wrapper_result); + internal_assert(!verifyFunction(*wrapper_func, &llvm::errs())); + } + } } // Define all functions int idx = 0; diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index 3b4aa2ea424c..c6194267ee58 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -30,11 +30,8 @@ ostream &operator<<(ostream &out, const Type &type) { out << "float"; break; case Type::Handle: - if (type.handle_type) { - out << "(" << type.handle_type->inner_name.name << " *)"; - } else { - out << "(void *)"; - } + // ensure that 'const' (etc) qualifiers are emitted when appropriate + out << "(" << type_to_c_type(type, false) << ")"; break; case Type::BFloat: out << "bfloat"; diff --git a/test/generator/CMakeLists.txt b/test/generator/CMakeLists.txt index cbf6ca5e3564..9ed7d38dd37a 100644 --- a/test/generator/CMakeLists.txt +++ b/test/generator/CMakeLists.txt @@ -508,7 +508,7 @@ _add_halide_aot_tests(multitarget add_halide_generator(nested_externs.generator SOURCES nested_externs_generator.cpp) set(NESTED_EXTERNS_LIBS nested_externs_root nested_externs_inner nested_externs_combine nested_externs_leaf) foreach (LIB IN LISTS NESTED_EXTERNS_LIBS) - _add_halide_libraries(${LIB} FROM nested_externs.generator GENERATOR_NAME ${LIB} FEATURES c_plus_plus_name_mangling) + _add_halide_libraries(${LIB} FROM nested_externs.generator GENERATOR_NAME ${LIB} FEATURES user_context c_plus_plus_name_mangling) endforeach () _add_halide_aot_tests(nested_externs HALIDE_LIBRARIES ${NESTED_EXTERNS_LIBS}) diff --git a/test/generator/nested_externs_aottest.cpp b/test/generator/nested_externs_aottest.cpp index 920118d56211..0f32cb5c382b 100644 --- a/test/generator/nested_externs_aottest.cpp +++ b/test/generator/nested_externs_aottest.cpp @@ -11,7 +11,8 @@ int main(int argc, char **argv) { auto val = Buffer::make_scalar(); val() = 38.5f; - nested_externs_root(val, buf); + void const *ucon = nullptr; + nested_externs_root(ucon, val, buf); buf.for_each_element([&](int x, int y, int c) { const float correct = 158.0f; diff --git a/test/generator/nested_externs_generator.cpp b/test/generator/nested_externs_generator.cpp index c5ab216b7aa2..124362991887 100644 --- a/test/generator/nested_externs_generator.cpp +++ b/test/generator/nested_externs_generator.cpp @@ -17,7 +17,7 @@ class NestedExternsCombine : public Generator { Output> combine{"combine"}; // unspecified type-and-dim will be inferred void generate() { - Var x, y, c; + Var x{"x"}, y{"y"}, c{"c"}; combine(x, y, c) = input_a(x, y, c) + input_b(x, y, c); } @@ -35,10 +35,11 @@ class NestedExternsInner : public Generator { Output> inner{"inner"}; void generate() { - extern_stage_1.define_extern("nested_externs_leaf", {value}, Float(32), 3); - extern_stage_2.define_extern("nested_externs_leaf", {value + 1}, Float(32), 3); + Expr ucon = user_context_value(); + extern_stage_1.define_extern("nested_externs_leaf", {ucon, value}, Float(32), 3); + extern_stage_2.define_extern("nested_externs_leaf", {ucon, value + 1}, Float(32), 3); extern_stage_combine.define_extern("nested_externs_combine", - {extern_stage_1, extern_stage_2}, Float(32), 3); + {ucon, extern_stage_1, extern_stage_2}, Float(32), 3); inner(x, y, c) = extern_stage_combine(x, y, c); } @@ -51,8 +52,10 @@ class NestedExternsInner : public Generator { } private: - Var x, y, c; - Func extern_stage_1, extern_stage_2, extern_stage_combine; + Var x{"x"}, y{"y"}, c{"c"}; + Func extern_stage_1{"extern_stage_1_inner"}, + extern_stage_2{"extern_stage_2_inner"}, + extern_stage_combine{"extern_stage_combine_inner"}; }; // Basically a memset. @@ -62,7 +65,7 @@ class NestedExternsLeaf : public Generator { Output> leaf{"leaf"}; void generate() { - Var x, y, c; + Var x{"x"}, y{"y"}, c{"c"}; leaf(x, y, c) = value; } @@ -80,10 +83,11 @@ class NestedExternsRoot : public Generator { Output> root{"root"}; void generate() { - extern_stage_1.define_extern("nested_externs_inner", {value()}, Float(32), 3); - extern_stage_2.define_extern("nested_externs_inner", {value() + 1}, Float(32), 3); + Expr ucon = user_context_value(); + extern_stage_1.define_extern("nested_externs_inner", {ucon, value()}, Float(32), 3); + extern_stage_2.define_extern("nested_externs_inner", {ucon, value() + 1}, Float(32), 3); extern_stage_combine.define_extern("nested_externs_combine", - {extern_stage_1, extern_stage_2}, Float(32), 3); + {ucon, extern_stage_1, extern_stage_2}, Float(32), 3); root(x, y, c) = extern_stage_combine(x, y, c); } @@ -94,11 +98,14 @@ class NestedExternsRoot : public Generator { } set_interleaved(root); root.reorder_storage(c, x, y); + root.parallel(y, 8); } private: - Var x, y, c; - Func extern_stage_1, extern_stage_2, extern_stage_combine; + Var x{"x"}, y{"y"}, c{"c"}; + Func extern_stage_1{"extern_stage_1_root"}, + extern_stage_2{"extern_stage_2_root"}, + extern_stage_combine{"extern_stage_combine_root"}; }; } // namespace