Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to pass a user context in JIT mode #6313

Merged
merged 15 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions python_bindings/src/PyBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,11 +552,28 @@ void define_buffer(py::module &m) {
return b.device_sync(nullptr);
})

.def("copy_to_device", (int (Buffer<>::*)(const Target &)) & Buffer<>::copy_to_device, py::arg("target") = get_jit_target_from_environment())
.def("copy_to_device", (int (Buffer<>::*)(const DeviceAPI &, const Target &)) & Buffer<>::copy_to_device, py::arg("device_api"), py::arg("target") = get_jit_target_from_environment())
.def(
"copy_to_device", [](Buffer<> &b, const Target &t) -> int {
return b.copy_to_device(t);
},
py::arg("target") = get_jit_target_from_environment())

.def("device_malloc", (int (Buffer<>::*)(const Target &)) & Buffer<>::device_malloc, py::arg("target") = get_jit_target_from_environment())
.def("device_malloc", (int (Buffer<>::*)(const DeviceAPI &, const Target &)) & Buffer<>::device_malloc, py::arg("device_api"), py::arg("target") = get_jit_target_from_environment())
.def(
"copy_to_device", [](Buffer<> &b, const DeviceAPI &d, const Target &t) -> int {
return b.copy_to_device(d, t);
},
py::arg("device_api"), py::arg("target") = get_jit_target_from_environment())
.def(
"device_malloc", [](Buffer<> &b, const Target &t) -> int {
return b.device_malloc(t);
},
py::arg("target") = get_jit_target_from_environment())

.def(
"device_malloc", [](Buffer<> &b, const DeviceAPI &d, const Target &t) -> int {
return b.device_malloc(d, t);
},
py::arg("device_api"), py::arg("target") = get_jit_target_from_environment())

.def(
"set_min", [](Buffer<> &b, const std::vector<int> &mins) -> void {
Expand Down
6 changes: 3 additions & 3 deletions python_bindings/src/PyError.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ namespace PythonBindings {

namespace {

void halide_python_error(void *, const char *msg) {
void halide_python_error(JITUserContext *, const char *msg) {
throw Error(msg);
}

void halide_python_print(void *, const char *msg) {
void halide_python_print(JITUserContext *, const char *msg) {
py::print(msg, py::arg("end") = "");
}

Expand All @@ -31,7 +31,7 @@ void define_error(py::module &m) {
static HalidePythonCompileTimeErrorReporter reporter;
set_custom_compile_time_error_reporter(&reporter);

Halide::Internal::JITHandlers handlers;
Halide::JITHandlers handlers;
handlers.custom_error = halide_python_error;
handlers.custom_print = halide_python_print;
Halide::Internal::JITSharedRuntime::set_default_handlers(handlers);
Expand Down
6 changes: 3 additions & 3 deletions python_bindings/stub/PyStubImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ namespace {
// This seems redundant to the code in PyError.cpp, but is necessary
// in case the Stub builder links in a separate copy of libHalide, rather
// sharing the same halide.so that is built by default.
void halide_python_error(void *, const char *msg) {
void halide_python_error(JITUserContext *, const char *msg) {
throw Error(msg);
}

void halide_python_print(void *, const char *msg) {
void halide_python_print(JITUserContext *, const char *msg) {
py::print(msg, py::arg("end") = "");
}

Expand All @@ -59,7 +59,7 @@ void install_error_handlers(py::module &m) {
static HalidePythonCompileTimeErrorReporter reporter;
set_custom_compile_time_error_reporter(&reporter);

Halide::Internal::JITHandlers handlers;
Halide::JITHandlers handlers;
handlers.custom_error = halide_python_error;
handlers.custom_print = halide_python_print;
Halide::Internal::JITSharedRuntime::set_default_handlers(handlers);
Expand Down
22 changes: 12 additions & 10 deletions src/Buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ namespace Halide {
template<typename T = void>
class Buffer;

struct JITUserContext;

namespace Internal {

struct BufferContents {
Expand Down Expand Up @@ -577,32 +579,32 @@ class Buffer {
// @}

/** Copy to the GPU, using the device API that is the default for the given Target. */
int copy_to_device(const Target &t = get_jit_target_from_environment()) {
return copy_to_device(DeviceAPI::Default_GPU, t);
int copy_to_device(const Target &t = get_jit_target_from_environment(), JITUserContext *context = nullptr) {
return copy_to_device(DeviceAPI::Default_GPU, t, context);
}

/** Copy to the GPU, using the given device API */
int copy_to_device(const DeviceAPI &d, const Target &t = get_jit_target_from_environment()) {
return contents->buf.copy_to_device(get_device_interface_for_device_api(d, t, "Buffer::copy_to_device"));
int copy_to_device(const DeviceAPI &d, const Target &t = get_jit_target_from_environment(), JITUserContext *context = nullptr) {
return contents->buf.copy_to_device(get_device_interface_for_device_api(d, t, "Buffer::copy_to_device"), context);
}

/** Allocate on the GPU, using the device API that is the default for the given Target. */
int device_malloc(const Target &t = get_jit_target_from_environment()) {
return device_malloc(DeviceAPI::Default_GPU, t);
int device_malloc(const Target &t = get_jit_target_from_environment(), JITUserContext *context = nullptr) {
return device_malloc(DeviceAPI::Default_GPU, t, context);
}

/** Allocate storage on the GPU, using the given device API */
int device_malloc(const DeviceAPI &d, const Target &t = get_jit_target_from_environment()) {
return contents->buf.device_malloc(get_device_interface_for_device_api(d, t, "Buffer::device_malloc"));
int device_malloc(const DeviceAPI &d, const Target &t = get_jit_target_from_environment(), JITUserContext *context = nullptr) {
return contents->buf.device_malloc(get_device_interface_for_device_api(d, t, "Buffer::device_malloc"), context);
}

/** Wrap a native handle, using the given device API.
* It is a bad idea to pass DeviceAPI::Default_GPU to this routine
* as the handle argument must match the API that the default
* resolves to and it is clearer and more reliable to pass the
* resolved DeviceAPI explicitly. */
int device_wrap_native(const DeviceAPI &d, uint64_t handle, const Target &t = get_jit_target_from_environment()) {
return contents->buf.device_wrap_native(get_device_interface_for_device_api(d, t, "Buffer::device_wrap_native"), handle);
int device_wrap_native(const DeviceAPI &d, uint64_t handle, const Target &t = get_jit_target_from_environment(), JITUserContext *context = nullptr) {
return contents->buf.device_wrap_native(get_device_interface_for_device_api(d, t, "Buffer::device_wrap_native"), handle, context);
}
};

Expand Down
6 changes: 3 additions & 3 deletions src/DeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ bool host_supports_target_device(const Target &t) {
temp.fill(0);
temp.set_host_dirty();

Halide::Internal::JITHandlers handlers;
handlers.custom_error = [](void *user_context, const char *msg) {
Halide::JITHandlers handlers;
handlers.custom_error = [](JITUserContext *user_context, const char *msg) {
debug(1) << "host_supports_device_api: saw error (" << msg << ")\n";
};
Halide::Internal::JITHandlers old_handlers = Halide::Internal::JITSharedRuntime::set_default_handlers(handlers);
Halide::JITHandlers old_handlers = Halide::Internal::JITSharedRuntime::set_default_handlers(handlers);

int result = temp.copy_to_device(i);

Expand Down
39 changes: 35 additions & 4 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3089,17 +3089,32 @@ Realization Func::realize(std::vector<int32_t> sizes, const Target &target,
return pipeline().realize(std::move(sizes), target, param_map);
}

Realization Func::realize(JITUserContext *context,
std::vector<int32_t> sizes,
const Target &target,
const ParamMap &param_map) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Idly wondering if we could someday roll param_map into the JITUserContext...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ParamMap holds the arguments to the realize call and thus is a completely separate concept from the user context. Arguably the user context could be passed in the ParamMap, but that doesn't seem an improvement to me. Normally Params are retrieved from global variables, but that is not thread safe. It is a bit of a silly design in the first place, but it is really convenient for the just hacking up Halide code so it is totally a thing... In order to pass an arbitrary set of arguments through realize to a JITted call, one has to use some sort of keyed dynamic data structure. That is what ParamMap is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the runtime design I'm looking at does make the user context part of the arguments. But until the design doc is done, there's not much point in discussing it. If it goes through, it will reverse this change entirely. (The goal is to make the Halide compiler able to pass through a flexible contract from the outside caller to the runtime called from Halide generated code. This is done by possibly adding arguments to the runtime calls at codegen time.)

user_assert(defined()) << "Can't realize undefined Func.\n";
return pipeline().realize(context, std::move(sizes), target, param_map);
}

void Func::infer_input_bounds(const std::vector<int32_t> &sizes,
const Target &target,
const ParamMap &param_map) {
infer_input_bounds(nullptr, sizes, target, param_map);
}

void Func::infer_input_bounds(JITUserContext *context,
const std::vector<int32_t> &sizes,
const Target &target,
const ParamMap &param_map) {
user_assert(defined()) << "Can't infer input bounds on an undefined Func.\n";
vector<Buffer<>> outputs(func.outputs());
for (size_t i = 0; i < outputs.size(); i++) {
Buffer<> im(func.output_types()[i], nullptr, sizes);
outputs[i] = std::move(im);
}
Realization r(outputs);
infer_input_bounds(r, target, param_map);
infer_input_bounds(context, r, target, param_map);
}

OutputImageParam Func::output_buffer() const {
Expand Down Expand Up @@ -3280,20 +3295,36 @@ const vector<CustomLoweringPass> &Func::custom_lowering_passes() {
return pipeline().custom_lowering_passes();
}

const Internal::JITHandlers &Func::jit_handlers() {
JITHandlers &Func::jit_handlers() {
return pipeline().jit_handlers();
}

void Func::realize(Pipeline::RealizationArg outputs, const Target &target,
void Func::realize(Pipeline::RealizationArg outputs,
const Target &target,
const ParamMap &param_map) {
pipeline().realize(std::move(outputs), target, param_map);
}

void Func::infer_input_bounds(Pipeline::RealizationArg outputs, const Target &target,
void Func::realize(JITUserContext *context,
Pipeline::RealizationArg outputs,
const Target &target,
const ParamMap &param_map) {
pipeline().realize(context, std::move(outputs), target, param_map);
}

void Func::infer_input_bounds(Pipeline::RealizationArg outputs,
const Target &target,
const ParamMap &param_map) {
pipeline().infer_input_bounds(std::move(outputs), target, param_map);
}

void Func::infer_input_bounds(JITUserContext *context,
Pipeline::RealizationArg outputs,
const Target &target,
const ParamMap &param_map) {
pipeline().infer_input_bounds(context, std::move(outputs), target, param_map);
}

void Func::compile_jit(const Target &target) {
pipeline().compile_jit(target);
}
Expand Down
50 changes: 45 additions & 5 deletions src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,14 @@ class Func {
Realization realize(std::vector<int32_t> sizes = {}, const Target &target = Target(),
const ParamMap &param_map = ParamMap::empty_map());

/** Same as above, but takes a custom user-provided context to be
* passed to runtime functions. This can be used to pass state to
* runtime overrides in a thread-safe manner. */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it legal to pass nullptr for context? Should be documented.

Realization realize(JITUserContext *context,
std::vector<int32_t> sizes = {},
const Target &target = Target(),
const ParamMap &param_map = ParamMap::empty_map());

/** Evaluate this function into an existing allocated buffer or
* buffers. If the buffer is also one of the arguments to the
* function, strange things may happen, as the pipeline isn't
Expand All @@ -838,6 +846,14 @@ class Func {
void realize(Pipeline::RealizationArg outputs, const Target &target = Target(),
const ParamMap &param_map = ParamMap::empty_map());

/** Same as above, but takes a custom user-provided context to be
* passed to runtime functions. This can be used to pass state to
* runtime overrides in a thread-safe manner. */
void realize(JITUserContext *context,
Pipeline::RealizationArg outputs,
const Target &target = Target(),
const ParamMap &param_map = ParamMap::empty_map());

/** For a given size of output, or a given output buffer,
* determine the bounds required of all unbound ImageParams
* referenced. Communicates the result by allocating new buffers
Expand Down Expand Up @@ -870,6 +886,18 @@ class Func {
const ParamMap &param_map = ParamMap::empty_map());
// @}

/** Versions of infer_input_bounds that take a custom user context
* to pass to runtime functions. */
// @{
void infer_input_bounds(JITUserContext *context,
const std::vector<int32_t> &sizes,
const Target &target = get_jit_target_from_environment(),
const ParamMap &param_map = ParamMap::empty_map());
void infer_input_bounds(JITUserContext *context,
Pipeline::RealizationArg outputs,
const Target &target = get_jit_target_from_environment(),
const ParamMap &param_map = ParamMap::empty_map());
// @}
/** Statically compile this function to llvm bitcode, with the
* given filename (which should probably end in .bc), type
* signature, and C function name (which defaults to the same name
Expand Down Expand Up @@ -1114,7 +1142,7 @@ class Func {

/** Get a struct containing the currently set custom functions
* used by JIT. */
const Internal::JITHandlers &jit_handlers();
JITHandlers &jit_handlers();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This returns a mutable ref, so I assume it's ok to just mutate the contents... we should probably explicitly document the rules for doing so. (eg, when do changes I make take effect?)


/** Add a custom pass to be used during lowering. It is run after
* all other lowering passes. Can be used to verify properties of
Expand Down Expand Up @@ -2585,28 +2613,40 @@ inline void assign_results(Realization &r, int idx, First first, Second second,
* expression. This can be thought of as a scalar version of
* \ref Func::realize */
template<typename T>
HALIDE_NO_USER_CODE_INLINE T evaluate(const Expr &e) {
HALIDE_NO_USER_CODE_INLINE T evaluate(JITUserContext *ctx, const Expr &e) {
user_assert(e.type() == type_of<T>())
<< "Can't evaluate expression "
<< e << " of type " << e.type()
<< " as a scalar of type " << type_of<T>() << "\n";
Func f;
f() = e;
Buffer<T> im = f.realize();
Buffer<T> im = f.realize(ctx);
return im();
}

/** evaluate with a default user context */
template<typename T>
HALIDE_NO_USER_CODE_INLINE T evaluate(const Expr &e) {
return evaluate<T>(nullptr, e);
}

/** JIT-compile and run enough code to evaluate a Halide Tuple. */
template<typename First, typename... Rest>
HALIDE_NO_USER_CODE_INLINE void evaluate(Tuple t, First first, Rest &&...rest) {
HALIDE_NO_USER_CODE_INLINE void evaluate(JITUserContext *ctx, Tuple t, First first, Rest &&...rest) {
Internal::check_types<First, Rest...>(t, 0);

Func f;
f() = t;
Realization r = f.realize();
Realization r = f.realize(ctx);
Internal::assign_results(r, 0, first, rest...);
}

/** JIT-compile and run enough code to evaluate a Halide Tuple. */
template<typename First, typename... Rest>
HALIDE_NO_USER_CODE_INLINE void evaluate(Tuple t, First first, Rest &&...rest) {
evaluate<First, Rest...>(nullptr, std::move(t), std::forward<First>(first), std::forward<Rest...>(rest...));
}

namespace Internal {

inline void schedule_scalar(Func f) {
Expand Down
Loading