Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding jnp.identity and jnp.matmul raises XLA warning and affects performance. #19885

Closed
tillahoffmann opened this issue Feb 20, 2024 · 4 comments
Labels
bug Something isn't working

Comments

@tillahoffmann
Copy link
Contributor

tillahoffmann commented Feb 20, 2024

Description

Under specific circumstances, my jitted function raises a warning W external/xla/xla/service/cpu/onednn_matmul.cc:172] [Perf]: MatMul reference implementation being executed.

>>> import jax
>>> from jax import numpy as jnp

>>> @jax.jit
>>> def func_with_warning(y):
...    return jnp.identity(y.shape[-1]) + jnp.matmul(y, y)

>>> func_with_warning(jnp.ones((2, 100, 100))).shape
(2, 100, 100)
2024-02-20 02:31:15.805823: W external/xla/xla/service/cpu/onednn_matmul.cc:172] [Perf]: MatMul reference implementation being executed

The warning is only raised for this specific setup. Turning one of many knobs eliminates the warning. For example, having a batch dimension of size 1 works fine, even if we increase the size of the trailing two dimensions.

>>> func_with_warning(jnp.ones((1, 1000, 1000))).shape
(1, 1000, 1000)
<no warning>

Replacing identity by ones works just fine.

>>> @jax.jit
>>> def fine_func(y):
...    return jnp.ones((y.shape[-1], y.shape[-1])) + jnp.matmul(y, y)

>>> fine_func(jnp.ones((2, 100, 100))).shape
(2, 100, 100)
<no warning>

Summing the identity before addition works fine.

>>> @jax.jit
>>> def fine_func(y):
...    return jnp.identity(y.shape[-1]).sum() + jnp.matmul(y, y)

>>> fine_func(jnp.ones((2, 100, 100))).shape
(2, 100, 100)
<no warning>

Subtracting rather than adding works fine.

>>> @jax.jit
>>> def fine_func(y):
...    x = jnp.identity(y.shape[-1])
...    return x - jnp.matmul(y, y)

>>> fine_func(jnp.ones((2, 100, 100))).shape
(2, 100, 100)
<no warning>

This seems to indeed affect performance.

>>> batch = jnp.ones((10, 100, 100))
>>> # Run the jitted function on the batch.
>>> %timeit func_with_warning(batch).block_until_ready()
81.4 ms ± 3.43 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> # List comprehension in Python.
>>> %timeit [func_with_warning(y).block_until_ready() for y in batch]
345 µs ± 7.57 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

We can even create a weird function that runs a lot faster using double negation.

>>> @jax.jit
>>> def weird_func(y):
...    return jnp.identity(y.shape[-1]) - jnp.matmul(- y, y)

>>> weird_func(jnp.ones((2, 100, 100))).shape
(2, 100, 100)
<no warning>

>>> batch = jnp.ones((10, 100, 100))
>>> %timeit weird_func(batch).block_until_ready()
277 µs ± 14.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Notebook for reproducing the above is here.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.24
jaxlib: 0.4.24
numpy:  1.26.4
python: 3.10.10 (main, Mar  3 2023, 16:31:35) [GCC 9.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
@tillahoffmann tillahoffmann added the bug Something isn't working label Feb 20, 2024
tillahoffmann added a commit to tillahoffmann/numpyro that referenced this issue Feb 23, 2024
tillahoffmann added a commit to tillahoffmann/numpyro that referenced this issue Feb 23, 2024
tillahoffmann added a commit to tillahoffmann/numpyro that referenced this issue Feb 23, 2024
@tillahoffmann
Copy link
Contributor Author

It turns out I cannot reproduce this on an M1 MacbookPro, but it is reproducible on GitHub actions. See https://github.com/tillahoffmann/google-jax-19885 and an example GitHub action run here.

tillahoffmann added a commit to tillahoffmann/numpyro that referenced this issue Feb 27, 2024
fehiepsi pushed a commit to pyro-ppl/numpyro that referenced this issue Feb 28, 2024
* Address jax-ml/jax#19885 for numpyro.

* Implement function to add constant or batch of vectors to diagonal.

* Use `add_diag` helper function in `distributions` module.

* Move `add_diag` to `distributions.util` module.
OlaRonning pushed a commit to aleatory-science/numpyro that referenced this issue May 6, 2024
* Address jax-ml/jax#19885 for numpyro.

* Implement function to add constant or batch of vectors to diagonal.

* Use `add_diag` helper function in `distributions` module.

* Move `add_diag` to `distributions.util` module.
@selamw1
Copy link
Contributor

selamw1 commented May 8, 2024

Hi @tillahoffmann
It appears this issue has been resolved in the latest JAX versions. I ran the mentioned code on JAX version 0.4.26 using Colab on CPU, TPU, and GPU backends (both v0.4.26 and v0.4.27). It executed without any warnings.

Below is the output of the code when running on CPU:

>>> import jax
>>> from jax import numpy as jnp
>>> print(jax.__version__)

>>> @jax.jit
>>> def func_with_warning(y):
...    return jnp.identity(y.shape[-1]) + jnp.matmul(y, y)

>>> func_with_warning(jnp.ones((2, 100, 100))).shape

Output:

0.4.26
(2, 100, 100)

I've included a Gist here for your reference.

@jakevdp
Copy link
Collaborator

jakevdp commented May 8, 2024

Thanks for following up!

@jakevdp jakevdp closed this as completed May 8, 2024
@tillahoffmann
Copy link
Contributor Author

That's great, thank you @selamw1!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants