Skip to content

Commit

Permalink
Allow PyPipeline and PyFunc to realize() scalar buffers (halide#6674)
Browse files Browse the repository at this point in the history
  • Loading branch information
steven-johnson authored and ardier committed Mar 3, 2024
1 parent a0f6012 commit 1d2c8bf
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 17 deletions.
3 changes: 3 additions & 0 deletions python_bindings/correctness/complexstub_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ComplexStub : public Halide::Generator<ComplexStub> {
Output<Buffer<float, 3>> typed_buffer_output{"typed_buffer_output"};
Output<Buffer<void, 3>> untyped_buffer_output{"untyped_buffer_output"};
Output<Buffer<uint8_t, 3>> static_compiled_buffer_output{"static_compiled_buffer_output"};
Output<float> scalar_output{"scalar_output"};

void configure() {
// Pointers returned by add_input() are managed by the Generator;
Expand Down Expand Up @@ -72,6 +73,8 @@ class ComplexStub : public Halide::Generator<ComplexStub> {
static_compiled_buffer_output = static_compiled_buffer;

(*extra_func_output)(x, y) = cast<double>((*extra_func_input)(x, y, 0) + 1);

scalar_output() = float_arg + int_arg;
}

void schedule() {
Expand Down
5 changes: 5 additions & 0 deletions python_bindings/correctness/pystub.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def test_complexstub():
typed_buffer_output,
untyped_buffer_output,
static_compiled_buffer_output,
scalar_output,
extra_func_output) = r

b = simple_output.realize([32, 32, 3], target)
Expand Down Expand Up @@ -249,6 +250,10 @@ def test_complexstub():
actual = b[x, y, c]
assert expected == actual, "Expected %s Actual %s" % (expected, actual)

b = scalar_output.realize([], target)
assert b.type() == hl.Float(32)
assert b[()] == 34.25

b = extra_func_output.realize([32, 32], target)
assert b.type() == hl.Float(64)
for x in range(32):
Expand Down
24 changes: 15 additions & 9 deletions python_bindings/src/PyFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,12 @@ void define_func(py::module &m) {
},
py::arg("dst"), py::arg("target") = Target())

// This will actually allow a list-of-buffers as well as a tuple-of-buffers, but that's OK.
.def(
"realize",
[](Func &f, std::vector<Buffer<>> buffers, const Target &t) -> void {
py::gil_scoped_release release;
f.realize(Realization(buffers), t);
},
py::arg("dst"), py::arg("target") = Target())

// It's important to have this overload of realize() go first:
// passing an empty list [] is ambiguous in Python, and could match to
// either list-of-sizes or list-of-buffers... but the former is useful
// (it allows realizing a 0-dimensional/scalar buffer) and the former is
// not (it will always assert-fail). Putting this one first allows it to
// be the first one chosen by the bindings in this case.
.def(
"realize",
[](Func &f, const std::vector<int32_t> &sizes, const Target &target) -> py::object {
Expand All @@ -143,6 +140,15 @@ void define_func(py::module &m) {
},
py::arg("sizes") = std::vector<int32_t>{}, py::arg("target") = Target())

// This will actually allow a list-of-buffers as well as a tuple-of-buffers, but that's OK.
.def(
"realize",
[](Func &f, std::vector<Buffer<>> buffers, const Target &t) -> void {
py::gil_scoped_release release;
f.realize(Realization(buffers), t);
},
py::arg("dst"), py::arg("target") = Target())

.def("defined", &Func::defined)
.def("name", &Func::name)
.def("dimensions", &Func::dimensions)
Expand Down
22 changes: 14 additions & 8 deletions python_bindings/src/PyPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,12 @@ void define_pipeline(py::module &m) {
},
py::arg("dst"), py::arg("target") = Target())

// This will actually allow a list-of-buffers as well as a tuple-of-buffers, but that's OK.
.def(
"realize", [](Pipeline &p, std::vector<Buffer<>> buffers, const Target &t) -> void {
py::gil_scoped_release release;
p.realize(Realization(buffers), t);
},
py::arg("dst"), py::arg("target") = Target())

// It's important to have this overload of realize() go first:
// passing an empty list [] is ambiguous in Python, and could match to
// either list-of-sizes or list-of-buffers... but the former is useful
// (it allows realizing a 0-dimensional/scalar buffer) and the former is
// not (it will always assert-fail). Putting this one first allows it to
// be the first one chosen by the bindings in this case.
.def(
"realize", [](Pipeline &p, std::vector<int32_t> sizes, const Target &target) -> py::object {
std::optional<Realization> r;
Expand All @@ -120,6 +118,14 @@ void define_pipeline(py::module &m) {
},
py::arg("sizes") = std::vector<int32_t>{}, py::arg("target") = Target())

// This will actually allow a list-of-buffers as well as a tuple-of-buffers, but that's OK.
.def(
"realize", [](Pipeline &p, std::vector<Buffer<>> buffers, const Target &t) -> void {
py::gil_scoped_release release;
p.realize(Realization(buffers), t);
},
py::arg("dst"), py::arg("target") = Target())

.def(
"infer_input_bounds", [](Pipeline &p, const py::object &dst, const Target &target) -> void {
// dst could be Buffer<>, vector<Buffer>, or vector<int>
Expand Down

0 comments on commit 1d2c8bf

Please sign in to comment.