Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Backport #7352 to release/15.x) Fix Python error handling #7353

Merged
merged 1 commit into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions python_bindings/src/halide/halide_/PyError.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@ namespace PythonBindings {

namespace {

void halide_python_error(JITUserContext *, const char *msg) {
throw Error(msg);
}

void halide_python_print(JITUserContext *, const char *msg) {
py::gil_scoped_acquire acquire;
py::print(msg, py::arg("end") = "");
Expand All @@ -21,22 +17,36 @@ class HalidePythonCompileTimeErrorReporter : public CompileTimeErrorReporter {
}

void error(const char *msg) override {
// This method is called *only* from the Compiler -- never from jitted
// code -- so throwing an Error here is the right thing to do.

throw Error(msg);

// This method must not return!
}
};

} // namespace

PyJITUserContext::PyJITUserContext()
: JITUserContext() {
handlers.custom_print = halide_python_print;
// No: we don't want a custom error function.
// If we leave it as the default, realize() and infer_input_bounds()
// will correctly propagate the final error message to halide_runtime_error,
// which will throw an exception at the end of the relevant call.
//
// (It's tempting to override custom_error to just do 'throw Error',
// but when called from jitted code, it likely won't be able to find
// an enclosing C++ try block, meaning it could call std::terminate.)
//
// handlers.custom_error = halide_python_error;
}

void define_error(py::module &m) {
static HalidePythonCompileTimeErrorReporter reporter;
set_custom_compile_time_error_reporter(&reporter);

Halide::JITHandlers handlers;
handlers.custom_error = halide_python_error;
handlers.custom_print = halide_python_print;
Halide::Internal::JITSharedRuntime::set_default_handlers(handlers);

static py::exception<Error> halide_error(m, "HalideError");
py::register_exception_translator([](std::exception_ptr p) { // NOLINT
try {
Expand Down
4 changes: 4 additions & 0 deletions python_bindings/src/halide/halide_/PyError.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ namespace PythonBindings {

void define_error(py::module &m);

struct PyJITUserContext : public JITUserContext {
PyJITUserContext();
};

} // namespace PythonBindings
} // namespace Halide

Expand Down
24 changes: 17 additions & 7 deletions python_bindings/src/halide/halide_/PyFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <utility>

#include "PyBuffer.h"
#include "PyError.h"
#include "PyExpr.h"
#include "PyFuncRef.h"
#include "PyLoopLevel.h"
Expand Down Expand Up @@ -92,7 +93,8 @@ py::object evaluate_impl(const py::object &expr, bool may_gpu) {
{
py::gil_scoped_release release;

r = f.realize();
PyJITUserContext juc;
r = f.realize(&juc);
}
if (r->size() == 1) {
return buffer_getitem_operator((*r)[0], {});
Expand Down Expand Up @@ -144,7 +146,9 @@ void define_func(py::module &m) {
"realize",
[](Func &f, Buffer<> buffer, const Target &target) -> void {
py::gil_scoped_release release;
f.realize(buffer, target);

PyJITUserContext juc;
f.realize(&juc, buffer, target);
},
py::arg("dst"), py::arg("target") = Target())

Expand All @@ -160,7 +164,9 @@ void define_func(py::module &m) {
std::optional<Realization> r;
{
py::gil_scoped_release release;
r = f.realize(sizes, target);

PyJITUserContext juc;
r = f.realize(&juc, sizes, target);
}
return realization_to_object(*r);
},
Expand All @@ -171,7 +177,9 @@ void define_func(py::module &m) {
"realize",
[](Func &f, std::vector<Buffer<>> buffers, const Target &t) -> void {
py::gil_scoped_release release;
f.realize(Realization(std::move(buffers)), t);

PyJITUserContext juc;
f.realize(&juc, Realization(std::move(buffers)), t);
},
py::arg("dst"), py::arg("target") = Target())

Expand Down Expand Up @@ -361,26 +369,28 @@ void define_func(py::module &m) {
.def(
"infer_input_bounds", [](Func &f, const py::object &dst, const Target &target) -> void {
const Target t = to_jit_target(target);
PyJITUserContext juc;

// dst could be Buffer<>, vector<Buffer>, or vector<int>
try {
Buffer<> b = dst.cast<Buffer<>>();
f.infer_input_bounds(b, t);
f.infer_input_bounds(&juc, b, t);
return;
} catch (...) {
// fall thru
}

try {
std::vector<Buffer<>> v = dst.cast<std::vector<Buffer<>>>();
f.infer_input_bounds(Realization(std::move(v)), t);
f.infer_input_bounds(&juc, Realization(std::move(v)), t);
return;
} catch (...) {
// fall thru
}

try {
std::vector<int32_t> v = dst.cast<std::vector<int32_t>>();
f.infer_input_bounds(v, t);
f.infer_input_bounds(&juc, v, t);
return;
} catch (...) {
// fall thru
Expand Down
21 changes: 15 additions & 6 deletions python_bindings/src/halide/halide_/PyPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <utility>

#include "PyError.h"
#include "PyTuple.h"

namespace Halide {
Expand Down Expand Up @@ -187,7 +188,9 @@ void define_pipeline(py::module &m) {
.def(
"realize", [](Pipeline &p, Buffer<> buffer, const Target &target) -> void {
py::gil_scoped_release release;
p.realize(Realization(std::move(buffer)), target);

PyJITUserContext juc;
p.realize(&juc, Realization(std::move(buffer)), target);
},
py::arg("dst"), py::arg("target") = Target())

Expand All @@ -202,7 +205,9 @@ void define_pipeline(py::module &m) {
std::optional<Realization> r;
{
py::gil_scoped_release release;
r = p.realize(std::move(sizes), target);

PyJITUserContext juc;
r = p.realize(&juc, std::move(sizes), target);
}
return realization_to_object(*r);
},
Expand All @@ -212,33 +217,37 @@ void define_pipeline(py::module &m) {
.def(
"realize", [](Pipeline &p, std::vector<Buffer<>> buffers, const Target &t) -> void {
py::gil_scoped_release release;
p.realize(Realization(std::move(buffers)), t);

PyJITUserContext juc;
p.realize(&juc, Realization(std::move(buffers)), t);
},
py::arg("dst"), py::arg("target") = Target())

.def(
"infer_input_bounds", [](Pipeline &p, const py::object &dst, const Target &target) -> void {
const Target t = to_jit_target(target);
PyJITUserContext juc;

// dst could be Buffer<>, vector<Buffer>, or vector<int>
try {
Buffer<> b = dst.cast<Buffer<>>();
p.infer_input_bounds(b, t);
p.infer_input_bounds(&juc, b, t);
return;
} catch (...) {
// fall thru
}

try {
std::vector<Buffer<>> v = dst.cast<std::vector<Buffer<>>>();
p.infer_input_bounds(Realization(std::move(v)), t);
p.infer_input_bounds(&juc, Realization(std::move(v)), t);
return;
} catch (...) {
// fall thru
}

try {
std::vector<int32_t> v = dst.cast<std::vector<int32_t>>();
p.infer_input_bounds(v, t);
p.infer_input_bounds(&juc, v, t);
return;
} catch (...) {
// fall thru
Expand Down
17 changes: 0 additions & 17 deletions python_bindings/stub/PyStubImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,6 @@ using StubInputBuffer = Internal::StubInputBuffer<void>;

namespace {

// This seems redundant to the code in PyError.cpp, but is necessary
// in case the Stub builder links in a separate copy of libHalide, rather
// sharing the same halide.so that is built by default.
void halide_python_error(JITUserContext *, const char *msg) {
throw Halide::Error(msg);
}

void halide_python_print(JITUserContext *, const char *msg) {
py::gil_scoped_acquire acquire;
py::print(msg, py::arg("end") = "");
}

class HalidePythonCompileTimeErrorReporter : public CompileTimeErrorReporter {
public:
void warning(const char *msg) override {
Expand All @@ -57,11 +45,6 @@ void install_error_handlers(py::module &m) {
static HalidePythonCompileTimeErrorReporter reporter;
set_custom_compile_time_error_reporter(&reporter);

Halide::JITHandlers handlers;
handlers.custom_error = halide_python_error;
handlers.custom_print = halide_python_print;
Halide::Internal::JITSharedRuntime::set_default_handlers(handlers);

static py::object halide_error = py::module_::import("halide").attr("HalideError");
if (halide_error.is(py::none())) {
throw std::runtime_error("Could not find halide.HalideError");
Expand Down
3 changes: 2 additions & 1 deletion src/JITModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,8 @@ std::string JITErrorBuffer::str() const {

JITFuncCallContext::JITFuncCallContext(JITUserContext *context, const JITHandlers &pipeline_handlers)
: context(context) {
custom_error_handler = (context->handlers.custom_error != nullptr ||
custom_error_handler = ((context->handlers.custom_error != nullptr &&
context->handlers.custom_error != JITErrorBuffer::handler) ||
pipeline_handlers.custom_error != nullptr);
// Hook the error handler if not set
if (!custom_error_handler) {
Expand Down