diff --git a/python_bindings/src/halide/_generator_helpers.py b/python_bindings/src/halide/_generator_helpers.py index 5349f7409e0b..f7a939224569 100644 --- a/python_bindings/src/halide/_generator_helpers.py +++ b/python_bindings/src/halide/_generator_helpers.py @@ -394,6 +394,10 @@ def using_autoscheduler(self): def natural_vector_size(self, type: Type) -> int: return self.target().natural_vector_size(type) + def add_requirement(self, condition: Expr, *args) -> None: + assert self._stage < _Stage.pipeline_built + self._pipeline_requirements.append((condition, [*args])) + @classmethod def call(cls, *args, **kwargs): generator = cls() @@ -475,6 +479,7 @@ def __init__(self, generator_params: dict = {}): self._requirements = {} self._replacements = {} self._in_configure = 0 + self._pipeline_requirements = [] self._advance_to_gp_created() if generator_params: @@ -699,6 +704,8 @@ def _build_pipeline(self) -> Pipeline: funcs.append(f) self._pipeline = Pipeline(funcs) + for condition, error_args in self._pipeline_requirements: + self._pipeline.add_requirement(condition, *error_args) self._stage = _Stage.pipeline_built return self._pipeline diff --git a/python_bindings/src/halide/halide_/PyHalide.cpp b/python_bindings/src/halide/halide_/PyHalide.cpp index 4b9a14bf07aa..26a816e76b14 100644 --- a/python_bindings/src/halide/halide_/PyHalide.cpp +++ b/python_bindings/src/halide/halide_/PyHalide.cpp @@ -101,5 +101,21 @@ Expr double_to_expr_check(double v) { return Expr(f); } +std::vector collect_print_args(const py::args &args) { + std::vector v; + v.reserve(args.size()); + for (size_t i = 0; i < args.size(); ++i) { + // No way to see if a cast will work: just have to try + // and fail. Normally we don't want string to be convertible + // to Expr, but in this unusual case we do. + try { + v.emplace_back(args[i].cast()); + } catch (...) { + v.push_back(args[i].cast()); + } + } + return v; +} + } // namespace PythonBindings } // namespace Halide diff --git a/python_bindings/src/halide/halide_/PyHalide.h b/python_bindings/src/halide/halide_/PyHalide.h index 7e325c4952b3..2eefb1f463bf 100644 --- a/python_bindings/src/halide/halide_/PyHalide.h +++ b/python_bindings/src/halide/halide_/PyHalide.h @@ -32,6 +32,7 @@ std::vector args_to_vector(const py::args &args, size_t start_offset = 0, siz return v; } +std::vector collect_print_args(const py::args &args); Expr double_to_expr_check(double v); } // namespace PythonBindings diff --git a/python_bindings/src/halide/halide_/PyIROperator.cpp b/python_bindings/src/halide/halide_/PyIROperator.cpp index 74eedd9e723e..bcf6a18bfee9 100644 --- a/python_bindings/src/halide/halide_/PyIROperator.cpp +++ b/python_bindings/src/halide/halide_/PyIROperator.cpp @@ -7,30 +7,6 @@ namespace Halide { namespace PythonBindings { -namespace { - -// TODO: clever template usage could generalize this to list-of-types-to-try. -std::vector args_to_vector_for_print(const py::args &args, size_t start_offset = 0) { - if (args.size() < start_offset) { - throw py::value_error("Not enough arguments"); - } - std::vector v; - v.reserve(args.size() - (start_offset)); - for (size_t i = start_offset; i < args.size(); ++i) { - // No way to see if a cast will work: just have to try - // and fail. Normally we don't want string to be convertible - // to Expr, but in this unusual case we do. - try { - v.emplace_back(args[i].cast()); - } catch (...) { - v.push_back(args[i].cast()); - } - } - return v; -} - -} // namespace - void define_operators(py::module &m) { m.def("max", [](const py::args &args) -> Expr { if (args.size() < 2) { @@ -149,11 +125,11 @@ void define_operators(py::module &m) { m.def("reinterpret", (Expr(*)(Type, Expr)) & reinterpret); m.def("cast", (Expr(*)(Type, Expr)) & cast); m.def("print", [](const py::args &args) -> Expr { - return print(args_to_vector_for_print(args)); + return print(collect_print_args(args)); }); m.def( "print_when", [](const Expr &condition, const py::args &args) -> Expr { - return print_when(condition, args_to_vector_for_print(args)); + return print_when(condition, collect_print_args(args)); }, py::arg("condition")); m.def( diff --git a/python_bindings/src/halide/halide_/PyPipeline.cpp b/python_bindings/src/halide/halide_/PyPipeline.cpp index 18d932f01649..eaa6faf03b27 100644 --- a/python_bindings/src/halide/halide_/PyPipeline.cpp +++ b/python_bindings/src/halide/halide_/PyPipeline.cpp @@ -196,6 +196,13 @@ void define_pipeline(py::module &m) { .def("defined", &Pipeline::defined) .def("invalidate_cache", &Pipeline::invalidate_cache) + .def( + "add_requirement", [](Pipeline &p, const Expr &condition, const py::args &error_args) -> void { + auto v = collect_print_args(error_args); + p.add_requirement(condition, v); + }, + py::arg("condition")) + .def("__repr__", [](const Pipeline &p) -> std::string { std::ostringstream o; o << " 0, "negative values are bad", delta) + + delta.set(1) + p.realize([10]) + + try: + delta.set(0) + p.realize([10]) + except hl.HalideError as e: + assert 'Requirement Failed: (false)' in str(e) + else: + assert False, 'Did not see expected exception!' + + try: + delta.set(-1) + p.realize([10]) + except hl.HalideError as e: + assert 'Requirement Failed: (false) negative values are bad -1' in str(e) + else: + assert False, 'Did not see expected exception!' + if __name__ == "__main__": test_compiletime_error() test_runtime_error() @@ -402,3 +432,4 @@ def test_typed_funcs(): test_basics5() test_scalar_funcs() test_bool_conversion() + test_requirements() diff --git a/python_bindings/test/generators/addconstantcpp_generator.cpp b/python_bindings/test/generators/addconstantcpp_generator.cpp index 24b1a892d8ae..577de4dab890 100644 --- a/python_bindings/test/generators/addconstantcpp_generator.cpp +++ b/python_bindings/test/generators/addconstantcpp_generator.cpp @@ -48,6 +48,9 @@ class AddConstantGenerator : public Halide::Generator { Var x, y, z; void generate() { + add_requirement(scalar_int32 != 0); // error_args omitted for this case + add_requirement(scalar_int32 > 0, "negative values are bad", scalar_int32); + output_uint8(x) = input_uint8(x) + scalar_uint8; output_uint16(x) = input_uint16(x) + scalar_uint16; output_uint32(x) = input_uint32(x) + scalar_uint32; diff --git a/python_bindings/test/generators/addconstantpy_generator.py b/python_bindings/test/generators/addconstantpy_generator.py index 2275ced9faee..c48476e766a2 100644 --- a/python_bindings/test/generators/addconstantpy_generator.py +++ b/python_bindings/test/generators/addconstantpy_generator.py @@ -52,6 +52,9 @@ class AddConstantGenerator: def generate(self): g = self + g.add_requirement(g.scalar_int32 != 0) # error_args omitted for this case + g.add_requirement(g.scalar_int32 > 0, "negative values are bad", g.scalar_int32) + g.output_uint8[x] = g.input_uint8[x] + g.scalar_uint8 g.output_uint16[x] = g.input_uint16[x] + g.scalar_uint16 g.output_uint32[x] = g.input_uint32[x] + g.scalar_uint32 diff --git a/src/Generator.cpp b/src/Generator.cpp index bc242061437c..2f4d12c71481 100644 --- a/src/Generator.cpp +++ b/src/Generator.cpp @@ -1554,6 +1554,11 @@ void GeneratorBase::pre_schedule() { void GeneratorBase::post_schedule() { } +void GeneratorBase::add_requirement(const Expr &condition, const std::vector &error_args) { + internal_assert(!pipeline.defined()); + requirements.push_back({condition, error_args}); +} + Pipeline GeneratorBase::get_pipeline() { check_min_phase(GenerateCalled); if (!pipeline.defined()) { @@ -1584,6 +1589,9 @@ Pipeline GeneratorBase::get_pipeline() { } } pipeline = Pipeline(funcs); + for (const auto &r : requirements) { + pipeline.add_requirement(r.condition, r.error_args); + } } return pipeline; } diff --git a/src/Generator.h b/src/Generator.h index 66c2442e05d4..18ba2414b8e4 100644 --- a/src/Generator.h +++ b/src/Generator.h @@ -3444,9 +3444,14 @@ class GeneratorBase : public NamesInterface, public AbstractGenerator { return p; } - template - HALIDE_NO_USER_CODE_INLINE void add_requirement(Expr condition, Args &&...args) { - get_pipeline().add_requirement(condition, std::forward(args)...); + void add_requirement(const Expr &condition, const std::vector &error_args); + + template::value>::type> + inline HALIDE_NO_USER_CODE_INLINE void add_requirement(const Expr &condition, Args &&...error_args) { + std::vector collected_args; + Internal::collect_print_args(collected_args, std::forward(error_args)...); + add_requirement(condition, collected_args); } void trace_pipeline() { @@ -3636,6 +3641,12 @@ class GeneratorBase : public NamesInterface, public AbstractGenerator { std::string generator_registered_name, generator_stub_name; Pipeline pipeline; + struct Requirement { + Expr condition; + std::vector error_args; + }; + std::vector requirements; + // Return our GeneratorParamInfo. GeneratorParamInfo ¶m_info(); diff --git a/src/IROperator.h b/src/IROperator.h index d61362b5d13d..1342de00f6e6 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -322,6 +322,15 @@ Stmt remove_promises(const Stmt &s); * the tagged expression. If not, returns the expression. */ Expr unwrap_tags(const Expr &e); +template +struct is_printable_arg { + static constexpr bool value = std::is_convertible::value || + std::is_convertible::value; +}; + +template +struct all_are_printable_args : meta_and...> {}; + // Secondary args to print can be Exprs or const char * inline HALIDE_NO_USER_CODE_INLINE void collect_print_args(std::vector &args) { } diff --git a/src/Pipeline.cpp b/src/Pipeline.cpp index 76b28b7255ae..87753b79fdbb 100644 --- a/src/Pipeline.cpp +++ b/src/Pipeline.cpp @@ -790,7 +790,7 @@ Realization Pipeline::realize(JITUserContext *context, return r; } -void Pipeline::add_requirement(const Expr &condition, std::vector &error_args) { +void Pipeline::add_requirement(const Expr &condition, const std::vector &error_args) { user_assert(defined()) << "Pipeline is undefined\n"; // It is an error for a requirement to reference a Func or a Var diff --git a/src/Pipeline.h b/src/Pipeline.h index bb67391f4a44..0a7bc078904c 100644 --- a/src/Pipeline.h +++ b/src/Pipeline.h @@ -547,17 +547,20 @@ class Pipeline { * with the remaining arguments, and return * halide_error_code_requirement_failed. Requirements are checked * in the order added. */ - void add_requirement(const Expr &condition, std::vector &error); - - /** Generate begin_pipeline and end_pipeline tracing calls for this pipeline. */ - void trace_pipeline(); + // @{ + void add_requirement(const Expr &condition, const std::vector &error_args); - template - inline HALIDE_NO_USER_CODE_INLINE void add_requirement(const Expr &condition, Args &&...args) { + template::value>::type> + inline HALIDE_NO_USER_CODE_INLINE void add_requirement(const Expr &condition, Args &&...error_args) { std::vector collected_args; - Internal::collect_print_args(collected_args, std::forward(args)...); + Internal::collect_print_args(collected_args, std::forward(error_args)...); add_requirement(condition, collected_args); } + // @} + + /** Generate begin_pipeline and end_pipeline tracing calls for this pipeline. */ + void trace_pipeline(); private: std::string generate_function_name() const;