Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jit does not respect backend type when the function has no arguments #1431

Closed
romanngg opened this issue Oct 2, 2019 · 2 comments · Fixed by #1668
Closed

jit does not respect backend type when the function has no arguments #1431

romanngg opened this issue Oct 2, 2019 · 2 comments · Fixed by #1668
Assignees

Comments

@romanngg
Copy link
Contributor

romanngg commented Oct 2, 2019

Example:

from jax.api import jit
import jax.random as random

def x():
  return random.normal(random.PRNGKey(1), (2, 3))

x().device_buffer.device()
jit(x)().device_buffer.device()
jit(x, backend='cpu')().device_buffer.device()

All these output GpuDevice(id=0), while the last one should be CPU.

@jekbradbury
Copy link
Contributor

Functions without arguments can't meaningfully be jitted, since jit relies on data dependence to trace primitives inside the function (it injects tracer values for each function argument, and follows them through the function). That is, the return value of x() is actually a constant that was computed op-by-op (on the default device) when jit attempted to trace x(), and not the result of a jitted XLA computation at all.

We don't yet have a way to specify the device used for op-by-op computations; it might make sense for such a mechanism to be dynamically scoped (i.e., a context manager) since we can't use data dependence for scoping the way we do for jit.

@levskaya
Copy link
Collaborator

levskaya commented Oct 5, 2019

If for some reason you really need to use this idiom, you can force data dependence on constants (sometimes handy to prevent them from being folded out) by using lax.tie_in:

from jax.api import jit
import jax.random as random
from jax import lax

dummy = 0

def x(_placeholder):
  return lax.tie_in(_placeholder, random.normal(random.PRNGKey(1), (2, 3))) 

print(x(dummy).device_buffer.device())  # --> gpu
print(jit(x)(dummy).device_buffer.device())  # --> gpu
print(jit(x, backend='cpu')(dummy).device_buffer.device())  # --> cpu

@mattjj mattjj self-assigned this Jan 7, 2020
mattjj added a commit that referenced this issue Jan 8, 2020
mattjj added a commit that referenced this issue Jan 8, 2020
Before this commit, this computation would avoid materializing the iota
array at trace time:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(n)

But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(m)[:, None]

The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.

Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).

This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.

Incidentally fixes #1431

See #1668 for more.
@mattjj mattjj mentioned this issue Jan 8, 2020
11 tasks
mattjj added a commit that referenced this issue Jan 8, 2020
Before this commit, this computation would avoid materializing the iota
array at trace time:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(n)

But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:

  @jit
  def f(x):
    m, n = x.shape
    return x + np.arange(m)[:, None]

The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.

Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).

This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.

Incidentally fixes #1431

See #1668 for more.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants