Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pure callback with jax types hangs with jax > 0.4.31 #24255

Open
mfschubert opened this issue Oct 11, 2024 · 6 comments
Open

Pure callback with jax types hangs with jax > 0.4.31 #24255

mfschubert opened this issue Oct 11, 2024 · 6 comments
Labels
bug Something isn't working

Comments

@mfschubert
Copy link

Description

I am experiencing hanging for jax versions newer than 0.4.31, as referenced in an earlier issue that I created and subsequently closed (#24219). I managed to simplify the reproduction.

The issue seems to be related to jax calculations within a function called by pure_callback. The code below reproduces the issue.

import jax
print(f"jax_version={jax.__version__}")
import jax.numpy as jnp
import numpy as onp

def _eig_jax(matrix):
    """Eigendecomposition using `jax.numpy.linalg.eig`."""
    eigval, eigvec = jax.pure_callback(
        _eig_cpu,
        (
            jnp.ones(matrix.shape[:-1], dtype=complex),  # Eigenvalues
            jnp.ones(matrix.shape, dtype=complex),  # Eigenvectors
        ),
        matrix.astype(complex),
        vectorized=True,
    )
    return jnp.asarray(eigval), jnp.asarray(eigvec)

with jax.default_device(jax.devices("cpu")[0]):
    _eig_jax_cpu = jax.jit(jnp.linalg.eig)

def _eig_cpu(matrix):
  eigvals, eigvecs = _eig_jax_cpu(matrix)
  return onp.asarray(eigvals), onp.asarray(eigvecs)

# This loop hangs, typically at < 10 steps on a colab CPU runtime. Larger matrices
# cause the loop to hang at earlier steps.
for i in range(100):
    print(i)
    _eig_jax(jnp.ones((500, 500)))

The method of wrapping jnp.linalg.eig is one that has been successful for jax 0.4.31 and earlier, and has been brought up in discussions several times (#23079, #1259).

System info (python version, jaxlib version, accelerator, etc.)

Python 3.10
Jax 0.4.33
Colab CPU runtime

@mfschubert mfschubert added the bug Something isn't working label Oct 11, 2024
@mfschubert
Copy link
Author

OK, I am currently working around this via,

def _eig_jax(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Eigendecomposition using `jax.numpy.linalg.eig`."""
    if jax.devices()[0] == jax.devices("cpu")[0]:
        return jnp.linalg.eig(matrix)
    else:
        eigvals, eigvecs = jax.pure_callback(
            _eig_jax_cpu,
            (
                jnp.ones(matrix.shape[:-1], dtype=complex),  # Eigenvalues
                jnp.ones(matrix.shape, dtype=complex),  # Eigenvectors
            ),
            matrix.astype(complex),
            vectorized=True,
        )
        return jnp.asarray(eigvals), jnp.asarray(eigvecs)

This is fine for me, but it may be nice to have a graceful failure mode other than simply hanging.

@dfm
Copy link
Collaborator

dfm commented Oct 13, 2024

Thanks for reporting this! I'm not totally surprised that things don't go so well if we nest JAX functions inside of a pure_callback, but I agree that hanging isn't a good outcome!

I'll take a look into this because I have some ideas about what might be causing it. But, regardless, I would recommend avoiding the use of JAX functions within pure_callback. A better workaround than the one you came up with for this specific example would be to use numpy.linalg.eig instead of the JAX version:

def _eig_jax(matrix):
  matrix = matrix.astype(complex)
  return jax.pure_callback(
      np.linalg.eig,  # <-- np instead of jnp
      ...  # The rest of the arguments are the same as before
  )

This works on both GPU and CPU without relying on querying the first device. I expect the performance will be equivalent, without causing any hangs.

In the long run, I'm not sure we'll want to support including JAX code in the callback executed by pure_callback for reasons that are probably outside the scope of this discussion. jax.experimental.compute_on should eventually provide the needed API, but unfortunately it doesn't yet work on GPU.

@mfschubert
Copy link
Author

Thanks for your suggestion and I'll look forward to the new API.

Unfortunately, I have found performance to be quite different (also when using scipy.linalg.eig). This is so both on my machine and on colab, and so I suspect something other than my specific setup is responsible. But, I suppose is a different issue entirely.

@dfm
Copy link
Collaborator

dfm commented Oct 15, 2024

Unfortunately, I have found performance to be quite different (also when using scipy.linalg.eig).

Interesting! Can you say more about what performance differences you're seeing? I believe that the JAX CPU implementation of eig calls exactly the same LAPACK function that the scipy version does (JAX actually uses scipy to find LAPACK!), so I'm surprised that you would find significant performance differences. When I compare the performance of JAX on CPU with the version that uses a pure callback to scipy I actually get exactly the same performance.

Sample code run on CPU
@jax.jit
def eig_jax(x):
  x = x.astype(np.complex64)
  return jnp.linalg.eig(x)

@jax.jit
def eig_scipy(x):
  x = x.astype(np.complex64)
  eigvals = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
  return jax.pure_callback(scipy.linalg.eig, (eigvals, x), x)

Regardless, it will be great when jax.experimental.compute_on works on GPU, but I'm just surprised that you're finding significant performance differences!

@mfschubert
Copy link
Author

Sure, here is some benchmarking code that I ran on colab CPU. I avoided using %%timeit (which reports mean time) since free tier colab seems to be quite noisy.

import jax
import jax.numpy as jnp
import numpy as onp
import scipy
import time

onp.random.seed(0)
matrix = onp.random.randn(500, 500).astype(onp.float32)

times = {"scipy": onp.inf, "numpy": onp.inf, "jax": onp.inf}
for _ in range(20):
  start = time.time()
  onp.linalg.eig(matrix)
  times["numpy"] = min(time.time() - start, times["numpy"])

  start = time.time()
  scipy.linalg.eig(matrix)
  times["scipy"] = min(time.time() - start, times["scipy"])

  start = time.time()
  jax.block_until_ready(jnp.linalg.eig(matrix))
  times["jax"] = min(time.time() - start, times["jax"])

print(times)
# {'scipy': 0.2719097137451172, 'numpy': 0.41986513137817383, 'jax': 0.227125883102417}

I also observe speed differences on my local machine, but I figure that the odds of operator error on the setup are lower with colab. :-)

@kcdodd
Copy link

kcdodd commented Oct 17, 2024

I am thinking this is possibly related to some issues we are having with the new CPU client async dispatch. One issue is that, when the client chooses async, the initial call returns immediately (in the main thread) and the callback will happen in a separate/new thread created by the client. If the function does not return anything, or the return of the function is not used in some way that would block the main thread (print, array, block_until_ready, etc) nothing stops the main thread from running ahead to the end and exiting.

According to https://docs.python.org/3/library/threading.html

A thread can be flagged as a “daemon thread”. The significance of this flag is that the entire Python program exits when only daemon threads are left.

It also mentions (but not clear if applicable)

Daemon threads are abruptly stopped at shutdown.

Using the threading module in the callback, E.G.

thread = threading.current_thread()
print(f"{type(thread)}: {thread!r}")

will give something like

<class 'threading._DummyThread'>: <_DummyThread(Dummy-3, started daemon 128068297426496)>

The expected behavior, which seems to match what happens, is that when the main thread reaches the end it will start tearing down and calls everything registered with atexit, including jax._src.api.clean_up, jax._src.dispatch.wait_for_tokens. In my case at least it calls clean_up first, and when the callback happens JAX starts trying to re-initialize itself. I'm not exactly sure what all that means, but I can imagine an inconsistent state that is a mixture between teardown and startup.

If this is not what is happening in this issue, I can start a new one to document this information.

copybara-service bot pushed a commit that referenced this issue Nov 1, 2024
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_experimental_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_MAGMA_LIB_PATH` environment variable.

For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now.

PiperOrigin-RevId: 691072237
copybara-service bot pushed a commit that referenced this issue Nov 1, 2024
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_experimental_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_MAGMA_LIB_PATH` environment variable.

For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now.

PiperOrigin-RevId: 691072237
copybara-service bot pushed a commit that referenced this issue Nov 6, 2024
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_experimental_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_MAGMA_LIB_PATH` environment variable.

For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now.

PiperOrigin-RevId: 691072237
copybara-service bot pushed a commit that referenced this issue Nov 6, 2024
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now.

PiperOrigin-RevId: 691072237
copybara-service bot pushed a commit that referenced this issue Nov 6, 2024
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now.

PiperOrigin-RevId: 691072237
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants