From 2cac3cc1c6ff53996e481d4ced6ce51d575a8b30 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Mon, 13 Feb 2023 16:51:36 -0800 Subject: [PATCH] Change early-bound default args in Python bindings to late-bound In PyBind11, if you specify a default argument for a method, it is evaluated when the Python module is initialized, *not* when the method is called (as you might expect in C++). For defaults that are just constants/literals, this is no big deal, but when calling get_*_target_from_environment, this means it is called at module init time -- also normally not a big deal (since the values ~never change at runtime anyway), with one big exception (no pun intended): if the function throws an exception (e.g. via calling user_assert() or similar), that exception is thrown at Module-initialization time, which is a much more inscrutable crash, and one that is very hard to recover from. This may seem unlikely, but can happen pretty easily if you set (say) HL_JIT_TARGET=host-cuda (or other gpu) and the given GPU runtime isn't present on the given system; the current behavior is basically "make if impossible for the libHalidePython bindings to run", whereas what we want is "runtime exception thrown when you call the method". This changes the relevant methods to use `Target()` as the default value, and inside the method wrapper, if the value passed equals `Target()`, it replaces the value with the righ `get_*_target_from_environment()` call. (This turned up while doing some testing of https://github.com/halide/Halide/pull/6924 on a system without Vulkan available) --- .../src/halide/halide_/PyBuffer.cpp | 24 ++-- python_bindings/src/halide/halide_/PyFunc.cpp | 120 ++++++++++++---- .../src/halide/halide_/PyHalide.cpp | 14 ++ python_bindings/src/halide/halide_/PyHalide.h | 2 + .../src/halide/halide_/PyPipeline.cpp | 135 +++++++++++++----- 5 files changed, 215 insertions(+), 80 deletions(-) diff --git a/python_bindings/src/halide/halide_/PyBuffer.cpp b/python_bindings/src/halide/halide_/PyBuffer.cpp index 7b32f9824349..6d1461a33ec4 100644 --- a/python_bindings/src/halide/halide_/PyBuffer.cpp +++ b/python_bindings/src/halide/halide_/PyBuffer.cpp @@ -587,27 +587,27 @@ void define_buffer(py::module &m) { }) .def( - "copy_to_device", [](Buffer<> &b, const Target &t) -> int { - return b.copy_to_device(t); + "copy_to_device", [](Buffer<> &b, const Target &target) -> int { + return b.copy_to_device(to_jit_target(target)); }, - py::arg("target") = get_jit_target_from_environment()) + py::arg("target") = Target()) .def( - "copy_to_device", [](Buffer<> &b, const DeviceAPI &d, const Target &t) -> int { - return b.copy_to_device(d, t); + "copy_to_device", [](Buffer<> &b, const DeviceAPI &d, const Target &target) -> int { + return b.copy_to_device(d, to_jit_target(target)); }, - py::arg("device_api"), py::arg("target") = get_jit_target_from_environment()) + py::arg("device_api"), py::arg("target") = Target()) .def( - "device_malloc", [](Buffer<> &b, const Target &t) -> int { - return b.device_malloc(t); + "device_malloc", [](Buffer<> &b, const Target &target) -> int { + return b.device_malloc(to_jit_target(target)); }, - py::arg("target") = get_jit_target_from_environment()) + py::arg("target") = Target()) .def( - "device_malloc", [](Buffer<> &b, const DeviceAPI &d, const Target &t) -> int { - return b.device_malloc(d, t); + "device_malloc", [](Buffer<> &b, const DeviceAPI &d, const Target &target) -> int { + return b.device_malloc(d, to_jit_target(target)); }, - py::arg("device_api"), py::arg("target") = get_jit_target_from_environment()) + py::arg("device_api"), py::arg("target") = Target()) .def( "set_min", [](Buffer<> &b, const std::vector &mins) -> void { diff --git a/python_bindings/src/halide/halide_/PyFunc.cpp b/python_bindings/src/halide/halide_/PyFunc.cpp index 966b35bb4388..6c6e38ec7501 100644 --- a/python_bindings/src/halide/halide_/PyFunc.cpp +++ b/python_bindings/src/halide/halide_/PyFunc.cpp @@ -212,39 +212,98 @@ void define_func(py::module &m) { .def("store_in", &Func::store_in, py::arg("memory_type")) - .def("compile_to", &Func::compile_to, py::arg("outputs"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - - .def("compile_to_bitcode", (void(Func::*)(const std::string &, const std::vector &, const std::string &, const Target &target)) & Func::compile_to_bitcode, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_bitcode", (void(Func::*)(const std::string &, const std::vector &, const Target &target)) & Func::compile_to_bitcode, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment()) - - .def("compile_to_llvm_assembly", (void(Func::*)(const std::string &, const std::vector &, const std::string &, const Target &target)) & Func::compile_to_llvm_assembly, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_llvm_assembly", (void(Func::*)(const std::string &, const std::vector &, const Target &target)) & Func::compile_to_llvm_assembly, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment()) - - .def("compile_to_object", (void(Func::*)(const std::string &, const std::vector &, const std::string &, const Target &target)) & Func::compile_to_object, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_object", (void(Func::*)(const std::string &, const std::vector &, const Target &target)) & Func::compile_to_object, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment()) - - .def("compile_to_header", &Func::compile_to_header, py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = get_target_from_environment()) - - .def("compile_to_assembly", (void(Func::*)(const std::string &, const std::vector &, const std::string &, const Target &target)) & Func::compile_to_assembly, py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_assembly", (void(Func::*)(const std::string &, const std::vector &, const Target &target)) & Func::compile_to_assembly, py::arg("filename"), py::arg("arguments"), py::arg("target") = get_target_from_environment()) - - .def("compile_to_c", &Func::compile_to_c, py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = get_target_from_environment()) - - .def("compile_to_lowered_stmt", &Func::compile_to_lowered_stmt, py::arg("filename"), py::arg("arguments"), py::arg("fmt") = Text, py::arg("target") = get_target_from_environment()) - - .def("compile_to_file", &Func::compile_to_file, py::arg("filename_prefix"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = get_target_from_environment()) - - .def("compile_to_static_library", &Func::compile_to_static_library, py::arg("filename_prefix"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = get_target_from_environment()) + .def( + "compile_to", [](Func &f, const std::map &output_files, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to(output_files, args, fn_name, to_aot_target(target)); + }, + py::arg("outputs"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_bitcode", [](Func &f, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to_bitcode(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_bitcode", [](Func &f, const std::string &filename, const std::vector &args, const Target &target) { + f.compile_to_bitcode(filename, args, "", to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("target") = Target()) + .def( + "compile_to_llvm_assembly", [](Func &f, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to_llvm_assembly(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_llvm_assembly", [](Func &f, const std::string &filename, const std::vector &args, const Target &target) { + f.compile_to_llvm_assembly(filename, args, "", to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("target") = Target()) + .def( + "compile_to_object", [](Func &f, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to_object(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_object", [](Func &f, const std::string &filename, const std::vector &args, const Target &target) { + f.compile_to_object(filename, args, "", to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("target") = Target()) + .def( + "compile_to_header", [](Func &f, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to_header(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) + .def( + "compile_to_assembly", [](Func &f, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to_assembly(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_assembly", [](Func &f, const std::string &filename, const std::vector &args, const Target &target) { + f.compile_to_assembly(filename, args, "", to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("target") = Target()) + .def( + "compile_to_c", [](Func &f, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to_c(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) + .def( + "compile_to_lowered_stmt", [](Func &f, const std::string &filename, const std::vector &args, StmtOutputFormat fmt, const Target &target) { + f.compile_to_lowered_stmt(filename, args, fmt, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fmt") = Text, py::arg("target") = Target()) + .def( + "compile_to_file", [](Func &f, const std::string &filename_prefix, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to_file(filename_prefix, args, fn_name, to_aot_target(target)); + }, + py::arg("filename_prefix"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) + .def( + "compile_to_static_library", [](Func &f, const std::string &filename_prefix, const std::vector &args, const std::string &fn_name, const Target &target) { + f.compile_to_static_library(filename_prefix, args, fn_name, to_aot_target(target)); + }, + py::arg("filename_prefix"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) .def("compile_to_multitarget_static_library", &Func::compile_to_multitarget_static_library, py::arg("filename_prefix"), py::arg("arguments"), py::arg("targets")) .def("compile_to_multitarget_object_files", &Func::compile_to_multitarget_object_files, py::arg("filename_prefix"), py::arg("arguments"), py::arg("targets"), py::arg("suffixes")) // TODO: useless until Module is defined. - .def("compile_to_module", &Func::compile_to_module, py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = get_target_from_environment()) + .def( + "compile_to_module", [](Func &f, const std::vector &args, const std::string &fn_name, const Target &target) -> Module { + return f.compile_to_module(args, fn_name, to_aot_target(target)); + }, + py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) - .def("compile_jit", &Func::compile_jit, py::arg("target") = get_jit_target_from_environment()) + .def( + "compile_jit", [](Func &f, const Target &target) { + f.compile_jit(to_jit_target(target)); + }, + py::arg("target") = Target()) - .def("compile_to_callable", &Func::compile_to_callable, py::arg("arguments"), py::arg("target") = get_jit_target_from_environment()) + .def( + "compile_to_callable", [](Func &f, const std::vector &args, const Target &target) { + return f.compile_to_callable(args, to_jit_target(target)); + }, + py::arg("arguments"), py::arg("target") = Target()) .def("has_update_definition", &Func::has_update_definition) .def("num_update_definitions", &Func::num_update_definitions) @@ -285,10 +344,11 @@ void define_func(py::module &m) { .def( "infer_input_bounds", [](Func &f, const py::object &dst, const Target &target) -> void { + const Target t = to_jit_target(target); // dst could be Buffer<>, vector, or vector try { Buffer<> b = dst.cast>(); - f.infer_input_bounds(b, target); + f.infer_input_bounds(b, t); return; } catch (...) { // fall thru @@ -296,7 +356,7 @@ void define_func(py::module &m) { try { std::vector> v = dst.cast>>(); - f.infer_input_bounds(Realization(std::move(v)), target); + f.infer_input_bounds(Realization(std::move(v)), t); return; } catch (...) { // fall thru @@ -304,7 +364,7 @@ void define_func(py::module &m) { try { std::vector v = dst.cast>(); - f.infer_input_bounds(v, target); + f.infer_input_bounds(v, t); return; } catch (...) { // fall thru @@ -312,7 +372,7 @@ void define_func(py::module &m) { throw py::value_error("Invalid arguments to infer_input_bounds"); }, - py::arg("dst"), py::arg("target") = get_jit_target_from_environment()) + py::arg("dst"), py::arg("target") = Target()) .def("in_", (Func(Func::*)(const Func &)) & Func::in, py::arg("f")) .def("in_", (Func(Func::*)(const std::vector &fs)) & Func::in, py::arg("fs")) diff --git a/python_bindings/src/halide/halide_/PyHalide.cpp b/python_bindings/src/halide/halide_/PyHalide.cpp index d18ccf01b725..430ad690420d 100644 --- a/python_bindings/src/halide/halide_/PyHalide.cpp +++ b/python_bindings/src/halide/halide_/PyHalide.cpp @@ -111,5 +111,19 @@ std::vector collect_print_args(const py::args &args) { return v; } +Target to_jit_target(const Target &target) { + if (target != Target()) { + return target; + } + return get_jit_target_from_environment(); +} + +Target to_aot_target(const Target &target) { + if (target != Target()) { + return target; + } + return get_target_from_environment(); +} + } // namespace PythonBindings } // namespace Halide diff --git a/python_bindings/src/halide/halide_/PyHalide.h b/python_bindings/src/halide/halide_/PyHalide.h index 2eefb1f463bf..64003c339bb3 100644 --- a/python_bindings/src/halide/halide_/PyHalide.h +++ b/python_bindings/src/halide/halide_/PyHalide.h @@ -34,6 +34,8 @@ std::vector args_to_vector(const py::args &args, size_t start_offset = 0, siz std::vector collect_print_args(const py::args &args); Expr double_to_expr_check(double v); +Target to_jit_target(const Target &target); +Target to_aot_target(const Target &target); } // namespace PythonBindings } // namespace Halide diff --git a/python_bindings/src/halide/halide_/PyPipeline.cpp b/python_bindings/src/halide/halide_/PyPipeline.cpp index 2300ab5e76cb..069ea9394df8 100644 --- a/python_bindings/src/halide/halide_/PyPipeline.cpp +++ b/python_bindings/src/halide/halide_/PyPipeline.cpp @@ -77,40 +77,98 @@ void define_pipeline(py::module &m) { py::arg("index")) .def("print_loop_nest", &Pipeline::print_loop_nest) - .def("compile_to", &Pipeline::compile_to, - py::arg("outputs"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - - .def("compile_to_bitcode", &Pipeline::compile_to_bitcode, - py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_llvm_assembly", &Pipeline::compile_to_llvm_assembly, - py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_object", &Pipeline::compile_to_object, - py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_header", &Pipeline::compile_to_header, - py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_assembly", &Pipeline::compile_to_assembly, - py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_c", &Pipeline::compile_to_c, - py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_file", &Pipeline::compile_to_file, - py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - .def("compile_to_static_library", &Pipeline::compile_to_static_library, - py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment()) - - .def("compile_to_lowered_stmt", &Pipeline::compile_to_lowered_stmt, - py::arg("filename"), py::arg("arguments"), py::arg("format") = StmtOutputFormat::Text, py::arg("target") = get_target_from_environment()) - - .def("compile_to_multitarget_static_library", &Pipeline::compile_to_multitarget_static_library, - py::arg("filename_prefix"), py::arg("arguments"), py::arg("targets")) - .def("compile_to_multitarget_object_files", &Pipeline::compile_to_multitarget_object_files, - py::arg("filename_prefix"), py::arg("arguments"), py::arg("targets"), py::arg("suffixes")) - - .def("compile_to_module", &Pipeline::compile_to_module, - py::arg("arguments"), py::arg("fn_name"), py::arg("target") = get_target_from_environment(), py::arg("linkage") = LinkageType::ExternalPlusMetadata) - - .def("compile_jit", &Pipeline::compile_jit, py::arg("target") = get_jit_target_from_environment()) - - .def("compile_to_callable", &Pipeline::compile_to_callable, py::arg("arguments"), py::arg("target") = get_jit_target_from_environment()) + .def( + "compile_to", [](Pipeline &p, const std::map &output_files, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to(output_files, args, fn_name, to_aot_target(target)); + }, + py::arg("outputs"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + + .def( + "compile_to_bitcode", [](Pipeline &p, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to_bitcode(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_bitcode", [](Pipeline &p, const std::string &filename, const std::vector &args, const Target &target) { + p.compile_to_bitcode(filename, args, "", to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("target") = Target()) + .def( + "compile_to_llvm_assembly", [](Pipeline &p, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to_llvm_assembly(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_llvm_assembly", [](Pipeline &p, const std::string &filename, const std::vector &args, const Target &target) { + p.compile_to_llvm_assembly(filename, args, "", to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("target") = Target()) + .def( + "compile_to_object", [](Pipeline &p, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to_object(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_object", [](Pipeline &p, const std::string &filename, const std::vector &args, const Target &target) { + p.compile_to_object(filename, args, "", to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("target") = Target()) + .def( + "compile_to_header", [](Pipeline &p, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to_header(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) + .def( + "compile_to_assembly", [](Pipeline &p, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to_assembly(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target()) + .def( + "compile_to_assembly", [](Pipeline &p, const std::string &filename, const std::vector &args, const Target &target) { + p.compile_to_assembly(filename, args, "", to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("target") = Target()) + .def( + "compile_to_c", [](Pipeline &p, const std::string &filename, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to_c(filename, args, fn_name, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) + .def( + "compile_to_lowered_stmt", [](Pipeline &p, const std::string &filename, const std::vector &args, StmtOutputFormat fmt, const Target &target) { + p.compile_to_lowered_stmt(filename, args, fmt, to_aot_target(target)); + }, + py::arg("filename"), py::arg("arguments"), py::arg("fmt") = Text, py::arg("target") = Target()) + .def( + "compile_to_file", [](Pipeline &p, const std::string &filename_prefix, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to_file(filename_prefix, args, fn_name, to_aot_target(target)); + }, + py::arg("filename_prefix"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) + .def( + "compile_to_static_library", [](Pipeline &p, const std::string &filename_prefix, const std::vector &args, const std::string &fn_name, const Target &target) { + p.compile_to_static_library(filename_prefix, args, fn_name, to_aot_target(target)); + }, + py::arg("filename_prefix"), py::arg("arguments"), py::arg("fn_name") = "", py::arg("target") = Target()) + + .def("compile_to_multitarget_static_library", &Pipeline::compile_to_multitarget_static_library, py::arg("filename_prefix"), py::arg("arguments"), py::arg("targets")) + .def("compile_to_multitarget_object_files", &Pipeline::compile_to_multitarget_object_files, py::arg("filename_prefix"), py::arg("arguments"), py::arg("targets"), py::arg("suffixes")) + + .def( + "compile_to_module", [](Pipeline &p, const std::vector &args, const std::string &fn_name, const Target &target, LinkageType linkage_type) -> Module { + return p.compile_to_module(args, fn_name, to_aot_target(target), linkage_type); + }, + py::arg("arguments"), py::arg("fn_name"), py::arg("target") = Target(), py::arg("linkage") = LinkageType::ExternalPlusMetadata) + + .def( + "compile_jit", [](Pipeline &p, const Target &target) { + p.compile_jit(to_jit_target(target)); + }, + py::arg("target") = Target()) + + .def( + "compile_to_callable", [](Pipeline &p, const std::vector &args, const Target &target) { + return p.compile_to_callable(args, to_jit_target(target)); + }, + py::arg("arguments"), py::arg("target") = Target()) .def( "realize", [](Pipeline &p, Buffer<> buffer, const Target &target) -> void { @@ -146,10 +204,11 @@ void define_pipeline(py::module &m) { .def( "infer_input_bounds", [](Pipeline &p, const py::object &dst, const Target &target) -> void { + const Target t = to_jit_target(target); // dst could be Buffer<>, vector, or vector try { Buffer<> b = dst.cast>(); - p.infer_input_bounds(b, target); + p.infer_input_bounds(b, t); return; } catch (...) { // fall thru @@ -157,7 +216,7 @@ void define_pipeline(py::module &m) { try { std::vector> v = dst.cast>>(); - p.infer_input_bounds(Realization(std::move(v)), target); + p.infer_input_bounds(Realization(std::move(v)), t); return; } catch (...) { // fall thru @@ -165,7 +224,7 @@ void define_pipeline(py::module &m) { try { std::vector v = dst.cast>(); - p.infer_input_bounds(v, target); + p.infer_input_bounds(v, t); return; } catch (...) { // fall thru @@ -173,7 +232,7 @@ void define_pipeline(py::module &m) { throw py::value_error("Invalid arguments to infer_input_bounds"); }, - py::arg("dst"), py::arg("target") = get_jit_target_from_environment()) + py::arg("dst"), py::arg("target") = Target()) .def("infer_arguments", [](Pipeline &p) -> std::vector { return p.infer_arguments();