From 91f4a56aa5b7121622644472d1ff50e5c9d3061d Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Wed, 13 Apr 2022 09:44:00 -0700 Subject: [PATCH] Python: make Func implicitly convertible to Stage (#6702) This allows for `compute_with` and `rfactor` to work more seamlessly in Python. Also: - Move two compute_with() variant bindings from PyFunc and PyStage to PyScheduleMethods, as they are identical between the two - drive-by removal of redundant `py::implicitly_convertible();` call --- python_bindings/src/PyFunc.cpp | 5 ----- python_bindings/src/PyScheduleMethods.h | 4 ++++ python_bindings/src/PyStage.cpp | 12 ++++++------ 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/python_bindings/src/PyFunc.cpp b/python_bindings/src/PyFunc.cpp index 8957f241d07e..f66d96763860 100644 --- a/python_bindings/src/PyFunc.cpp +++ b/python_bindings/src/PyFunc.cpp @@ -300,9 +300,6 @@ void define_func(py::module &m) { .def("fold_storage", &Func::fold_storage, py::arg("dim"), py::arg("extent"), py::arg("fold_forward") = true) - .def("compute_with", (Func & (Func::*)(LoopLevel, const std::vector> &)) & Func::compute_with, py::arg("loop_level"), py::arg("align")) - .def("compute_with", (Func & (Func::*)(LoopLevel, LoopAlignStrategy)) & Func::compute_with, py::arg("loop_level"), py::arg("align") = LoopAlignStrategy::Auto) - .def("infer_arguments", &Func::infer_arguments) .def("__repr__", [](const Func &func) -> std::string { @@ -353,8 +350,6 @@ void define_func(py::module &m) { add_schedule_methods(func_class); - py::implicitly_convertible(); - define_stage(m); } diff --git a/python_bindings/src/PyScheduleMethods.h b/python_bindings/src/PyScheduleMethods.h index 8280ccd99ec5..0a6c2b3dfb75 100644 --- a/python_bindings/src/PyScheduleMethods.h +++ b/python_bindings/src/PyScheduleMethods.h @@ -17,6 +17,10 @@ HALIDE_NEVER_INLINE void add_schedule_methods(PythonClass &class_instance) { py::arg("stage"), py::arg("var"), py::arg("align")) .def("compute_with", (T & (T::*)(const Stage &, const VarOrRVar &, LoopAlignStrategy)) & T::compute_with, py::arg("stage"), py::arg("var"), py::arg("align") = LoopAlignStrategy::Auto) + .def("compute_with", (T & (T::*)(LoopLevel, const std::vector> &)) & T::compute_with, + py::arg("loop_level"), py::arg("align")) + .def("compute_with", (T & (T::*)(LoopLevel, LoopAlignStrategy)) & T::compute_with, + py::arg("loop_level"), py::arg("align") = LoopAlignStrategy::Auto) .def("unroll", (T & (T::*)(const VarOrRVar &)) & T::unroll, py::arg("var")) diff --git a/python_bindings/src/PyStage.cpp b/python_bindings/src/PyStage.cpp index 2efda5fd816a..e84c6fcc7189 100644 --- a/python_bindings/src/PyStage.cpp +++ b/python_bindings/src/PyStage.cpp @@ -8,19 +8,19 @@ namespace PythonBindings { void define_stage(py::module &m) { auto stage_class = py::class_(m, "Stage") + // for implicitly_convertible + .def(py::init([](const Func &f) -> Stage { return f; })) + .def("dump_argument_list", &Stage::dump_argument_list) .def("name", &Stage::name) .def("rfactor", (Func(Stage::*)(std::vector>)) & Stage::rfactor, py::arg("preserved")) .def("rfactor", (Func(Stage::*)(const RVar &, const Var &)) & Stage::rfactor, - py::arg("r"), py::arg("v")) + py::arg("r"), py::arg("v")); + + py::implicitly_convertible(); - // These two variants of compute_with are specific to Stage - .def("compute_with", (Stage & (Stage::*)(LoopLevel, const std::vector> &)) & Stage::compute_with, - py::arg("loop_level"), py::arg("align")) - .def("compute_with", (Stage & (Stage::*)(LoopLevel, LoopAlignStrategy)) & Stage::compute_with, - py::arg("loop_level"), py::arg("align") = LoopAlignStrategy::Auto); add_schedule_methods(stage_class); }