Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support (nonsymmetric) np.linalg.eig on GPU #1259

Open
clemisch opened this issue Aug 28, 2019 · 31 comments
Open

Support (nonsymmetric) np.linalg.eig on GPU #1259

clemisch opened this issue Aug 28, 2019 · 31 comments
Assignees
Labels
enhancement New feature or request NVIDIA GPU Issues specific to NVIDIA GPUs P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR

Comments

@clemisch
Copy link
Contributor

Dear jax team,

this is just a friendly bump on the implementation of eigendecomposition and batched SVD on GPU. Are you planning on implementing these?

Should I want to implement it myself, would I be able to do it with the primitives in jax.lax, or would I have to hook up a new part of cuSolver? I am willing to spend the time as I would benefit a lot from these features, but I have no experience with expanding jax and would not know where to look.

@mattjj
Copy link
Collaborator

mattjj commented Aug 28, 2019

Thanks for the ping! Are there open issues for these already?

@hawkinsp is the expert on these things and can provide the best advice, but for GPU implementations of linalg my understanding is we set up some wrappers in jaxlib, then set up backend-specific translation rules for the appropriate primitives in lax_linalg.py. As for adding batching specifically, I think we just need to make sure batch dimensions are plumbed through properly, which if the cusolver kernels themselves don't support batch dimensions might mean adding some kind of a loop over cusolver calls. It looks like Peter added batched triangular solve and LU decomposition for GPU in #1144, so that might provide hints for the plumbing needed.

What do you think? Questions welcome! I can only provide high-level pointers to the right places, but if we sniff around there I bet we'll find things.

@clemisch
Copy link
Contributor Author

Thanks for your quick response! I think there are already some issues concerning linear ops, but not specifically eigendecomp or batched SVD.

Also thanks a lot for the explanation! I'll try to get oriented and come here if I have questions.

@mattjj mattjj added the enhancement New feature or request label Aug 29, 2019
@hawkinsp
Copy link
Collaborator

hawkinsp commented Sep 3, 2019

@clemisch I can take a look at these if you aren't already working on them.

@clemisch
Copy link
Contributor Author

clemisch commented Sep 4, 2019

Hey @hawkinsp, thank you for getting back on this! Tbh I have not looked into this so far. It would be great if you could have a look too!

@hawkinsp
Copy link
Collaborator

hawkinsp commented Sep 5, 2019

PR #1314 adds batched SVD on CPU and GPU. On CPU or for large matrices on GPU it merely calls the current code in a loop. On GPU for small matrices it calls the batched Jacobi kernel from Cusolver.

Unfortunately np.linalg.eig is a little harder. I can add a "batched" implementation on CPU (simply looping over the batch elements.) However there is no support for non-symmetric eigendecomposition in Cusolver (batched or unbatched). If you really need this, then we'd need to add another dependency (probably MAGMA), which is a bunch more work. Does SVD and symmetric eigendecomposition satisfy you for now?

@hawkinsp hawkinsp changed the title np.linalg.eig and batched SVD Support (nonsymmetric) np.linalg.eig on GPU Sep 6, 2019
@hawkinsp
Copy link
Collaborator

hawkinsp commented Sep 6, 2019

I merged the PR that adds batched SVD support. You'll need to rebuild Jaxlib (or wait for us to make a release.)

I retitled the issue to reflect the open action item (nonsymmetric eigendecomposition on GPU).

@hawkinsp
Copy link
Collaborator

hawkinsp commented Sep 6, 2019

GPU Eigendecomposition via MAGMA might fall into the "contributions welcome" category, unless it proves to be a popular request.

@clemisch
Copy link
Contributor Author

Thank you @hawkinsp, this is great! Non-symmetric eigendecomposition is not very urgent for me, especially if it's so cumbersome to add to jax.

Concerning batched SVD I have a question about speed: In this little test I only see x4 speedup vs. single-core numpy. Is this expected?

import jax
import jax.numpy as np
import numpy as onp

x_host = onp.random.rand(100000, 3, 3).astype(onp.float32)
x_gpu = np.array(x_host)

svd_batch = jax.jit(jax.vmap(np.linalg.svd, 0, 0))

u1, s1, v1 = onp.linalg.svd(x_host)
u2, s2, v2 = np.linalg.svd(x_gpu)
u3, s3, v3 = svd_batch(x_gpu)

%timeit onp.linalg.svd(x_host)                       # 495 ms
%timeit np.linalg.svd(x_gpu)[0].block_until_ready()  # 122 ms
%timeit svd_batch(x_gpu)[1].block_until_ready()      # 123 ms

(sorry about the repost, I deleted the original comment by mistake)

@clemisch
Copy link
Contributor Author

Bump 😺

In this little test I only see x4 speedup vs. single-core numpy. Is this expected?

@hawkinsp
Copy link
Collaborator

I believe that's just how fast the NVidia's Cusolver batched jacobi implementation is. On my GPU, it seems we spend 99.9% of the time in the batched Jacobi kernel:

 GPU activities:
99.90%  3.20600s        12  267.17ms  233.68ms  305.82ms  void batched_svd_parallel_jacobi_32x16<float, float>(int, int, int, int, float*, unsigned long, int, float*, float*, unsigned long, int, float*, unsigned long, int, float, int, int*, float, int, int*, int, float)

The algorithm does have some tunable parameters that one might explore setting:
https://docs.nvidia.com/cuda/cusolver/index.html#cuds-lt-t-gt-gesvdjbatch

If you wanted to try that, I think you just need to call the functions that modify the Jacobi parameters at this line and then rebuild Jaxlib.
https://github.com/google/jax/blob/master/jaxlib/cusolver.cc#L731

@clemisch
Copy link
Contributor Author

Thank you very much for clarifying!

@mganahl
Copy link

mganahl commented Sep 18, 2020

Hi! Just popping up to ask if there is any progress regarding eig.
I'm currently preparing a JAX implementation of implicitly restarted arnoldi (non-symmetric operators). The working CPU implementation relies on jax.numpy.linalg.eig to compute eigenvalues of the Hessenberg matrix returned by Arnoldi. Would be great to have this run on GPU eventually.

@joncarter1
Copy link

Hey, thought I'd also express my desire for this, my use case being finding the poles of many auto-regressive models in parallel with np.roots. Thanks to all the contributors to JAX for where it already is, it's amazing.

@hawkinsp
Copy link
Collaborator

I'm curious how folks would feel about the following: suppose MAGMA were an optional dependency of JAX. i.e., we don't bundle it in jaxlib builds, but if you install it yourself (or perhaps via conda?) and JAX can find the shared library in your library path, then jnp.linalg.eig works on GPU.

(I'm a bit reluctant to bundle it with jaxlib unconditionally for just one function!)

@joncarter1
Copy link

I'd be totally fine with this. Could always be bundled in later down the line but as you say I feel the critical threshold for functional usage is perhaps a bit higher than one! :)

@apaszke apaszke added the P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR label Jun 2, 2021
@ianwilliamson
Copy link

+1 that support for GPU-backed eig would be great.

janEbert added a commit to janEbert/sdc-gym that referenced this issue Aug 25, 2021
@sudhakarsingh27 sudhakarsingh27 added the NVIDIA GPU Issues specific to NVIDIA GPUs label Aug 10, 2022
@drscook
Copy link

drscook commented Sep 29, 2022

+1 for GPU-support for nonsymmetric eig to allow GPU-enabled numpy.roots

@melsophos
Copy link

melsophos commented Oct 4, 2022

I also support strongly the implementation of this feature, in order to be able to use jnp.roots with GPU. I am training a network whose loss function requires computing roots of a polynomial, and training on CPU is really too slow.

@mfschubert
Copy link

I developed a workaround for my use case, which involves using the jax.experimental.host_callback module. Just sharing it in case it's useful.

def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""
    eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], complex)
    eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, complex)
    return host_callback.call(
        # We force this computation to be performed on the cpu by jit-ing and
        # explicitly specifying the device.
        jax.jit(jnp.linalg.eig, device=jax.devices("cpu")[0]),
        matrix.astype(complex),
        result_shape=[eigenvalues_shape, eigenvectors_shape],
    )

jax.jit(_eig_host, device=jax.devices("gpu")[0])(m)  # This works, we can jit on GPU.

@mfschubert
Copy link

A brief update to this: we have a slightly modified version of this which avoids the device specification in the call to jax.jit, which is the new recommended practice:

def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""
    eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], complex)
    eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, complex)

    def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # We force this computation to be performed on the cpu by jit-ing and
        # explicitly specifying the device.
        with jax.default_device(jax.devices("cpu")[0]):
            return jax.jit(jnp.linalg.eig)(matrix)

    return host_callback.call(
        _eig_cpu,
        matrix.astype(complex),
        result_shape=(eigenvalues_shape, eigenvectors_shape),
    )

@tsunhopang
Copy link

Hi there! I was just wondering if there has been any progress made on this particular issue. Since it is quite a common and essential function for scientific studies.

tsunhopang added a commit to tsunhopang/jester that referenced this issue Mar 27, 2024
@ju-kreber
Copy link

Note for anyone using the above workaround with the host callback nowadays: host_callback has been deprecated, use external callbacks instead (most probably pure_callback()). These also work nicely under function transformations.

@qiyang-ustc
Copy link

qiyang-ustc commented May 5, 2024

I have implemented (matrix-free) eigs in JAX for scientific purposes in jaxeigs. I have borrowed some code from TensorNetwork and performed Arnoldi decomposition on the GPU. However, I have kept the last step, which involves solving the eigenproblem in the projected Krylov space, implemented on the CPU (via callback) since the algorithms is divide and conquer thus not efficient on GPU.

I must admit that this code is currently extremely unstable, and the documentation is incomplete. Despite these limitations, it is functional for my own use.

@moskomule
Copy link

moskomule commented May 15, 2024

Note for anyone using the above workaround with the host callback nowadays: host_callback has been deprecated, use external callbacks instead (most probably pure_callback()). These also work nicely under function transformations.

As pure_callback does not seem to support fp64 at the moment, you need additional tricks (in case you are using fp32).

Looking forward to the implementation of eig on GPU.

def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], matrix.dtype)
    eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, matrix.dtype)

    def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # We force this computation to be performed on the cpu by jit-ing and
        # explicitly specifying the device.
        with jax.default_device(jax.devices("cpu")[0]):
            val, vec =  jax.jit(jnp.linalg.eig)(matrix)
            return (val.real, val.imag), (vec.real, vec.imag)

    val, vec = jax.pure_callback(_eig_cpu,
                                 ((eigenvalues_shape, eigenvalues_shape),
                                  (eigenvectors_shape, eigenvectors_shape)),
                                 matrix)
    return val[0] + 1j * val[1], vec[0] + 1j * vec[1]

@mfschubert
Copy link

As pure_callback does not seem to support fp64 at the moment, you need additional tricks (in case you are using fp32).

We don't seem to have issues supporting fp32 and fp64 with the following implementation in fmmax:

def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""

    def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # We force this computation to be performed on the cpu by jit-ing and
        # explicitly specifying the device.
        with jax.default_device(jax.devices("cpu")[0]):
            return jax.jit(jnp.linalg.eig)(matrix)

    return jax.pure_callback(
        _eig_cpu,
        (
            jnp.ones(matrix.shape[:-1], dtype=complex),  # Eigenvalues
            jnp.ones(matrix.shape, dtype=complex),  # Eigenvectors
        ),
        matrix.astype(complex),
        vectorized=True,
    )

@mfschubert
Copy link

I did some tests comparing eig performance for scipy, numpy, jax, and torch and found that they can differ quite a bit, with torch generally being the fastest. In lieu of a GPU-accelerated eig, simply using the torch version may be of benefit.

I also created a pip-installable jeig package which wraps all of these for use with jax. All implementations can be jit-compiled, including on machines with GPUs.

Here is an example of the performance difference I am seeing. This was generated on CPU colab, but torch comes out ahead also on my Apple and Intel machines. I didn't investigate the origin of the difference, but presumably there's a different linear algebra library being used in each of these packages.

image

@hawkinsp
Copy link
Collaborator

JAX just calls scipy's copy of LAPACK. You can probably accelerate it by installing e.g., Intel's MKL scipy.

Torch, as far as I know, also just calls LAPACK. It may be linking it with a different BLAS library; JAX will just be using openblas from scipy.

@gautierronan
Copy link

Hi, bumping this to ask if there is any plan from the jax team to implement this feature? @jakevdp

We'd also need this feature for dynamiqs, for the simulation of quantum systems in the so-called Floquet basis (time-periodic quantum systems).

Thanks!

@cwoolfo1
Copy link

Hello,
I was wondering if there has been any progress on this issue?

@hbmcmahan
Copy link

I'm also wondering if there are any updates. I have an application where I need to do jnp.linalg.eigvals on a small matrix (< 10x10) in a jitted context where I need GPUs for performance on the rest of the loss computation. I am computing gradients of this loss, and for this use it appears the pure_callback approach mentioned above does not work. I get an error like:

ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.

Is there a workaround that allows gradient computation?

@dfm dfm self-assigned this Oct 17, 2024
@mfschubert
Copy link

@hbmcmahan Check the fmmax code for an example custom eig implementation that supports gradient calculation. https://github.com/facebookresearch/fmmax/blob/6fd55729920b537225f54934ad9bb537900928b1/src/fmmax/utils.py#L112

copybara-service bot pushed a commit that referenced this issue Nov 1, 2024
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_experimental_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_MAGMA_LIB_PATH` environment variable.

For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now.

PiperOrigin-RevId: 691072237
copybara-service bot pushed a commit that referenced this issue Nov 1, 2024
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_experimental_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_MAGMA_LIB_PATH` environment variable.

For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now.

PiperOrigin-RevId: 691072237
copybara-service bot pushed a commit that referenced this issue Nov 6, 2024
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_experimental_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_MAGMA_LIB_PATH` environment variable.

For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now.

PiperOrigin-RevId: 691072237
copybara-service bot pushed a commit that referenced this issue Nov 6, 2024
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now.

PiperOrigin-RevId: 691072237
copybara-service bot pushed a commit that referenced this issue Nov 6, 2024
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now.

PiperOrigin-RevId: 691072237
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request NVIDIA GPU Issues specific to NVIDIA GPUs P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR
Projects
None yet
Development

No branches or pull requests