-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jit does not respect backend type when the function has no arguments #1431
Comments
Functions without arguments can't meaningfully be jitted, since We don't yet have a way to specify the device used for op-by-op computations; it might make sense for such a mechanism to be dynamically scoped (i.e., a context manager) since we can't use data dependence for scoping the way we do for |
If for some reason you really need to use this idiom, you can force data dependence on constants (sometimes handy to prevent them from being folded out) by using from jax.api import jit
import jax.random as random
from jax import lax
dummy = 0
def x(_placeholder):
return lax.tie_in(_placeholder, random.normal(random.PRNGKey(1), (2, 3)))
print(x(dummy).device_buffer.device()) # --> gpu
print(jit(x)(dummy).device_buffer.device()) # --> gpu
print(jit(x, backend='cpu')(dummy).device_buffer.device()) # --> cpu |
Before this commit, this computation would avoid materializing the iota array at trace time: @jit def f(x): m, n = x.shape return x + np.arange(n) But this one would materialize the iota array at trace time and stage it into the computation as a potentially large array constant: @jit def f(x): m, n = x.shape return x + np.arange(m)[:, None] The difference is that previously operations like broadcasts, transposes, and reshapes that add singleton dimensions (as above) would force otherwise lazy values to be materialized, while after this commit broadcasts, transposes, and reshapes are all lazy operations that only update metadata on their input rather than compiling and executing XLA computations and producing new buffers. Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full). This commit replaces the ad-hoc "lazy device constant" system, which was used to get the simpler behavior in the first example above. Incidentally fixes #1431 See #1668 for more.
Before this commit, this computation would avoid materializing the iota array at trace time: @jit def f(x): m, n = x.shape return x + np.arange(n) But this one would materialize the iota array at trace time and stage it into the computation as a potentially large array constant: @jit def f(x): m, n = x.shape return x + np.arange(m)[:, None] The difference is that previously operations like broadcasts, transposes, and reshapes that add singleton dimensions (as above) would force otherwise lazy values to be materialized, while after this commit broadcasts, transposes, and reshapes are all lazy operations that only update metadata on their input rather than compiling and executing XLA computations and producing new buffers. Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full). This commit replaces the ad-hoc "lazy device constant" system, which was used to get the simpler behavior in the first example above. Incidentally fixes #1431 See #1668 for more.
Example:
All these output
GpuDevice(id=0)
, while the last one should be CPU.The text was updated successfully, but these errors were encountered: