Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a GPU implementation of lax.linalg.eig. #24663

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Commits on Nov 6, 2024

  1. Add a GPU implementation of lax.linalg.eig.

    This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately).
    
    This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)
    
    We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.
    
    For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now.
    
    PiperOrigin-RevId: 691072237
    dfm authored and Google-ML-Automation committed Nov 6, 2024
    Configuration menu
    Copy the full SHA
    35199e1 View commit details
    Browse the repository at this point in the history