Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to pass a user context in JIT mode #6313

Merged
merged 15 commits into from
Oct 20, 2021

Conversation

abadams
Copy link
Member

@abadams abadams commented Oct 12, 2021

JIT mode has long hijacked user_context for its own purposes. It uses it to store overrides of runtime functions, and it uses it to store an error buffer for accumulating error messages. This means user_context isn't available in JIT mode, which leaves no good way to pass state to custom overrides in a thread-safe manner.

This PR promotes the existing JITUserContext struct to the public namespace, and lets you pass one per call to realize, instead of hiding it inside the realize implementation. By inheriting from this struct and passing a subclass in as the context pointer, you can pass additional state to your runtime overrides! This is demonstrated in the new test.

Runtime overrides in JIT mode will now be expected to take a JITUserContext * as the first arg instead of a void *. The existing set_custom_foo methods that expect function pointers with void * first args are left in place, but I think we should deprecate them and add new ones that expect a JITUserContext *.

I didn't make any of the base class (JITUserContext) state private or anything, so as a side-effect, this means that handlers or error buffers can be overridden per call to realize now in a thread-safe manner.

@steven-johnson steven-johnson added the release_notes For changes that may warrant a note in README for official releases. label Oct 12, 2021
Copy link
Contributor

@steven-johnson steven-johnson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about evaluate(), compile_jit(), and anything else that does jitting implicitly?

this means that handlers or error buffers can be overridden per call to realize now in a thread-safe manner

This should be explicitly documented.

This PR almost certainly will need testing for WebAssembly, since it's a weird JIT case.

On the whole I don't see anything wrong with this, and I should be happy to expand the ability of JIT in this way, but part of me feels like if we're going to be attacking the limitations of the JIT runtime situation, it would be nice to find a way to re-hook everything in a generic way, but that's definitely scope creep...

(Withholding approval pending green buildbots)

@@ -349,6 +349,8 @@ class Pipeline {
*/
void compile_jit(const Target &target = get_jit_target_from_environment());

// TODO: deprecate all of these and replace with versions that take a JITUserContext
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to land this PR, we absolutely deprecate these as part of it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about making new ones that accept any function that takes a T * as the first arg, where T * must be implicitly convertible to a JITUserContext *? Then you could have signatures that accept a derived class.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need a template for that. Given struct Derived : public JITUserContext, passing a Derived* should be acceptable anywhere a JITUserContext* is.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, but I don't want to pass a Derived, I want to pass a function pointer that takes a Derived * as the first argument.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right. Well... meh. Adding more template usage to our public API is something I'm not wild about -- what would it look like?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm honestly not too wild about the idea now, because it would only be correct if the user passes a context object of the matching type to the next realize call. If they don't it's an implicit downcast to the wrong type. Let's just add versions that take function pointers that take JITUserContext *

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Builds are green. I'm inclined to deprecate these as a follow-up PR, because that will involve a lot of code tweaking in tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inclined to deprecate these as a follow-up PR

SGTM, re-reviewing now

@abadams
Copy link
Member Author

abadams commented Oct 12, 2021

compile_jit doesn't need to know, because this just changes the handlers struct that gets used at runtime, but evaluate and infer_input_bounds conceivably do. If there's agreement that this approach is sound, I'll do that.

I agree it would be good to make hooking generic instead of there being a blessed set of overridable functions, but I wanted to defer that to a later PR because the design is a little tricky. This PR is just exposing an existing thing.

@mzient
Copy link

mzient commented Oct 13, 2021

Now, if we customize the JITUserContext and provide custom handlers for things like obtaining a custom device context/stream/..., how do we do the same in Buffer::copy_to_device and similar functions?

@abadams
Copy link
Member Author

abadams commented Oct 13, 2021

Buffer::copy_to_device and similar all take a user_context argument as the first arg. Passing something derived from JITUserContext should work. I'll write a quick test.

@abadams
Copy link
Member Author

abadams commented Oct 13, 2021

Yeah, that seems to work fine.

Copy link
Contributor

@steven-johnson steven-johnson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with nits regarding documentation.

(If the followup deprecation work isn't going to happen right away, we should create a tracking issue so it isn't overlooked.)

Realization Func::realize(JITUserContext *context,
std::vector<int32_t> sizes,
const Target &target,
const ParamMap &param_map) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Idly wondering if we could someday roll param_map into the JITUserContext...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ParamMap holds the arguments to the realize call and thus is a completely separate concept from the user context. Arguably the user context could be passed in the ParamMap, but that doesn't seem an improvement to me. Normally Params are retrieved from global variables, but that is not thread safe. It is a bit of a silly design in the first place, but it is really convenient for the just hacking up Halide code so it is totally a thing... In order to pass an arbitrary set of arguments through realize to a JITted call, one has to use some sort of keyed dynamic data structure. That is what ParamMap is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the runtime design I'm looking at does make the user context part of the arguments. But until the design doc is done, there's not much point in discussing it. If it goes through, it will reverse this change entirely. (The goal is to make the Halide compiler able to pass through a flexible contract from the outside caller to the runtime called from Halide generated code. This is done by possibly adding arguments to the runtime calls at codegen time.)

@@ -1114,7 +1142,7 @@ class Func {

/** Get a struct containing the currently set custom functions
* used by JIT. */
const Internal::JITHandlers &jit_handlers();
JITHandlers &jit_handlers();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This returns a mutable ref, so I assume it's ok to just mutate the contents... we should probably explicitly document the rules for doing so. (eg, when do changes I make take effect?)

src/Func.h Outdated
@@ -829,6 +829,14 @@ class Func {
Realization realize(std::vector<int32_t> sizes = {}, const Target &target = Target(),
const ParamMap &param_map = ParamMap::empty_map());

/** Same as above, but takes a custom user-provided context to be
* passed to runtime functions. This can be used to pass state to
* runtime overrides in a thread-safe manner. */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it legal to pass nullptr for context? Should be documented.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
release_notes For changes that may warrant a note in README for official releases.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants