-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type specification/inferrence now fails in @jax.jit #9380
Comments
Thanks for the report. I can reproduce with jaxlib 0.1.76; the jax version doesn't appear to matter. |
Shorter repro: import numpy as np
import jax
import jax.numpy as jnp
@jax.jit
def f():
return jnp.exp(np.complex64(1j))
f() |
I think the issue is that complex dtypes were left out of the list here: https://github.com/google/jax/blob/0382a6a04eddd7506a4ef6bb0c93f0f660ee3df6/jax/interpreters/mlir.py#L243-L247 |
We're going to make a |
Thanks for the quick turnaround! |
Please:
The following test case has worked for a while, but has recently started failing. It seems that there isn't a constant handler for
complex64
anymore? The following reproducer demonstrates the issue, but replacingout_dtype.type(1j)
with1j
fixes the problem.This works on jax 0.2.26 and jaxlib 0.1.75, but fails on jax 0.2.27 and 0.1.76.
The text was updated successfully, but these errors were encountered: