Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic support for JIT compilation with custom NumPy ops #766

Closed
shoyer opened this issue May 23, 2019 · 5 comments
Closed

Generic support for JIT compilation with custom NumPy ops #766

shoyer opened this issue May 23, 2019 · 5 comments
Labels
enhancement New feature or request

Comments

@shoyer
Copy link
Collaborator

shoyer commented May 23, 2019

It would be great to be able to jit functions that make use of custom CPU operation, i.e., implemented with NumPy arrays. This would be a really valuable extension point for integrating JAX with existing codes/algorithms, and would possibly solve the final remaining use-cases for autograd.

Right now, you can use custom CPU operations if you don't jit, but that adds a very large amount of dispatch overhead.

My understanding is this could be possible by making use of XLA's CustomCall support.

@shoyer shoyer changed the title jit with custom CPU ops JIT compilation with custom CPU ops May 23, 2019
@shoyer shoyer changed the title JIT compilation with custom CPU ops JIT compilation with custom NumPy ops May 24, 2019
@shoyer
Copy link
Collaborator Author

shoyer commented May 24, 2019

To clarify, the use case here is avoiding the need to write my own low-level wrappers for compiled code. For example, to reuse a wrapper from a library such as SciPy. In many cases SciPy's wrappers span hundreds of lines of code, e.g., for ARPACK. I certainly could write a Cython XLA wrapper for the particular routines that I need, but that's a lot of boilerplate -- C/Fortran interfaces tend to be much less clean than the wrappers SciPy exposes.

@hawkinsp hawkinsp added the enhancement New feature or request label Jun 25, 2019
@jonasrauber
Copy link
Contributor

Any update on this?
Is there maybe a workaround?

@shoyer
Copy link
Collaborator Author

shoyer commented Dec 3, 2019

The work-around for now is to not compile your entire model with jit. That works fine in many cases, but isn't entirely satisfactory.

As for the particular approach, lapack.pyx is probably a good model here, but this could also be written in pure C++:
https://github.com/google/jax/blob/5b6c9325ed47b29d9182b0480206ba15b5787500/jaxlib/lapack.pyx

The main challenge would be piping array shape/dtype information and Python functions into XLA's interface, and then reconstructing them on the other side. For keeping track of custom functions in particular, it would probably make sense to store Python functions in a global (weak reference?) dictionary and only pass around their id() into the XLA CustomCall.

@shoyer shoyer changed the title JIT compilation with custom NumPy ops Generic support for JIT compilation with custom NumPy ops Aug 4, 2020
@chaserileyroberts
Copy link
Contributor

Bump on this.

We would like to be able to call python code like how tf.py_func works. I understand that type/shape inference is one of the blockers here, but having some kind of "manual" typing would work for us.

So something like

def f(a: np.array) -> np.array:
  res = ...# some non-jax math
  return res

typed_f = jax.give_signature(
  function=f,
  input=Shape((1,), jnp.float32), 
  output=Shape((1,), jnp.float32))

@jax.jit
def g(a)
  return typed_f(a) # Should work assuming our earlier code worked

@shoyer
Copy link
Collaborator Author

shoyer commented Jan 31, 2021

I think we can consider this fixed by the new (experimental) version of host_callback.call, e.g.,

import jax.experimental.host_callback as hcb
import jax.numpy as jnp
import jax

def myprint(x):
  print('inside myprint:', type(x), x)
  return x

@jax.jit
def device_fun(x):
  return hcb.call(myprint, x, result_shape=x)

device_fun(jnp.arange(10))

Prints: inside myprint: <class 'numpy.ndarray'> [0 1 2 3 4 5 6 7 8 9]

Please give this a try and file new issues CCing @gnecula if you encounter any issues!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants