Add a GPU implementation of lax.linalg.eig
.
#24663
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Add a GPU implementation of
lax.linalg.eig
.This feature has been in the queue for a long time (see #1259), and some folks have found that they can use
pure_callback
to call the CPU version as a workaround. It has recently come up that there can be issues when usingpure_callback
with JAX calls in the body (#24255; this should be investigated separately).This change adds a native solution for computing
lax.linalg.eig
on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on MAGMA can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the
jax_gpu_eig_magma
configuration variable is set to"on"
. By default, we try to dlopenlibmagma.so
, but the path to a non-standard installation location can be specified using theJAX_GPU_MAGMA_PATH
environment variable.For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now.