-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FFI]: Add JEP for FFI #12632
[FFI]: Add JEP for FFI #12632
Conversation
@sharadmv — I haven't had a chance to really dig into this in detail, but I wanted to say that from my perspective this would be a huge quality of life improvement! |
Thank you for your feedback. Your guide has been and will continue to be a huge benefit to the community! |
@dfm we've been referring people to your tutorial for years (thank you!), and as good as it is, one founding objective of this JEP was to make such tutorials unnecessary (i.e. to make your life much easier). 😁 |
|
||
Unlike in the dfm guide, users are not constructing JAX primitives and therefore don’t have the opportunity to register transformation rules for those primitives. Do we want to expose them and if so, how? | ||
|
||
For automatic differentiation, users have the option of wrapping their `FFICall` with a `jax.custom_jvp` or `jax.custom_vjp`. Alternatively we could expose additional methods on `FFICall` that do something similar. The `jax.custom_*` (`custom_vmap`, `custom_transpose`, etc.) API, in principle, could also handle any custom behavior users want from FFI calls. However, this API has not been fully built out yet. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about multi-GPU and sharding?
I think this case should be supported as we know this is a current issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an important use case but can arguably be solved orthogonally via the WIP custom partitioning API. We're still working on its design but presumably we can surface a jax.custom_sharding
or something like it that will work with FFI calls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be great.
|
||
On CPU (for now), the `function_ptr` and `user_static_info` are passed by reference as the first argument to the custom call. On GPU, they can be passed by reference via the opaque string. | ||
|
||
Note that XLA custom calls support custom layouts for operands and results. Here we’ll generate MHLO that uses default layouts, which technically limits what users can express. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why you do not try to cover this?
Or this is planned for a next version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our thinking is:
- We can't support this easily without either implementing our own wrapper around XLA's layout API or exposing details of XLA itself
- Users can work around by either changing layouts manually or forgoing the FFI API and using the Custom Call API directly
In the long run, we might want to support custom layouts and that can be in a follow up version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the other hand, we could just provide an API to specify custom major-to-minor orders in Python (like [0, 1] or [1, 0]) and that would probably be sufficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. Asking to not use the FFI just for that looks a very high cost. So allowing this would be great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Peter pointed out that jaxlib uses custom layouts frequently so we should support this. Thankfully it only needs to exist in Python and doesn't require changing the C API.
Co-authored-by: Kuangyuan Chen <chky@google.com> Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
If now (late 2023) I wanted to start integrating a custom op (such as those from cuSolver/Magma etc) into Jax and define autodiffs, should I wait for this PR to land? |
The most up to date doc for JAX custom operation on GPU is: |
Now that #21925 is merged I think we can close this: we have an FFI now! |
Proposes a new API for calling out to foreign functions that wraps XLA's
CustomCall
API and JAX's primitive API.You can check out a prototype implementation of this JEP in #12396.
You can read the HTML version of this proposal here: https://jax--12632.org.readthedocs.build/en/12632/jep/12535-ffi.html.
cc: @dfm
Tracker: #12535