-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generic support for JIT compilation with custom NumPy ops #766
Comments
To clarify, the use case here is avoiding the need to write my own low-level wrappers for compiled code. For example, to reuse a wrapper from a library such as SciPy. In many cases SciPy's wrappers span hundreds of lines of code, e.g., for ARPACK. I certainly could write a Cython XLA wrapper for the particular routines that I need, but that's a lot of boilerplate -- C/Fortran interfaces tend to be much less clean than the wrappers SciPy exposes. |
Any update on this? |
The work-around for now is to not compile your entire model with As for the particular approach, The main challenge would be piping array shape/dtype information and Python functions into XLA's interface, and then reconstructing them on the other side. For keeping track of custom functions in particular, it would probably make sense to store Python functions in a global (weak reference?) dictionary and only pass around their |
Bump on this. We would like to be able to call python code like how So something like def f(a: np.array) -> np.array:
res = ...# some non-jax math
return res
typed_f = jax.give_signature(
function=f,
input=Shape((1,), jnp.float32),
output=Shape((1,), jnp.float32))
@jax.jit
def g(a)
return typed_f(a) # Should work assuming our earlier code worked |
I think we can consider this fixed by the new (experimental) version of import jax.experimental.host_callback as hcb
import jax.numpy as jnp
import jax
def myprint(x):
print('inside myprint:', type(x), x)
return x
@jax.jit
def device_fun(x):
return hcb.call(myprint, x, result_shape=x)
device_fun(jnp.arange(10)) Prints: Please give this a try and file new issues CCing @gnecula if you encounter any issues! |
It would be great to be able to
jit
functions that make use of custom CPU operation, i.e., implemented with NumPy arrays. This would be a really valuable extension point for integrating JAX with existing codes/algorithms, and would possibly solve the final remaining use-cases for autograd.Right now, you can use custom CPU operations if you don't
jit
, but that adds a very large amount of dispatch overhead.My understanding is this could be possible by making use of XLA's CustomCall support.
The text was updated successfully, but these errors were encountered: