Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to pass a user context in JIT mode #6313

Merged
merged 15 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python_bindings/src/PyError.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void define_error(py::module &m) {
static HalidePythonCompileTimeErrorReporter reporter;
set_custom_compile_time_error_reporter(&reporter);

Halide::Internal::JITHandlers handlers;
Halide::JITHandlers handlers;
handlers.custom_error = halide_python_error;
handlers.custom_print = halide_python_print;
Halide::Internal::JITSharedRuntime::set_default_handlers(handlers);
Expand Down
2 changes: 1 addition & 1 deletion python_bindings/stub/PyStubImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void install_error_handlers(py::module &m) {
static HalidePythonCompileTimeErrorReporter reporter;
set_custom_compile_time_error_reporter(&reporter);

Halide::Internal::JITHandlers handlers;
Halide::JITHandlers handlers;
handlers.custom_error = halide_python_error;
handlers.custom_print = halide_python_print;
Halide::Internal::JITSharedRuntime::set_default_handlers(handlers);
Expand Down
6 changes: 3 additions & 3 deletions src/DeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ bool host_supports_target_device(const Target &t) {
temp.fill(0);
temp.set_host_dirty();

Halide::Internal::JITHandlers handlers;
handlers.custom_error = [](void *user_context, const char *msg) {
Halide::JITHandlers handlers;
handlers.custom_error = [](JITUserContext *user_context, const char *msg) {
debug(1) << "host_supports_device_api: saw error (" << msg << ")\n";
};
Halide::Internal::JITHandlers old_handlers = Halide::Internal::JITSharedRuntime::set_default_handlers(handlers);
Halide::JITHandlers old_handlers = Halide::Internal::JITSharedRuntime::set_default_handlers(handlers);

int result = temp.copy_to_device(i);

Expand Down
15 changes: 14 additions & 1 deletion src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3089,6 +3089,14 @@ Realization Func::realize(std::vector<int32_t> sizes, const Target &target,
return pipeline().realize(std::move(sizes), target, param_map);
}

Realization Func::realize(JITUserContext *context,
std::vector<int32_t> sizes,
const Target &target,
const ParamMap &param_map) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Idly wondering if we could someday roll param_map into the JITUserContext...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ParamMap holds the arguments to the realize call and thus is a completely separate concept from the user context. Arguably the user context could be passed in the ParamMap, but that doesn't seem an improvement to me. Normally Params are retrieved from global variables, but that is not thread safe. It is a bit of a silly design in the first place, but it is really convenient for the just hacking up Halide code so it is totally a thing... In order to pass an arbitrary set of arguments through realize to a JITted call, one has to use some sort of keyed dynamic data structure. That is what ParamMap is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the runtime design I'm looking at does make the user context part of the arguments. But until the design doc is done, there's not much point in discussing it. If it goes through, it will reverse this change entirely. (The goal is to make the Halide compiler able to pass through a flexible contract from the outside caller to the runtime called from Halide generated code. This is done by possibly adding arguments to the runtime calls at codegen time.)

user_assert(defined()) << "Can't realize undefined Func.\n";
return pipeline().realize(context, std::move(sizes), target, param_map);
}

void Func::infer_input_bounds(const std::vector<int32_t> &sizes,
const Target &target,
const ParamMap &param_map) {
Expand Down Expand Up @@ -3280,7 +3288,7 @@ const vector<CustomLoweringPass> &Func::custom_lowering_passes() {
return pipeline().custom_lowering_passes();
}

const Internal::JITHandlers &Func::jit_handlers() {
JITHandlers &Func::jit_handlers() {
return pipeline().jit_handlers();
}

Expand All @@ -3289,6 +3297,11 @@ void Func::realize(Pipeline::RealizationArg outputs, const Target &target,
pipeline().realize(std::move(outputs), target, param_map);
}

void Func::realize(JITUserContext *context, Pipeline::RealizationArg outputs, const Target &target,
const ParamMap &param_map) {
pipeline().realize(context, std::move(outputs), target, param_map);
}

void Func::infer_input_bounds(Pipeline::RealizationArg outputs, const Target &target,
const ParamMap &param_map) {
pipeline().infer_input_bounds(std::move(outputs), target, param_map);
Expand Down
16 changes: 15 additions & 1 deletion src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,13 @@ class Func {
Realization realize(std::vector<int32_t> sizes = {}, const Target &target = Target(),
const ParamMap &param_map = ParamMap::empty_map());

/** Same as above, but takes a custom user-provided context to be
* passed to runtime functions. */
Realization realize(JITUserContext *context,
std::vector<int32_t> sizes = {},
const Target &target = Target(),
const ParamMap &param_map = ParamMap::empty_map());

/** Evaluate this function into an existing allocated buffer or
* buffers. If the buffer is also one of the arguments to the
* function, strange things may happen, as the pipeline isn't
Expand All @@ -838,6 +845,13 @@ class Func {
void realize(Pipeline::RealizationArg outputs, const Target &target = Target(),
const ParamMap &param_map = ParamMap::empty_map());

/** Same as above, but takes a custom user-provided context to be
* passed to runtime functions. */
void realize(JITUserContext *context,
Pipeline::RealizationArg outputs,
const Target &target = Target(),
const ParamMap &param_map = ParamMap::empty_map());

/** For a given size of output, or a given output buffer,
* determine the bounds required of all unbound ImageParams
* referenced. Communicates the result by allocating new buffers
Expand Down Expand Up @@ -1114,7 +1128,7 @@ class Func {

/** Get a struct containing the currently set custom functions
* used by JIT. */
const Internal::JITHandlers &jit_handlers();
JITHandlers &jit_handlers();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This returns a mutable ref, so I assume it's ok to just mutate the contents... we should probably explicitly document the rules for doing so. (eg, when do changes I make take effect?)


/** Add a custom pass to be used during lowering. It is run after
* all other lowering passes. Can be used to verify properties of
Expand Down
57 changes: 24 additions & 33 deletions src/JITModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ void JITModule::compile_module(std::unique_ptr<llvm::Module> m, const string &fu
listeners.push_back(llvm::JITEventListener::createIntelJITEventListener());
}
// TODO: If this ever works in LLVM, this would allow profiling of JIT code with symbols with oprofile.
//listeners.push_back(llvm::createOProfileJITEventListener());
// listeners.push_back(llvm::createOProfileJITEventListener());

for (auto &listener : listeners) {
ee->RegisterJITEventListener(listener);
Expand Down Expand Up @@ -508,66 +508,59 @@ void merge_handlers(JITHandlers &base, const JITHandlers &addins) {
}
}

void print_handler(void *context, const char *msg) {
void print_handler(JITUserContext *context, const char *msg) {
if (context) {
JITUserContext *jit_user_context = (JITUserContext *)context;
(*jit_user_context->handlers.custom_print)(context, msg);
(*context->handlers.custom_print)(context, msg);
} else {
return (*active_handlers.custom_print)(context, msg);
}
}

void *malloc_handler(void *context, size_t x) {
void *malloc_handler(JITUserContext *context, size_t x) {
if (context) {
JITUserContext *jit_user_context = (JITUserContext *)context;
return (*jit_user_context->handlers.custom_malloc)(context, x);
return (*context->handlers.custom_malloc)(context, x);
} else {
return (*active_handlers.custom_malloc)(context, x);
}
}

void free_handler(void *context, void *ptr) {
void free_handler(JITUserContext *context, void *ptr) {
if (context) {
JITUserContext *jit_user_context = (JITUserContext *)context;
(*jit_user_context->handlers.custom_free)(context, ptr);
(*context->handlers.custom_free)(context, ptr);
} else {
(*active_handlers.custom_free)(context, ptr);
}
}

int do_task_handler(void *context, halide_task f, int idx,
int do_task_handler(JITUserContext *context, halide_task_t f, int idx,
uint8_t *closure) {
if (context) {
JITUserContext *jit_user_context = (JITUserContext *)context;
return (*jit_user_context->handlers.custom_do_task)(context, f, idx, closure);
return (*context->handlers.custom_do_task)(context, f, idx, closure);
} else {
return (*active_handlers.custom_do_task)(context, f, idx, closure);
}
}

int do_par_for_handler(void *context, halide_task f,
int do_par_for_handler(JITUserContext *context, halide_task_t f,
int min, int size, uint8_t *closure) {
if (context) {
JITUserContext *jit_user_context = (JITUserContext *)context;
return (*jit_user_context->handlers.custom_do_par_for)(context, f, min, size, closure);
return (*context->handlers.custom_do_par_for)(context, f, min, size, closure);
} else {
return (*active_handlers.custom_do_par_for)(context, f, min, size, closure);
}
}

void error_handler_handler(void *context, const char *msg) {
void error_handler_handler(JITUserContext *context, const char *msg) {
if (context) {
JITUserContext *jit_user_context = (JITUserContext *)context;
(*jit_user_context->handlers.custom_error)(context, msg);
(*context->handlers.custom_error)(context, msg);
} else {
(*active_handlers.custom_error)(context, msg);
}
}

int32_t trace_handler(void *context, const halide_trace_event_t *e) {
int32_t trace_handler(JITUserContext *context, const halide_trace_event_t *e) {
if (context) {
JITUserContext *jit_user_context = (JITUserContext *)context;
return (*jit_user_context->handlers.custom_trace)(context, e);
return (*context->handlers.custom_trace)(context, e);
} else {
return (*active_handlers.custom_trace)(context, e);
}
Expand All @@ -581,7 +574,7 @@ void *load_library_handler(const char *name) {
return (*active_handlers.custom_load_library)(name);
}

void *get_library_symbol_handler(void *lib, const char *name) {
void *get_library_symbol_handler(JITUserContext *lib, const char *name) {
return (*active_handlers.custom_get_library_symbol)(lib, name);
}

Expand Down Expand Up @@ -884,16 +877,14 @@ std::vector<JITModule> JITSharedRuntime::get(llvm::Module *for_module, const Tar
return result;
}

// TODO: Either remove user_context argument figure out how to make
// caller provided user context work with JIT. (At present, this
// cascaded handler calls cannot work with the right context as
// JITModule needs its context to be passed in case the called handler
// calls another callback which is not overriden by the caller.)
void JITSharedRuntime::init_jit_user_context(JITUserContext &jit_user_context,
void *user_context, const JITHandlers &handlers) {
jit_user_context.handlers = active_handlers;
jit_user_context.user_context = user_context;
merge_handlers(jit_user_context.handlers, handlers);
void JITSharedRuntime::populate_jit_handlers(JITUserContext *jit_user_context, const JITHandlers &handlers) {
// Take the active global handlers
JITHandlers merged = active_handlers;
// Clobber with any custom handlers set on the pipeline
merge_handlers(merged, handlers);
// Clobber with any custom handlers set on the call
merge_handlers(merged, jit_user_context->handlers);
jit_user_context->handlers = merged;
}

void JITSharedRuntime::release_all() {
Expand Down
50 changes: 29 additions & 21 deletions src/JITModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,34 @@ struct JITExtern;
struct Target;
class Module;

struct JITUserContext;

/** A set of custom overrides of runtime functions */
struct JITHandlers {
void (*custom_print)(JITUserContext *, const char *){nullptr};
void *(*custom_malloc)(JITUserContext *, size_t){nullptr};
void (*custom_free)(JITUserContext *, void *){nullptr};
int (*custom_do_task)(JITUserContext *, halide_task_t, int, uint8_t *){nullptr};
int (*custom_do_par_for)(JITUserContext *, halide_task_t, int, int, uint8_t *){nullptr};
void (*custom_error)(JITUserContext *, const char *){nullptr};
int32_t (*custom_trace)(JITUserContext *, const halide_trace_event_t *){nullptr};
void *(*custom_get_symbol)(const char *name){nullptr};
void *(*custom_load_library)(const char *name){nullptr};
void *(*custom_get_library_symbol)(JITUserContext *lib, const char *name){nullptr};
};

namespace Internal {
struct JITErrorBuffer;
}

/** A context to be passed to Pipeline::realize. Inherit from this to
* pass your own custom context object. Modify the handlers field to
* override runtime functions per-call to realize. */
struct JITUserContext {
Internal::JITErrorBuffer *error_buffer{nullptr};
JITHandlers handlers;
};

namespace Internal {

class JITModuleContents;
Expand Down Expand Up @@ -137,31 +165,11 @@ struct JITModule {
bool compiled() const;
};

typedef int (*halide_task)(void *user_context, int, uint8_t *);

struct JITHandlers {
void (*custom_print)(void *, const char *){nullptr};
void *(*custom_malloc)(void *, size_t){nullptr};
void (*custom_free)(void *, void *){nullptr};
int (*custom_do_task)(void *, halide_task, int, uint8_t *){nullptr};
int (*custom_do_par_for)(void *, halide_task, int, int, uint8_t *){nullptr};
void (*custom_error)(void *, const char *){nullptr};
int32_t (*custom_trace)(void *, const halide_trace_event_t *){nullptr};
void *(*custom_get_symbol)(const char *name){nullptr};
void *(*custom_load_library)(const char *name){nullptr};
void *(*custom_get_library_symbol)(void *lib, const char *name){nullptr};
};

struct JITUserContext {
void *user_context;
JITHandlers handlers;
};

class JITSharedRuntime {
public:
// Note only the first llvm::Module passed in here is used. The same shared runtime is used for all JIT.
static std::vector<JITModule> get(llvm::Module *m, const Target &target, bool create = true);
static void init_jit_user_context(JITUserContext &jit_user_context, void *user_context, const JITHandlers &handlers);
static void populate_jit_handlers(JITUserContext *jit_user_context, const JITHandlers &handlers);
static JITHandlers set_default_handlers(const JITHandlers &handlers);

/** Set the maximum number of bytes used by memoization caching.
Expand Down
Loading