From cdf93247aaca1491354e74f54499e38ff11c72ff Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Fri, 23 Sep 2022 11:15:13 -0700 Subject: [PATCH] add_requirement() maintenance (#7045) * add_requirement() maintenance This PR started out as a quick fix to add Python bindings for the `add_requirements` methods on Pipeline and Generator (which were missing), but expanded a bit to fix other issues as well: - The implementation of `Generator::add_requirement` was subtly wrong, in that it only worked if you called the method after everything else in your `generate()` method. Now we accumulate requirements and insert them at the end, so you can call the method anywhere. - We had C++ methods that took both an explicit `vector` and also a variadic-template version, but the former required a mutable vector... and fixing this to not require that ended up creating ambiguity about which overloaded call to use. Added an ugly enable_if thing to resolve this. (Side note #1: overloading methods to have both templated and non-templated versions with the same name is probably something to avoid in the future.) (Side note #2: we should probably thing more carefully about using variadic templates in our public API in the future; we currently use it pretty heavily, but it tends to be messy and hard to reason about IMHO.) * tidy * remove underscores --- .../src/halide/_generator_helpers.py | 7 +++ .../src/halide/halide_/PyHalide.cpp | 16 +++++++ python_bindings/src/halide/halide_/PyHalide.h | 1 + .../src/halide/halide_/PyIROperator.cpp | 28 +----------- .../src/halide/halide_/PyPipeline.cpp | 7 +++ .../test/correctness/addconstant_test.py | 44 ++++++++++++++++++- python_bindings/test/correctness/basics.py | 31 +++++++++++++ .../generators/addconstantcpp_generator.cpp | 3 ++ .../generators/addconstantpy_generator.py | 3 ++ src/Generator.cpp | 8 ++++ src/Generator.h | 17 +++++-- src/IROperator.h | 9 ++++ src/Pipeline.cpp | 2 +- src/Pipeline.h | 17 ++++--- 14 files changed, 155 insertions(+), 38 deletions(-) diff --git a/python_bindings/src/halide/_generator_helpers.py b/python_bindings/src/halide/_generator_helpers.py index 5349f7409e0b..f7a939224569 100644 --- a/python_bindings/src/halide/_generator_helpers.py +++ b/python_bindings/src/halide/_generator_helpers.py @@ -394,6 +394,10 @@ def using_autoscheduler(self): def natural_vector_size(self, type: Type) -> int: return self.target().natural_vector_size(type) + def add_requirement(self, condition: Expr, *args) -> None: + assert self._stage < _Stage.pipeline_built + self._pipeline_requirements.append((condition, [*args])) + @classmethod def call(cls, *args, **kwargs): generator = cls() @@ -475,6 +479,7 @@ def __init__(self, generator_params: dict = {}): self._requirements = {} self._replacements = {} self._in_configure = 0 + self._pipeline_requirements = [] self._advance_to_gp_created() if generator_params: @@ -699,6 +704,8 @@ def _build_pipeline(self) -> Pipeline: funcs.append(f) self._pipeline = Pipeline(funcs) + for condition, error_args in self._pipeline_requirements: + self._pipeline.add_requirement(condition, *error_args) self._stage = _Stage.pipeline_built return self._pipeline diff --git a/python_bindings/src/halide/halide_/PyHalide.cpp b/python_bindings/src/halide/halide_/PyHalide.cpp index 4b9a14bf07aa..26a816e76b14 100644 --- a/python_bindings/src/halide/halide_/PyHalide.cpp +++ b/python_bindings/src/halide/halide_/PyHalide.cpp @@ -101,5 +101,21 @@ Expr double_to_expr_check(double v) { return Expr(f); } +std::vector collect_print_args(const py::args &args) { + std::vector v; + v.reserve(args.size()); + for (size_t i = 0; i < args.size(); ++i) { + // No way to see if a cast will work: just have to try + // and fail. Normally we don't want string to be convertible + // to Expr, but in this unusual case we do. + try { + v.emplace_back(args[i].cast()); + } catch (...) { + v.push_back(args[i].cast()); + } + } + return v; +} + } // namespace PythonBindings } // namespace Halide diff --git a/python_bindings/src/halide/halide_/PyHalide.h b/python_bindings/src/halide/halide_/PyHalide.h index 7e325c4952b3..2eefb1f463bf 100644 --- a/python_bindings/src/halide/halide_/PyHalide.h +++ b/python_bindings/src/halide/halide_/PyHalide.h @@ -32,6 +32,7 @@ std::vector args_to_vector(const py::args &args, size_t start_offset = 0, siz return v; } +std::vector collect_print_args(const py::args &args); Expr double_to_expr_check(double v); } // namespace PythonBindings diff --git a/python_bindings/src/halide/halide_/PyIROperator.cpp b/python_bindings/src/halide/halide_/PyIROperator.cpp index 74eedd9e723e..bcf6a18bfee9 100644 --- a/python_bindings/src/halide/halide_/PyIROperator.cpp +++ b/python_bindings/src/halide/halide_/PyIROperator.cpp @@ -7,30 +7,6 @@ namespace Halide { namespace PythonBindings { -namespace { - -// TODO: clever template usage could generalize this to list-of-types-to-try. -std::vector args_to_vector_for_print(const py::args &args, size_t start_offset = 0) { - if (args.size() < start_offset) { - throw py::value_error("Not enough arguments"); - } - std::vector v; - v.reserve(args.size() - (start_offset)); - for (size_t i = start_offset; i < args.size(); ++i) { - // No way to see if a cast will work: just have to try - // and fail. Normally we don't want string to be convertible - // to Expr, but in this unusual case we do. - try { - v.emplace_back(args[i].cast()); - } catch (...) { - v.push_back(args[i].cast()); - } - } - return v; -} - -} // namespace - void define_operators(py::module &m) { m.def("max", [](const py::args &args) -> Expr { if (args.size() < 2) { @@ -149,11 +125,11 @@ void define_operators(py::module &m) { m.def("reinterpret", (Expr(*)(Type, Expr)) & reinterpret); m.def("cast", (Expr(*)(Type, Expr)) & cast); m.def("print", [](const py::args &args) -> Expr { - return print(args_to_vector_for_print(args)); + return print(collect_print_args(args)); }); m.def( "print_when", [](const Expr &condition, const py::args &args) -> Expr { - return print_when(condition, args_to_vector_for_print(args)); + return print_when(condition, collect_print_args(args)); }, py::arg("condition")); m.def( diff --git a/python_bindings/src/halide/halide_/PyPipeline.cpp b/python_bindings/src/halide/halide_/PyPipeline.cpp index 18d932f01649..eaa6faf03b27 100644 --- a/python_bindings/src/halide/halide_/PyPipeline.cpp +++ b/python_bindings/src/halide/halide_/PyPipeline.cpp @@ -196,6 +196,13 @@ void define_pipeline(py::module &m) { .def("defined", &Pipeline::defined) .def("invalidate_cache", &Pipeline::invalidate_cache) + .def( + "add_requirement", [](Pipeline &p, const Expr &condition, const py::args &error_args) -> void { + auto v = collect_print_args(error_args); + p.add_requirement(condition, v); + }, + py::arg("condition")) + .def("__repr__", [](const Pipeline &p) -> std::string { std::ostringstream o; o << " 0, "negative values are bad", delta) + + delta.set(1) + p.realize([10]) + + try: + delta.set(0) + p.realize([10]) + except hl.HalideError as e: + assert 'Requirement Failed: (false)' in str(e) + else: + assert False, 'Did not see expected exception!' + + try: + delta.set(-1) + p.realize([10]) + except hl.HalideError as e: + assert 'Requirement Failed: (false) negative values are bad -1' in str(e) + else: + assert False, 'Did not see expected exception!' + if __name__ == "__main__": test_compiletime_error() test_runtime_error() @@ -402,3 +432,4 @@ def test_typed_funcs(): test_basics5() test_scalar_funcs() test_bool_conversion() + test_requirements() diff --git a/python_bindings/test/generators/addconstantcpp_generator.cpp b/python_bindings/test/generators/addconstantcpp_generator.cpp index 24b1a892d8ae..577de4dab890 100644 --- a/python_bindings/test/generators/addconstantcpp_generator.cpp +++ b/python_bindings/test/generators/addconstantcpp_generator.cpp @@ -48,6 +48,9 @@ class AddConstantGenerator : public Halide::Generator { Var x, y, z; void generate() { + add_requirement(scalar_int32 != 0); // error_args omitted for this case + add_requirement(scalar_int32 > 0, "negative values are bad", scalar_int32); + output_uint8(x) = input_uint8(x) + scalar_uint8; output_uint16(x) = input_uint16(x) + scalar_uint16; output_uint32(x) = input_uint32(x) + scalar_uint32; diff --git a/python_bindings/test/generators/addconstantpy_generator.py b/python_bindings/test/generators/addconstantpy_generator.py index 2275ced9faee..c48476e766a2 100644 --- a/python_bindings/test/generators/addconstantpy_generator.py +++ b/python_bindings/test/generators/addconstantpy_generator.py @@ -52,6 +52,9 @@ class AddConstantGenerator: def generate(self): g = self + g.add_requirement(g.scalar_int32 != 0) # error_args omitted for this case + g.add_requirement(g.scalar_int32 > 0, "negative values are bad", g.scalar_int32) + g.output_uint8[x] = g.input_uint8[x] + g.scalar_uint8 g.output_uint16[x] = g.input_uint16[x] + g.scalar_uint16 g.output_uint32[x] = g.input_uint32[x] + g.scalar_uint32 diff --git a/src/Generator.cpp b/src/Generator.cpp index bc242061437c..2f4d12c71481 100644 --- a/src/Generator.cpp +++ b/src/Generator.cpp @@ -1554,6 +1554,11 @@ void GeneratorBase::pre_schedule() { void GeneratorBase::post_schedule() { } +void GeneratorBase::add_requirement(const Expr &condition, const std::vector &error_args) { + internal_assert(!pipeline.defined()); + requirements.push_back({condition, error_args}); +} + Pipeline GeneratorBase::get_pipeline() { check_min_phase(GenerateCalled); if (!pipeline.defined()) { @@ -1584,6 +1589,9 @@ Pipeline GeneratorBase::get_pipeline() { } } pipeline = Pipeline(funcs); + for (const auto &r : requirements) { + pipeline.add_requirement(r.condition, r.error_args); + } } return pipeline; } diff --git a/src/Generator.h b/src/Generator.h index 66c2442e05d4..18ba2414b8e4 100644 --- a/src/Generator.h +++ b/src/Generator.h @@ -3444,9 +3444,14 @@ class GeneratorBase : public NamesInterface, public AbstractGenerator { return p; } - template - HALIDE_NO_USER_CODE_INLINE void add_requirement(Expr condition, Args &&...args) { - get_pipeline().add_requirement(condition, std::forward(args)...); + void add_requirement(const Expr &condition, const std::vector &error_args); + + template::value>::type> + inline HALIDE_NO_USER_CODE_INLINE void add_requirement(const Expr &condition, Args &&...error_args) { + std::vector collected_args; + Internal::collect_print_args(collected_args, std::forward(error_args)...); + add_requirement(condition, collected_args); } void trace_pipeline() { @@ -3636,6 +3641,12 @@ class GeneratorBase : public NamesInterface, public AbstractGenerator { std::string generator_registered_name, generator_stub_name; Pipeline pipeline; + struct Requirement { + Expr condition; + std::vector error_args; + }; + std::vector requirements; + // Return our GeneratorParamInfo. GeneratorParamInfo ¶m_info(); diff --git a/src/IROperator.h b/src/IROperator.h index d61362b5d13d..1342de00f6e6 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -322,6 +322,15 @@ Stmt remove_promises(const Stmt &s); * the tagged expression. If not, returns the expression. */ Expr unwrap_tags(const Expr &e); +template +struct is_printable_arg { + static constexpr bool value = std::is_convertible::value || + std::is_convertible::value; +}; + +template +struct all_are_printable_args : meta_and...> {}; + // Secondary args to print can be Exprs or const char * inline HALIDE_NO_USER_CODE_INLINE void collect_print_args(std::vector &args) { } diff --git a/src/Pipeline.cpp b/src/Pipeline.cpp index 76b28b7255ae..87753b79fdbb 100644 --- a/src/Pipeline.cpp +++ b/src/Pipeline.cpp @@ -790,7 +790,7 @@ Realization Pipeline::realize(JITUserContext *context, return r; } -void Pipeline::add_requirement(const Expr &condition, std::vector &error_args) { +void Pipeline::add_requirement(const Expr &condition, const std::vector &error_args) { user_assert(defined()) << "Pipeline is undefined\n"; // It is an error for a requirement to reference a Func or a Var diff --git a/src/Pipeline.h b/src/Pipeline.h index bb67391f4a44..0a7bc078904c 100644 --- a/src/Pipeline.h +++ b/src/Pipeline.h @@ -547,17 +547,20 @@ class Pipeline { * with the remaining arguments, and return * halide_error_code_requirement_failed. Requirements are checked * in the order added. */ - void add_requirement(const Expr &condition, std::vector &error); - - /** Generate begin_pipeline and end_pipeline tracing calls for this pipeline. */ - void trace_pipeline(); + // @{ + void add_requirement(const Expr &condition, const std::vector &error_args); - template - inline HALIDE_NO_USER_CODE_INLINE void add_requirement(const Expr &condition, Args &&...args) { + template::value>::type> + inline HALIDE_NO_USER_CODE_INLINE void add_requirement(const Expr &condition, Args &&...error_args) { std::vector collected_args; - Internal::collect_print_args(collected_args, std::forward(args)...); + Internal::collect_print_args(collected_args, std::forward(error_args)...); add_requirement(condition, collected_args); } + // @} + + /** Generate begin_pipeline and end_pipeline tracing calls for this pipeline. */ + void trace_pipeline(); private: std::string generate_function_name() const;