Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT changes the output of a function #10129

Closed
fbartolic opened this issue Apr 2, 2022 · 2 comments
Closed

JIT changes the output of a function #10129

fbartolic opened this issue Apr 2, 2022 · 2 comments
Labels
bug Something isn't working

Comments

@fbartolic
Copy link

What's going on here?

import jax.numpy as jnp
from jax import jit, random

key = random.PRNGKey(42)
key, subkey = random.split(key)

X = random.uniform(subkey, (2, 100))

def f1(X):
    X = X.at[0].set(jnp.roll(X[0], 5, axis=-1))
    return X

def f2(X):
    X = X.at[0].set(X[0]**2)
    return X

def f3(X):
    x = jnp.roll(X[0], 5, axis=-1) 
    return jnp.stack([x, X[1]])

print(jnp.all(f1(X) == jit(f1)(X))) # evaluates to False
print(jnp.all(f2(X) == jit(f2)(X))) # evaluates to True
print(jnp.all(f3(X) == jit(f3)(X))) # evaluates to True
@fbartolic fbartolic added the bug Something isn't working label Apr 2, 2022
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 2, 2022

Thanks for the report! This is quite reminiscent of #7461, and similar to that issue only appears to happen on CPU. I suspect the same mechanism is somehow at play

@jakevdp
Copy link
Collaborator

jakevdp commented Jul 21, 2022

I tried with the most recent jax & jaxlib on a CPU backend - it appears the issue has been fixed.

@jakevdp jakevdp closed this as completed Jul 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants