-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: specify jit static args via Static annotation #24705
base: main
Are you sure you want to change the base?
Conversation
How should this behave in the absence of an immediate @jax.jit
def f(x, square):
return g(x, square)
def g(x: jax.Array, square: Static[bool]):
return x ** 2 if square else x
print(f(2, True))
print(f(2, False)) |
My vision would be for the |
This looks great to me! It's definitely a step-up in readability from the way we currently annotate static arguments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like it!
jax/_src/api_util.py
Outdated
def static_argnames_from_annotations(fun: Callable[..., Any]) -> tuple[str, ...]: | ||
try: | ||
hints = get_type_hints(fun, include_extras=True) | ||
except (TypeError, ValueError, NameError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we should suppress NameError
. IIUC this can only happen if a forward reference is local or if a type annotation references an undefined name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got a NameError
when running the JAX test suite and this was my fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just concerned that we would ignore Static
annotations if we handle NameError
unconditionally. Would it be a hack to check if __annotations__
doesn't contain Static[
as a substring on NameError
and re-raise if that's the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I'll give that a try.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just pushed a commit that tries to do smarter error handling, with detailed comments. It's a bit messy to be honest... let me know what you think.
An idea inspired by chats with @superbobry.
Example:
The current way to define this would be