diff --git a/python_bindings/src/halide/halide_/PyFunc.cpp b/python_bindings/src/halide/halide_/PyFunc.cpp index 8c21a145d351..a1e222f5c8f0 100644 --- a/python_bindings/src/halide/halide_/PyFunc.cpp +++ b/python_bindings/src/halide/halide_/PyFunc.cpp @@ -197,6 +197,8 @@ void define_func(py::module &m) { .def("type", &Func::type) .def("types", &Func::types) + .def("split_vars", &Func::split_vars) + .def("bound", &Func::bound, py::arg("var"), py::arg("min"), py::arg("extent")) .def("reorder_storage", (Func & (Func::*)(const std::vector &)) & Func::reorder_storage, py::arg("dims")) diff --git a/python_bindings/src/halide/halide_/PyStage.cpp b/python_bindings/src/halide/halide_/PyStage.cpp index e9ac740f0adb..9a9f6394d7cd 100644 --- a/python_bindings/src/halide/halide_/PyStage.cpp +++ b/python_bindings/src/halide/halide_/PyStage.cpp @@ -19,6 +19,21 @@ void define_stage(py::module &m) { .def("rfactor", static_cast(&Stage::rfactor), py::arg("r"), py::arg("v")) + .def("split_vars", [](const Stage &stage) -> py::list { + auto vars = stage.split_vars(); + py::list result; + // Return a mixed-type list of Var and RVar objects, instead of + // a list of VarOrRVar objects. + for (const auto &v : vars) { + if (v.is_rvar) { + result.append(py::cast(v.rvar)); + } else { + result.append(py::cast(v.var)); + } + } + return result; + }) + .def("unscheduled", &Stage::unscheduled); py::implicitly_convertible(); diff --git a/python_bindings/test/correctness/basics.py b/python_bindings/test/correctness/basics.py index 03d9d86220bb..a3f2c22053a4 100644 --- a/python_bindings/test/correctness/basics.py +++ b/python_bindings/test/correctness/basics.py @@ -103,7 +103,6 @@ def test_basics(): blur_x.compute_at(blur_y, x).vectorize(x, 8) blur_y.compile_jit() - def test_basics2(): input = hl.ImageParam(hl.Float(32), 3, "input") hl.Param(hl.Float(32), "r_sigma", 0.1) @@ -597,6 +596,31 @@ def test_print_ir(): p = hl.Pipeline() assert str(p) == "" +def test_split_vars(): + f = hl.Func("f") + (x, xo, xi) = hl.vars("x xo xi") + f[x] = x + r = hl.RDom([(0, 10), (0, 10)], "r") + f[x] += x + r.x + r.y + + f.split(x, xo, xi, 8) + + vars = f.split_vars() + assert len(vars) == 3 + assert vars[0].name() == xi.name() + assert vars[1].name() == xo.name() + assert vars[2].name() == hl.Var.outermost().name() + + (rxo, rxi) = (hl.RVar("rxo"), hl.RVar("rxi")) + f.update().split(r.x, rxo, rxi, 4) + + vars = f.update().split_vars() + assert len(vars) == 5 + assert isinstance(vars[0], hl.RVar) and vars[0].name() == rxi.name() + assert isinstance(vars[1], hl.RVar) and vars[1].name() == rxo.name() + assert isinstance(vars[2], hl.RVar) and vars[2].name() == r.y.name() + assert isinstance(vars[3], hl.Var) and vars[3].name() == x.name() + assert isinstance(vars[4], hl.Var) and vars[4].name() == hl.Var.outermost().name() if __name__ == "__main__": test_compiletime_error() @@ -622,3 +646,4 @@ def test_print_ir(): test_implicit_update_by_int() test_implicit_update_by_float() test_print_ir() + test_split_vars() diff --git a/src/Func.cpp b/src/Func.cpp index cf8904ec2d31..0a7d570167e4 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -1875,6 +1875,14 @@ Stage &Stage::reorder(const std::vector &vars) { return *this; } +std::vector Stage::split_vars() const { + std::vector result; + for (const auto &d : definition.schedule().dims()) { + result.emplace_back(split_string(d.var, ".").back(), d.is_rvar()); + } + return result; +} + Stage &Stage::gpu_threads(const VarOrRVar &tx, DeviceAPI device_api) { set_dim_device_api(tx, device_api); set_dim_type(tx, ForType::GPUThread); @@ -2641,6 +2649,19 @@ Func &Func::reorder(const std::vector &vars) { return *this; } +std::vector Func::split_vars() const { + std::vector result; + for (const auto &d : func.definition().schedule().dims()) { + // Pure stages can't have RVars + internal_assert(!d.is_rvar()) + << "The initial stage of Func " << name() + << " unexpectedly has RVar " << d.var + << "in the dims list. Initial stages aren't supposed to have RVars."; + result.emplace_back(split_string(d.var, ".").back()); + } + return result; +} + Func &Func::gpu_threads(const VarOrRVar &tx, DeviceAPI device_api) { invalidate_cache(); Stage(func, func.definition(), 0).gpu_threads(tx, device_api); diff --git a/src/Func.h b/src/Func.h index 517b725e4bb6..9fe4ba543efd 100644 --- a/src/Func.h +++ b/src/Func.h @@ -474,6 +474,14 @@ class Stage { } // @} + /** Get the Vars and RVars of this definition, from innermost out, with + * splits applied. This represents all the potentially-valid compute_at + * sites for this Stage. The RVars returned will be symbolic and not tied to + * a particular reduction domain, like the naked RVar objects used as split + * outputs. Note that this list by default will end with the sentinel + * Var::outermost. */ + std::vector split_vars() const; + /** Assert that this stage has intentionally been given no schedule, and * suppress the warning about unscheduled update definitions that would * otherwise fire. This counts as a schedule, so calling this twice on the @@ -1602,6 +1610,12 @@ class Func { return reorder(collected_args); } + /** Get the Vars of the pure definition, with splits applied. This + * represents all the potentially-valid compute_at sites for this stage of + * this Func. Note that this, by default, will end with the sentinel + * Var::outermost. */ + std::vector split_vars() const; + /** Rename a dimension. Equivalent to split with a inner size of one. */ Func &rename(const VarOrRVar &old_name, const VarOrRVar &new_name); diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 586db8d1db8e..4bce8789875e 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -53,6 +53,7 @@ tests(GROUPS correctness compile_to_bitcode.cpp compile_to_lowered_stmt.cpp compile_to_multitarget.cpp + compute_at_innermost.cpp compute_at_reordered_update_stage.cpp compute_at_split_rvar.cpp compute_inside_guard.cpp diff --git a/test/correctness/compute_at_innermost.cpp b/test/correctness/compute_at_innermost.cpp new file mode 100644 index 000000000000..d9b29f884617 --- /dev/null +++ b/test/correctness/compute_at_innermost.cpp @@ -0,0 +1,40 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + // Say we have a whole bunch of producer-consumer pairs, scheduled + // differently, and we always want to compute the corresponding producer + // innermost, even though that's not the same Var for each consumer. We can + // write a generic schedule using Func::split_vars() to get the list of + // scheduling points for each g + + std::vector producers, consumers; + Var x, xo, xi; + + for (int i = 0; i < 4; i++) { + producers.emplace_back("f" + std::to_string(i)); + producers.back()(x) = x + i; + consumers.emplace_back("g" + std::to_string(i)); + consumers.back()(x) = producers.back()(x) + 1; + } + + // And we want to schedule some of consumers differently than others: + + for (int i = 0; i < 4; i++) { + consumers[i].compute_root(); + + if (i == 3) { + consumers[3].split(x, xo, xi, 8); + } + + producers[i].compute_at(consumers[i], consumers[i].split_vars()[0]); + + // Just check these schedules are all legal, by running each but not + // checking the output. + consumers[i].realize({10}); + } + + printf("Success!\n"); + return 0; +}