-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pure callback with jax types hangs with jax > 0.4.31 #24255
Comments
OK, I am currently working around this via, def _eig_jax(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Eigendecomposition using `jax.numpy.linalg.eig`."""
if jax.devices()[0] == jax.devices("cpu")[0]:
return jnp.linalg.eig(matrix)
else:
eigvals, eigvecs = jax.pure_callback(
_eig_jax_cpu,
(
jnp.ones(matrix.shape[:-1], dtype=complex), # Eigenvalues
jnp.ones(matrix.shape, dtype=complex), # Eigenvectors
),
matrix.astype(complex),
vectorized=True,
)
return jnp.asarray(eigvals), jnp.asarray(eigvecs) This is fine for me, but it may be nice to have a graceful failure mode other than simply hanging. |
Thanks for reporting this! I'm not totally surprised that things don't go so well if we nest JAX functions inside of a I'll take a look into this because I have some ideas about what might be causing it. But, regardless, I would recommend avoiding the use of JAX functions within def _eig_jax(matrix):
matrix = matrix.astype(complex)
return jax.pure_callback(
np.linalg.eig, # <-- np instead of jnp
... # The rest of the arguments are the same as before
) This works on both GPU and CPU without relying on querying the first device. I expect the performance will be equivalent, without causing any hangs. In the long run, I'm not sure we'll want to support including JAX code in the callback executed by |
Thanks for your suggestion and I'll look forward to the new API. Unfortunately, I have found performance to be quite different (also when using |
Interesting! Can you say more about what performance differences you're seeing? I believe that the JAX CPU implementation of Sample code run on CPU@jax.jit
def eig_jax(x):
x = x.astype(np.complex64)
return jnp.linalg.eig(x)
@jax.jit
def eig_scipy(x):
x = x.astype(np.complex64)
eigvals = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
return jax.pure_callback(scipy.linalg.eig, (eigvals, x), x) Regardless, it will be great when |
Sure, here is some benchmarking code that I ran on colab CPU. I avoided using import jax
import jax.numpy as jnp
import numpy as onp
import scipy
import time
onp.random.seed(0)
matrix = onp.random.randn(500, 500).astype(onp.float32)
times = {"scipy": onp.inf, "numpy": onp.inf, "jax": onp.inf}
for _ in range(20):
start = time.time()
onp.linalg.eig(matrix)
times["numpy"] = min(time.time() - start, times["numpy"])
start = time.time()
scipy.linalg.eig(matrix)
times["scipy"] = min(time.time() - start, times["scipy"])
start = time.time()
jax.block_until_ready(jnp.linalg.eig(matrix))
times["jax"] = min(time.time() - start, times["jax"])
print(times)
# {'scipy': 0.2719097137451172, 'numpy': 0.41986513137817383, 'jax': 0.227125883102417} I also observe speed differences on my local machine, but I figure that the odds of operator error on the setup are lower with colab. :-) |
I am thinking this is possibly related to some issues we are having with the new CPU client async dispatch. One issue is that, when the client chooses async, the initial call returns immediately (in the main thread) and the callback will happen in a separate/new thread created by the client. If the function does not return anything, or the return of the function is not used in some way that would block the main thread (print, array, block_until_ready, etc) nothing stops the main thread from running ahead to the end and exiting. According to https://docs.python.org/3/library/threading.html
It also mentions (but not clear if applicable)
Using the threading module in the callback, E.G. thread = threading.current_thread()
print(f"{type(thread)}: {thread!r}") will give something like
The expected behavior, which seems to match what happens, is that when the main thread reaches the end it will start tearing down and calls everything registered with atexit, including If this is not what is happening in this issue, I can start a new one to document this information. |
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately). This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.) We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_experimental_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_MAGMA_LIB_PATH` environment variable. For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now. PiperOrigin-RevId: 691072237
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately). This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.) We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_experimental_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_MAGMA_LIB_PATH` environment variable. For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now. PiperOrigin-RevId: 691072237
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately). This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.) We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_experimental_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_MAGMA_LIB_PATH` environment variable. For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now. PiperOrigin-RevId: 691072237
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately). This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.) We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable. For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now. PiperOrigin-RevId: 691072237
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately). This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.) We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable. For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now. PiperOrigin-RevId: 691072237
Description
I am experiencing hanging for jax versions newer than 0.4.31, as referenced in an earlier issue that I created and subsequently closed (#24219). I managed to simplify the reproduction.
The issue seems to be related to jax calculations within a function called by
pure_callback
. The code below reproduces the issue.The method of wrapping
jnp.linalg.eig
is one that has been successful for jax 0.4.31 and earlier, and has been brought up in discussions several times (#23079, #1259).System info (python version, jaxlib version, accelerator, etc.)
Python 3.10
Jax 0.4.33
Colab CPU runtime
The text was updated successfully, but these errors were encountered: