-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unnecessary re-compiles for jax.lax.scan #14743
Comments
An identical hash is not sufficient for two objects to be considered equivalent; you also have to override import jax
import jax.numpy as jnp
class Identity():
def __call__(self, *x):
return x
def __hash__(self):
return 0
def __eq__(self, other):
return isinstance(other, Identity)
jax.config.update("jax_log_compiles", True)
s1 = Identity()
s2 = Identity()
assert hash(s1) == hash(s2)
assert s1 == s1
_ = jax.lax.scan(s1, None, jnp.arange(3))
_ = jax.lax.scan(s2, None, jnp.arange(3)) |
I see. Thanks for the quick reply! |
After some more though, I am again drawn to the conclusion that the re-compiles in def add(_, a1, a2):
return _, jnp.add(a1, a2)
jax.config.update("jax_log_compiles", True)
s1 = jax.tree_util.Partial(add, a2=jnp.arange(3, dtype=float))
s2 = jax.tree_util.Partial(add, a2=3. + jnp.arange(3, dtype=float))
_ = jax.lax.scan(s1, None, jnp.arange(3, dtype=float))
_ = jax.lax.scan(s2, None, jnp.arange(3, dtype=float))
P.S. as a quick fix I am using something along the lines of scan = jax.jit(jax.lax.scan, static_argnames=("length", "reverse", "unroll")) to work around this. |
#13071 is related but not the same. |
Maybe related to the fact that jit can't tell that the from jax.tree_utils import Partial
def f(x, y):
print("Tracing f")
return x + y
f1 = jax.jit(Partial(f, 1.))
f2 = jax.jit(Partial(f, 2.))
print(f1._cache_size()) # 0
print(f2._cache_size()) # 0
f1(0.) # this will print "Tracing f"
print(f1._cache_size()) # 1
print(f2._cache_size()) # 0
f1(0.) # this won't print
print(f1._cache_size()) # 1
print(f2._cache_size()) # 0
f2(0.) # this will print "Tracing f"
print(f1._cache_size()) # 1
print(f2._cache_size()) # 1 This is also why we had this issue: patrick-kidger/equinox#268 |
I actually wrote a more general |
I think this is at the core of what is leading to the recompiles here and also explains why scan = jax.jit(jax.lax.scan, static_argnames=("length", "reverse", "unroll"))
_ = scan(s1, None, jnp.arange(3, dtype=float))
_ = scan(s2, None, jnp.arange(3, dtype=float)) triggers no re-compiles. In my humble opinion this behavior is weird. I think def any_function(f, *x):
return f(*x)
def better_jit(f, *args, **kwargs):
return partial(jax.jit(any_function, *args, **kwargs), f)
f1 = better_jit(Partial(f, 1.))
f2 = better_jit(Partial(f, 2.)) behaves more sane. It traces |
I think you can do something very similar using a mixture of |
Only for the arguments (so it doesn't solve this problem), and it's harder to read IMO. |
@NeilGirdhar I'm trying to understand what a dynamic callable would look like. Would you happen to have an example where you would call your |
@KeAWang Something like this: from collections.abc import Callable
from jax import Array, jit
import jax.numpy as jnp
from tjax import Partial
from tjax.dataclasses import dataclass
@dataclass
class X:
x: Array
def __call__(self, y: Array, z: Array) -> Array:
return self.x + y + z
x = X(jnp.ones(2))
f = Partial(x, jnp.zeros(2), callable_is_static=False)
@jit
def g(f_: Callable[[Array], Array]) -> Array:
return f_(jnp.zeros(2))
print(g(f)) |
Great example, thank you :) |
Description
For a memory constrained model, I was forced to use
jax.lax.scan
instead ofjax.vmap
and noticed unnecessary re-compiles when passing callables tojax.lax.scan
even when the callables hash to the same value.jax.vmap
does not trigger any re-compiles.What jax/jaxlib version are you using?
jax==0.4.2, jaxlib==0.4.1
Which accelerator(s) are you using?
CPU
Additional system info
python==3.10.9, Linux
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: