Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python_bindings/src/halide/halide_/PyFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ void define_func(py::module &m) {
.def("type", &Func::type)
.def("types", &Func::types)

.def("split_vars", &Func::split_vars)

.def("bound", &Func::bound, py::arg("var"), py::arg("min"), py::arg("extent"))

.def("reorder_storage", (Func & (Func::*)(const std::vector<Var> &)) & Func::reorder_storage, py::arg("dims"))
Expand Down
15 changes: 15 additions & 0 deletions python_bindings/src/halide/halide_/PyStage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ void define_stage(py::module &m) {
.def("rfactor", static_cast<Func (Stage::*)(const RVar &, const Var &)>(&Stage::rfactor),
py::arg("r"), py::arg("v"))

.def("split_vars", [](const Stage &stage) -> py::list {
auto vars = stage.split_vars();
py::list result;
// Return a mixed-type list of Var and RVar objects, instead of
// a list of VarOrRVar objects.
for (const auto &v : vars) {
if (v.is_rvar) {
result.append(py::cast(v.rvar));
} else {
result.append(py::cast(v.var));
}
}
return result;
})

.def("unscheduled", &Stage::unscheduled);

py::implicitly_convertible<Func, Stage>();
Expand Down
27 changes: 26 additions & 1 deletion python_bindings/test/correctness/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def test_basics():
blur_x.compute_at(blur_y, x).vectorize(x, 8)
blur_y.compile_jit()


def test_basics2():
input = hl.ImageParam(hl.Float(32), 3, "input")
hl.Param(hl.Float(32), "r_sigma", 0.1)
Expand Down Expand Up @@ -597,6 +596,31 @@ def test_print_ir():
p = hl.Pipeline()
assert str(p) == "<halide.Pipeline Pipeline()>"

def test_split_vars():
f = hl.Func("f")
(x, xo, xi) = hl.vars("x xo xi")
f[x] = x
r = hl.RDom([(0, 10), (0, 10)], "r")
f[x] += x + r.x + r.y

f.split(x, xo, xi, 8)

vars = f.split_vars()
assert len(vars) == 3
assert vars[0].name() == xi.name()
assert vars[1].name() == xo.name()
assert vars[2].name() == hl.Var.outermost().name()

(rxo, rxi) = (hl.RVar("rxo"), hl.RVar("rxi"))
f.update().split(r.x, rxo, rxi, 4)

vars = f.update().split_vars()
assert len(vars) == 5
assert isinstance(vars[0], hl.RVar) and vars[0].name() == rxi.name()
assert isinstance(vars[1], hl.RVar) and vars[1].name() == rxo.name()
assert isinstance(vars[2], hl.RVar) and vars[2].name() == r.y.name()
assert isinstance(vars[3], hl.Var) and vars[3].name() == x.name()
assert isinstance(vars[4], hl.Var) and vars[4].name() == hl.Var.outermost().name()

if __name__ == "__main__":
test_compiletime_error()
Expand All @@ -622,3 +646,4 @@ def test_print_ir():
test_implicit_update_by_int()
test_implicit_update_by_float()
test_print_ir()
test_split_vars()
21 changes: 21 additions & 0 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1875,6 +1875,14 @@ Stage &Stage::reorder(const std::vector<VarOrRVar> &vars) {
return *this;
}

std::vector<VarOrRVar> Stage::split_vars() const {
std::vector<VarOrRVar> result;
for (const auto &d : definition.schedule().dims()) {
result.emplace_back(split_string(d.var, ".").back(), d.is_rvar());
}
return result;
}

Stage &Stage::gpu_threads(const VarOrRVar &tx, DeviceAPI device_api) {
set_dim_device_api(tx, device_api);
set_dim_type(tx, ForType::GPUThread);
Expand Down Expand Up @@ -2641,6 +2649,19 @@ Func &Func::reorder(const std::vector<VarOrRVar> &vars) {
return *this;
}

std::vector<Var> Func::split_vars() const {
std::vector<Var> result;
for (const auto &d : func.definition().schedule().dims()) {
// Pure stages can't have RVars
internal_assert(!d.is_rvar())
<< "The initial stage of Func " << name()
<< " unexpectedly has RVar " << d.var
<< "in the dims list. Initial stages aren't supposed to have RVars.";
result.emplace_back(split_string(d.var, ".").back());
}
return result;
}

Func &Func::gpu_threads(const VarOrRVar &tx, DeviceAPI device_api) {
invalidate_cache();
Stage(func, func.definition(), 0).gpu_threads(tx, device_api);
Expand Down
14 changes: 14 additions & 0 deletions src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,14 @@ class Stage {
}
// @}

/** Get the Vars and RVars of this definition, from innermost out, with
* splits applied. This represents all the potentially-valid compute_at
* sites for this Stage. The RVars returned will be symbolic and not tied to
* a particular reduction domain, like the naked RVar objects used as split
* outputs. Note that this list by default will end with the sentinel
* Var::outermost. */
std::vector<VarOrRVar> split_vars() const;

/** Assert that this stage has intentionally been given no schedule, and
* suppress the warning about unscheduled update definitions that would
* otherwise fire. This counts as a schedule, so calling this twice on the
Expand Down Expand Up @@ -1602,6 +1610,12 @@ class Func {
return reorder(collected_args);
}

/** Get the Vars of the pure definition, with splits applied. This
* represents all the potentially-valid compute_at sites for this stage of
* this Func. Note that this, by default, will end with the sentinel
* Var::outermost. */
std::vector<Var> split_vars() const;

/** Rename a dimension. Equivalent to split with a inner size of one. */
Func &rename(const VarOrRVar &old_name, const VarOrRVar &new_name);

Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ tests(GROUPS correctness
compile_to_bitcode.cpp
compile_to_lowered_stmt.cpp
compile_to_multitarget.cpp
compute_at_innermost.cpp
compute_at_reordered_update_stage.cpp
compute_at_split_rvar.cpp
compute_inside_guard.cpp
Expand Down
40 changes: 40 additions & 0 deletions test/correctness/compute_at_innermost.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include "Halide.h"

using namespace Halide;

int main(int argc, char **argv) {
// Say we have a whole bunch of producer-consumer pairs, scheduled
// differently, and we always want to compute the corresponding producer
// innermost, even though that's not the same Var for each consumer. We can
// write a generic schedule using Func::split_vars() to get the list of
// scheduling points for each g

std::vector<Func> producers, consumers;
Var x, xo, xi;

for (int i = 0; i < 4; i++) {
producers.emplace_back("f" + std::to_string(i));
producers.back()(x) = x + i;
consumers.emplace_back("g" + std::to_string(i));
consumers.back()(x) = producers.back()(x) + 1;
}

// And we want to schedule some of consumers differently than others:

for (int i = 0; i < 4; i++) {
consumers[i].compute_root();

if (i == 3) {
consumers[3].split(x, xo, xi, 8);
}

producers[i].compute_at(consumers[i], consumers[i].split_vars()[0]);

// Just check these schedules are all legal, by running each but not
// checking the output.
consumers[i].realize({10});
}

printf("Success!\n");
return 0;
}
Loading