-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support (nonsymmetric) np.linalg.eig on GPU #1259
Comments
Thanks for the ping! Are there open issues for these already? @hawkinsp is the expert on these things and can provide the best advice, but for GPU implementations of linalg my understanding is we set up some wrappers in jaxlib, then set up backend-specific translation rules for the appropriate primitives in lax_linalg.py. As for adding batching specifically, I think we just need to make sure batch dimensions are plumbed through properly, which if the cusolver kernels themselves don't support batch dimensions might mean adding some kind of a loop over cusolver calls. It looks like Peter added batched triangular solve and LU decomposition for GPU in #1144, so that might provide hints for the plumbing needed. What do you think? Questions welcome! I can only provide high-level pointers to the right places, but if we sniff around there I bet we'll find things. |
Thanks for your quick response! I think there are already some issues concerning linear ops, but not specifically eigendecomp or batched SVD. Also thanks a lot for the explanation! I'll try to get oriented and come here if I have questions. |
@clemisch I can take a look at these if you aren't already working on them. |
Hey @hawkinsp, thank you for getting back on this! Tbh I have not looked into this so far. It would be great if you could have a look too! |
PR #1314 adds batched SVD on CPU and GPU. On CPU or for large matrices on GPU it merely calls the current code in a loop. On GPU for small matrices it calls the batched Jacobi kernel from Cusolver. Unfortunately |
I merged the PR that adds batched SVD support. You'll need to rebuild Jaxlib (or wait for us to make a release.) I retitled the issue to reflect the open action item (nonsymmetric eigendecomposition on GPU). |
GPU Eigendecomposition via MAGMA might fall into the "contributions welcome" category, unless it proves to be a popular request. |
Thank you @hawkinsp, this is great! Non-symmetric eigendecomposition is not very urgent for me, especially if it's so cumbersome to add to jax. Concerning batched SVD I have a question about speed: In this little test I only see x4 speedup vs. single-core numpy. Is this expected? import jax
import jax.numpy as np
import numpy as onp
x_host = onp.random.rand(100000, 3, 3).astype(onp.float32)
x_gpu = np.array(x_host)
svd_batch = jax.jit(jax.vmap(np.linalg.svd, 0, 0))
u1, s1, v1 = onp.linalg.svd(x_host)
u2, s2, v2 = np.linalg.svd(x_gpu)
u3, s3, v3 = svd_batch(x_gpu)
%timeit onp.linalg.svd(x_host) # 495 ms
%timeit np.linalg.svd(x_gpu)[0].block_until_ready() # 122 ms
%timeit svd_batch(x_gpu)[1].block_until_ready() # 123 ms (sorry about the repost, I deleted the original comment by mistake) |
Bump 😺
|
I believe that's just how fast the NVidia's Cusolver batched jacobi implementation is. On my GPU, it seems we spend 99.9% of the time in the batched Jacobi kernel:
The algorithm does have some tunable parameters that one might explore setting: If you wanted to try that, I think you just need to call the functions that modify the Jacobi parameters at this line and then rebuild Jaxlib. |
Thank you very much for clarifying! |
Hi! Just popping up to ask if there is any progress regarding |
Hey, thought I'd also express my desire for this, my use case being finding the poles of many auto-regressive models in parallel with np.roots. Thanks to all the contributors to JAX for where it already is, it's amazing. |
I'm curious how folks would feel about the following: suppose MAGMA were an optional dependency of JAX. i.e., we don't bundle it in jaxlib builds, but if you install it yourself (or perhaps via conda?) and JAX can find the shared library in your library path, then (I'm a bit reluctant to bundle it with jaxlib unconditionally for just one function!) |
I'd be totally fine with this. Could always be bundled in later down the line but as you say I feel the critical threshold for functional usage is perhaps a bit higher than one! :) |
+1 that support for GPU-backed |
+1 for GPU-support for nonsymmetric eig to allow GPU-enabled numpy.roots |
I also support strongly the implementation of this feature, in order to be able to use jnp.roots with GPU. I am training a network whose loss function requires computing roots of a polynomial, and training on CPU is really too slow. |
I developed a workaround for my use case, which involves using the
|
A brief update to this: we have a slightly modified version of this which avoids the device specification in the call to def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""
eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], complex)
eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, complex)
def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
# We force this computation to be performed on the cpu by jit-ing and
# explicitly specifying the device.
with jax.default_device(jax.devices("cpu")[0]):
return jax.jit(jnp.linalg.eig)(matrix)
return host_callback.call(
_eig_cpu,
matrix.astype(complex),
result_shape=(eigenvalues_shape, eigenvectors_shape),
) |
Hi there! I was just wondering if there has been any progress made on this particular issue. Since it is quite a common and essential function for scientific studies. |
Note for anyone using the above workaround with the host callback nowadays: |
I have implemented (matrix-free) eigs in JAX for scientific purposes in jaxeigs. I have borrowed some code from TensorNetwork and performed Arnoldi decomposition on the GPU. However, I have kept the last step, which involves solving the eigenproblem in the projected Krylov space, implemented on the CPU (via callback) since the algorithms is divide and conquer thus not efficient on GPU. I must admit that this code is currently extremely unstable, and the documentation is incomplete. Despite these limitations, it is functional for my own use. |
As Looking forward to the implementation of def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], matrix.dtype)
eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, matrix.dtype)
def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
# We force this computation to be performed on the cpu by jit-ing and
# explicitly specifying the device.
with jax.default_device(jax.devices("cpu")[0]):
val, vec = jax.jit(jnp.linalg.eig)(matrix)
return (val.real, val.imag), (vec.real, vec.imag)
val, vec = jax.pure_callback(_eig_cpu,
((eigenvalues_shape, eigenvalues_shape),
(eigenvectors_shape, eigenvectors_shape)),
matrix)
return val[0] + 1j * val[1], vec[0] + 1j * vec[1] |
We don't seem to have issues supporting fp32 and fp64 with the following implementation in fmmax: def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""
def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
# We force this computation to be performed on the cpu by jit-ing and
# explicitly specifying the device.
with jax.default_device(jax.devices("cpu")[0]):
return jax.jit(jnp.linalg.eig)(matrix)
return jax.pure_callback(
_eig_cpu,
(
jnp.ones(matrix.shape[:-1], dtype=complex), # Eigenvalues
jnp.ones(matrix.shape, dtype=complex), # Eigenvectors
),
matrix.astype(complex),
vectorized=True,
) |
I did some tests comparing I also created a pip-installable Here is an example of the performance difference I am seeing. This was generated on CPU colab, but torch comes out ahead also on my Apple and Intel machines. I didn't investigate the origin of the difference, but presumably there's a different linear algebra library being used in each of these packages. |
JAX just calls scipy's copy of LAPACK. You can probably accelerate it by installing e.g., Intel's MKL scipy. Torch, as far as I know, also just calls LAPACK. It may be linking it with a different BLAS library; JAX will just be using openblas from scipy. |
… is not supported on gpus, see jax-ml/jax#1259
Hello, |
I'm also wondering if there are any updates. I have an application where I need to do
Is there a workaround that allows gradient computation? |
@hbmcmahan Check the fmmax code for an example custom eig implementation that supports gradient calculation. https://github.com/facebookresearch/fmmax/blob/6fd55729920b537225f54934ad9bb537900928b1/src/fmmax/utils.py#L112 |
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately). This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.) We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_experimental_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_MAGMA_LIB_PATH` environment variable. For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now. PiperOrigin-RevId: 691072237
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately). This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.) We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_experimental_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_MAGMA_LIB_PATH` environment variable. For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now. PiperOrigin-RevId: 691072237
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately). This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.) We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_experimental_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_MAGMA_LIB_PATH` environment variable. For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now. PiperOrigin-RevId: 691072237
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately). This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.) We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable. For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now. PiperOrigin-RevId: 691072237
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately). This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.) We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable. For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now. PiperOrigin-RevId: 691072237
Dear jax team,
this is just a friendly bump on the implementation of eigendecomposition and batched SVD on GPU. Are you planning on implementing these?
Should I want to implement it myself, would I be able to do it with the primitives in
jax.lax
, or would I have to hook up a new part of cuSolver? I am willing to spend the time as I would benefit a lot from these features, but I have no experience with expanding jax and would not know where to look.The text was updated successfully, but these errors were encountered: