-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ability to pass a user context in JIT mode #6313
Changes from 14 commits
0897d48
ea2dc06
ad7b47a
76cde70
0258ce8
0d5d63b
097f74a
9149827
59b7b8a
572c522
59561c6
0f3189c
5471942
aeb0c17
0b14ec0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -829,6 +829,14 @@ class Func { | |
Realization realize(std::vector<int32_t> sizes = {}, const Target &target = Target(), | ||
const ParamMap ¶m_map = ParamMap::empty_map()); | ||
|
||
/** Same as above, but takes a custom user-provided context to be | ||
* passed to runtime functions. This can be used to pass state to | ||
* runtime overrides in a thread-safe manner. */ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it legal to pass |
||
Realization realize(JITUserContext *context, | ||
std::vector<int32_t> sizes = {}, | ||
const Target &target = Target(), | ||
const ParamMap ¶m_map = ParamMap::empty_map()); | ||
|
||
/** Evaluate this function into an existing allocated buffer or | ||
* buffers. If the buffer is also one of the arguments to the | ||
* function, strange things may happen, as the pipeline isn't | ||
|
@@ -838,6 +846,14 @@ class Func { | |
void realize(Pipeline::RealizationArg outputs, const Target &target = Target(), | ||
const ParamMap ¶m_map = ParamMap::empty_map()); | ||
|
||
/** Same as above, but takes a custom user-provided context to be | ||
* passed to runtime functions. This can be used to pass state to | ||
* runtime overrides in a thread-safe manner. */ | ||
void realize(JITUserContext *context, | ||
Pipeline::RealizationArg outputs, | ||
const Target &target = Target(), | ||
const ParamMap ¶m_map = ParamMap::empty_map()); | ||
|
||
/** For a given size of output, or a given output buffer, | ||
* determine the bounds required of all unbound ImageParams | ||
* referenced. Communicates the result by allocating new buffers | ||
|
@@ -870,6 +886,18 @@ class Func { | |
const ParamMap ¶m_map = ParamMap::empty_map()); | ||
// @} | ||
|
||
/** Versions of infer_input_bounds that take a custom user context | ||
* to pass to runtime functions. */ | ||
// @{ | ||
void infer_input_bounds(JITUserContext *context, | ||
const std::vector<int32_t> &sizes, | ||
const Target &target = get_jit_target_from_environment(), | ||
const ParamMap ¶m_map = ParamMap::empty_map()); | ||
void infer_input_bounds(JITUserContext *context, | ||
Pipeline::RealizationArg outputs, | ||
const Target &target = get_jit_target_from_environment(), | ||
const ParamMap ¶m_map = ParamMap::empty_map()); | ||
// @} | ||
/** Statically compile this function to llvm bitcode, with the | ||
* given filename (which should probably end in .bc), type | ||
* signature, and C function name (which defaults to the same name | ||
|
@@ -1114,7 +1142,7 @@ class Func { | |
|
||
/** Get a struct containing the currently set custom functions | ||
* used by JIT. */ | ||
const Internal::JITHandlers &jit_handlers(); | ||
JITHandlers &jit_handlers(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This returns a mutable ref, so I assume it's ok to just mutate the contents... we should probably explicitly document the rules for doing so. (eg, when do changes I make take effect?) |
||
|
||
/** Add a custom pass to be used during lowering. It is run after | ||
* all other lowering passes. Can be used to verify properties of | ||
|
@@ -2585,28 +2613,40 @@ inline void assign_results(Realization &r, int idx, First first, Second second, | |
* expression. This can be thought of as a scalar version of | ||
* \ref Func::realize */ | ||
template<typename T> | ||
HALIDE_NO_USER_CODE_INLINE T evaluate(const Expr &e) { | ||
HALIDE_NO_USER_CODE_INLINE T evaluate(JITUserContext *ctx, const Expr &e) { | ||
user_assert(e.type() == type_of<T>()) | ||
<< "Can't evaluate expression " | ||
<< e << " of type " << e.type() | ||
<< " as a scalar of type " << type_of<T>() << "\n"; | ||
Func f; | ||
f() = e; | ||
Buffer<T> im = f.realize(); | ||
Buffer<T> im = f.realize(ctx); | ||
return im(); | ||
} | ||
|
||
/** evaluate with a default user context */ | ||
template<typename T> | ||
HALIDE_NO_USER_CODE_INLINE T evaluate(const Expr &e) { | ||
return evaluate<T>(nullptr, e); | ||
} | ||
|
||
/** JIT-compile and run enough code to evaluate a Halide Tuple. */ | ||
template<typename First, typename... Rest> | ||
HALIDE_NO_USER_CODE_INLINE void evaluate(Tuple t, First first, Rest &&...rest) { | ||
HALIDE_NO_USER_CODE_INLINE void evaluate(JITUserContext *ctx, Tuple t, First first, Rest &&...rest) { | ||
Internal::check_types<First, Rest...>(t, 0); | ||
|
||
Func f; | ||
f() = t; | ||
Realization r = f.realize(); | ||
Realization r = f.realize(ctx); | ||
Internal::assign_results(r, 0, first, rest...); | ||
} | ||
|
||
/** JIT-compile and run enough code to evaluate a Halide Tuple. */ | ||
template<typename First, typename... Rest> | ||
HALIDE_NO_USER_CODE_INLINE void evaluate(Tuple t, First first, Rest &&...rest) { | ||
evaluate<First, Rest...>(nullptr, std::move(t), std::forward<First>(first), std::forward<Rest...>(rest...)); | ||
} | ||
|
||
namespace Internal { | ||
|
||
inline void schedule_scalar(Func f) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Idly wondering if we could someday roll
param_map
into the JITUserContext...)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ParamMap holds the arguments to the realize call and thus is a completely separate concept from the user context. Arguably the user context could be passed in the ParamMap, but that doesn't seem an improvement to me. Normally Params are retrieved from global variables, but that is not thread safe. It is a bit of a silly design in the first place, but it is really convenient for the just hacking up Halide code so it is totally a thing... In order to pass an arbitrary set of arguments through realize to a JITted call, one has to use some sort of keyed dynamic data structure. That is what ParamMap is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, the runtime design I'm looking at does make the user context part of the arguments. But until the design doc is done, there's not much point in discussing it. If it goes through, it will reverse this change entirely. (The goal is to make the Halide compiler able to pass through a flexible contract from the outside caller to the runtime called from Halide generated code. This is done by possibly adding arguments to the runtime calls at codegen time.)