Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unnecessary re-compiles for jax.lax.scan #14743

Open
Edenhofer opened this issue Mar 1, 2023 · 14 comments
Open

Unnecessary re-compiles for jax.lax.scan #14743

Edenhofer opened this issue Mar 1, 2023 · 14 comments
Labels
bug Something isn't working

Comments

@Edenhofer
Copy link
Contributor

Description

For a memory constrained model, I was forced to use jax.lax.scan instead of jax.vmap and noticed unnecessary re-compiles when passing callables to jax.lax.scan even when the callables hash to the same value. jax.vmap does not trigger any re-compiles.

class Identity():
    def __call__(self, *x):
        return x

    def __hash__(self):
        return 0


jax.config.update("jax_log_compiles", True)

s1 = Identity()
s2 = Identity()
assert hash(s1) == hash(s2)

_ = jax.lax.scan(s1, None, jnp.arange(3))
_ = jax.lax.scan(s2, None, jnp.arange(3))

What jax/jaxlib version are you using?

jax==0.4.2, jaxlib==0.4.1

Which accelerator(s) are you using?

CPU

Additional system info

python==3.10.9, Linux

NVIDIA GPU info

No response

@Edenhofer Edenhofer added the bug Something isn't working label Mar 1, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented Mar 2, 2023

An identical hash is not sufficient for two objects to be considered equivalent; you also have to override __eq__:

import jax
import jax.numpy as jnp

class Identity():
    def __call__(self, *x):
        return x

    def __hash__(self):
        return 0

    def __eq__(self, other):
        return isinstance(other, Identity)


jax.config.update("jax_log_compiles", True)

s1 = Identity()
s2 = Identity()
assert hash(s1) == hash(s2)
assert s1 == s1

_ = jax.lax.scan(s1, None, jnp.arange(3))
_ = jax.lax.scan(s2, None, jnp.arange(3))

@Edenhofer
Copy link
Contributor Author

I see. Thanks for the quick reply!

@Edenhofer
Copy link
Contributor Author

After some more though, I am again drawn to the conclusion that the re-compiles in jax.lax.scan are weird. Taking the following code block as example:

def add(_, a1, a2):
    return _, jnp.add(a1, a2)


jax.config.update("jax_log_compiles", True)

s1 = jax.tree_util.Partial(add, a2=jnp.arange(3, dtype=float))
s2 = jax.tree_util.Partial(add, a2=3. + jnp.arange(3, dtype=float))

_ = jax.lax.scan(s1, None, jnp.arange(3, dtype=float))
_ = jax.lax.scan(s2, None, jnp.arange(3, dtype=float))
Finished tracing + transforming jit(iota) in 0.0002722740173339844 sec
No GPU[/TPU](https://file+.vscode-resource.vscode-cdn.net/TPU) found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Compiling prim_fun (139638274830992) for with global shapes and types (). Argument mapping: ().
Finished XLA compilation of jit(iota) in 0.010893583297729492 sec
Finished tracing + transforming jit(scan) in 0.0002956390380859375 sec
Compiling prim_fun (139638275152592) for with global shapes and types (ShapedArray(float32[3]), ShapedArray(float32[3])). Argument mapping: (OpShardingSharding({replicated}), OpShardingSharding({replicated})).
Finished XLA compilation of jit(scan) in 0.015352725982666016 sec
Finished tracing + transforming jit(scan) in 0.0002682209014892578 sec
Compiling prim_fun (139638275158032) for with global shapes and types (ShapedArray(float32[3]), ShapedArray(float32[3])). Argument mapping: (OpShardingSharding({replicated}), OpShardingSharding({replicated})).
Finished XLA compilation of jit(scan) in 0.015791654586791992 sec

jax.lax.scan recompiles for every new input (as long as they aren't identical) while jax.jit (jax.jit(s2)(None, jnp.arange(3, dtype=float))) only compiles add once. I think the latter makes more sense in general but I do see the problem of having the function to scan be a static argument and Partial not working anymore because Partial(s1) != Partial(s2).

P.S. as a quick fix I am using something along the lines of

scan = jax.jit(jax.lax.scan, static_argnames=("length", "reverse", "unroll"))

to work around this.

@Edenhofer Edenhofer reopened this Mar 6, 2023
@patrick-kidger
Copy link
Collaborator

Just had a quick skim. Not sure if they're related but maybe c.f. also #13554 and #13071.

@Edenhofer
Copy link
Contributor Author

#13071 is related but not the same. jax.lax.scan does not hit the cache even for a single Partial.

@KeAWang
Copy link

KeAWang commented Mar 10, 2023

Maybe related to the fact that jit can't tell that the Partials are wrapping the same function and using the same type of closed variables. For example:

from jax.tree_utils import Partial

def f(x, y):
    print("Tracing f")
    return x + y

f1 = jax.jit(Partial(f, 1.))
f2 = jax.jit(Partial(f, 2.))

print(f1._cache_size())  # 0
print(f2._cache_size())  # 0
f1(0.)  # this will print "Tracing f"
print(f1._cache_size())  # 1
print(f2._cache_size())  # 0
f1(0.)  # this won't print
print(f1._cache_size())  # 1
print(f2._cache_size())  # 0
f2(0.)  # this will print "Tracing f"
print(f1._cache_size())  # 1
print(f2._cache_size())  # 1

This is also why we had this issue: patrick-kidger/equinox#268

@NeilGirdhar
Copy link
Contributor

I actually wrote a more general Partial to solve this here. It allows you to have the callable be dynamic along with any arguments you choose to be static or dynamic.

@Edenhofer
Copy link
Contributor Author

Edenhofer commented Mar 10, 2023

Maybe related to the fact that jit can't tell that the Partials are wrapping the same function and using the same type of closed variables.

I think this is at the core of what is leading to the recompiles here and also explains why

scan = jax.jit(jax.lax.scan, static_argnames=("length", "reverse", "unroll"))
_ = scan(s1, None, jnp.arange(3, dtype=float))
_ = scan(s2, None, jnp.arange(3, dtype=float))

triggers no re-compiles. In my humble opinion this behavior is weird. I think

def any_function(f, *x):
    return f(*x)

def better_jit(f, *args, **kwargs):
    return partial(jax.jit(any_function, *args, **kwargs), f)

f1 = better_jit(Partial(f, 1.))
f2 = better_jit(Partial(f, 2.))

behaves more sane. It traces f exactly once.

@Edenhofer
Copy link
Contributor Author

I actually wrote a more general Partial to solve this here. It allows you to have the callable be dynamic along with any arguments you choose to be static or dynamic.

I think you can do something very similar using a mixture of functools.partial and jax.tree_util.Partial.

@NeilGirdhar
Copy link
Contributor

I think you can do something very similar using a mixture of functools.partial and jax.tree_util.Partial.

Only for the arguments (so it doesn't solve this problem), and it's harder to read IMO.

@KeAWang
Copy link

KeAWang commented Mar 15, 2023

@NeilGirdhar I'm trying to understand what a dynamic callable would look like. Would you happen to have an example where you would call your Partial with callable_is_static=False?

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Mar 15, 2023

@KeAWang Something like this:

from collections.abc import Callable

from jax import Array, jit
import jax.numpy as jnp
from tjax import Partial
from tjax.dataclasses import dataclass


@dataclass
class X:
  x: Array
  def __call__(self, y: Array, z: Array) -> Array:
    return self.x + y + z

x = X(jnp.ones(2))
f = Partial(x, jnp.zeros(2), callable_is_static=False)

@jit
def g(f_: Callable[[Array], Array]) -> Array:
  return f_(jnp.zeros(2))

print(g(f))

@KeAWang
Copy link

KeAWang commented Mar 15, 2023

Great example, thank you :)

@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Mar 15, 2023

@KeAWang YVW. By the way, it occurred to me that a simpler and more common example would be the backward pass returned by a jax.vjp. This callable is a pytree thanks to #3705. Another example is an object of type Partial itself.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants