Skip to content

Commit

Permalink
Python: make Func implicitly convertible to Stage (#6702)
Browse files Browse the repository at this point in the history
This allows for `compute_with` and `rfactor` to work more seamlessly in Python.

Also:
- Move two compute_with() variant bindings from PyFunc and PyStage to PyScheduleMethods, as they are identical between the two
- drive-by removal of redundant `py::implicitly_convertible<ImageParam, Func>();` call
  • Loading branch information
steven-johnson committed Apr 14, 2022
1 parent 63d24fb commit 91f4a56
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 11 deletions.
5 changes: 0 additions & 5 deletions python_bindings/src/PyFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,6 @@ void define_func(py::module &m) {

.def("fold_storage", &Func::fold_storage, py::arg("dim"), py::arg("extent"), py::arg("fold_forward") = true)

.def("compute_with", (Func & (Func::*)(LoopLevel, const std::vector<std::pair<VarOrRVar, LoopAlignStrategy>> &)) & Func::compute_with, py::arg("loop_level"), py::arg("align"))
.def("compute_with", (Func & (Func::*)(LoopLevel, LoopAlignStrategy)) & Func::compute_with, py::arg("loop_level"), py::arg("align") = LoopAlignStrategy::Auto)

.def("infer_arguments", &Func::infer_arguments)

.def("__repr__", [](const Func &func) -> std::string {
Expand Down Expand Up @@ -353,8 +350,6 @@ void define_func(py::module &m) {

add_schedule_methods(func_class);

py::implicitly_convertible<ImageParam, Func>();

define_stage(m);
}

Expand Down
4 changes: 4 additions & 0 deletions python_bindings/src/PyScheduleMethods.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ HALIDE_NEVER_INLINE void add_schedule_methods(PythonClass &class_instance) {
py::arg("stage"), py::arg("var"), py::arg("align"))
.def("compute_with", (T & (T::*)(const Stage &, const VarOrRVar &, LoopAlignStrategy)) & T::compute_with,
py::arg("stage"), py::arg("var"), py::arg("align") = LoopAlignStrategy::Auto)
.def("compute_with", (T & (T::*)(LoopLevel, const std::vector<std::pair<VarOrRVar, LoopAlignStrategy>> &)) & T::compute_with,
py::arg("loop_level"), py::arg("align"))
.def("compute_with", (T & (T::*)(LoopLevel, LoopAlignStrategy)) & T::compute_with,
py::arg("loop_level"), py::arg("align") = LoopAlignStrategy::Auto)

.def("unroll", (T & (T::*)(const VarOrRVar &)) & T::unroll,
py::arg("var"))
Expand Down
12 changes: 6 additions & 6 deletions python_bindings/src/PyStage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@ namespace PythonBindings {
void define_stage(py::module &m) {
auto stage_class =
py::class_<Stage>(m, "Stage")
// for implicitly_convertible
.def(py::init([](const Func &f) -> Stage { return f; }))

.def("dump_argument_list", &Stage::dump_argument_list)
.def("name", &Stage::name)

.def("rfactor", (Func(Stage::*)(std::vector<std::pair<RVar, Var>>)) & Stage::rfactor,
py::arg("preserved"))
.def("rfactor", (Func(Stage::*)(const RVar &, const Var &)) & Stage::rfactor,
py::arg("r"), py::arg("v"))
py::arg("r"), py::arg("v"));

py::implicitly_convertible<Func, Stage>();

// These two variants of compute_with are specific to Stage
.def("compute_with", (Stage & (Stage::*)(LoopLevel, const std::vector<std::pair<VarOrRVar, LoopAlignStrategy>> &)) & Stage::compute_with,
py::arg("loop_level"), py::arg("align"))
.def("compute_with", (Stage & (Stage::*)(LoopLevel, LoopAlignStrategy)) & Stage::compute_with,
py::arg("loop_level"), py::arg("align") = LoopAlignStrategy::Auto);
add_schedule_methods(stage_class);
}

Expand Down

0 comments on commit 91f4a56

Please sign in to comment.