-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding jnp.identity
and jnp.matmul
raises XLA warning and affects performance.
#19885
Comments
It turns out I cannot reproduce this on an M1 MacbookPro, but it is reproducible on GitHub actions. See https://github.com/tillahoffmann/google-jax-19885 and an example GitHub action run here. |
* Address jax-ml/jax#19885 for numpyro. * Implement function to add constant or batch of vectors to diagonal. * Use `add_diag` helper function in `distributions` module. * Move `add_diag` to `distributions.util` module.
* Address jax-ml/jax#19885 for numpyro. * Implement function to add constant or batch of vectors to diagonal. * Use `add_diag` helper function in `distributions` module. * Move `add_diag` to `distributions.util` module.
Hi @tillahoffmann Below is the output of the code when running on CPU: >>> import jax
>>> from jax import numpy as jnp
>>> print(jax.__version__)
>>> @jax.jit
>>> def func_with_warning(y):
... return jnp.identity(y.shape[-1]) + jnp.matmul(y, y)
>>> func_with_warning(jnp.ones((2, 100, 100))).shape Output:
I've included a Gist here for your reference. |
Thanks for following up! |
That's great, thank you @selamw1! |
Description
Under specific circumstances, my jitted function raises a warning
W external/xla/xla/service/cpu/onednn_matmul.cc:172] [Perf]: MatMul reference implementation being executed
.The warning is only raised for this specific setup. Turning one of many knobs eliminates the warning. For example, having a batch dimension of size 1 works fine, even if we increase the size of the trailing two dimensions.
Replacing
identity
byones
works just fine.Summing the identity before addition works fine.
Subtracting rather than adding works fine.
This seems to indeed affect performance.
We can even create a weird function that runs a lot faster using double negation.
Notebook for reproducing the above is here.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: