Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bound methods of the same Module have same hash but gets recompiled when jitted #268

Closed
KeAWang opened this issue Feb 27, 2023 · 5 comments

Comments

@KeAWang
Copy link

KeAWang commented Feb 27, 2023

jax.jit treats two functions with the same hash as the same function, allowing them to share the compilation cache. For example,

# Let's define a closure
y = 1.
def foo(x):
    print("Tracing!")
    return y * x

# Let's jit the same function, which of course will have the same hash, i.e. `hash(foo) == hash(foo)`
jitf1 = jax.jit(foo)
jitf2 = jax.jit(foo)
jitf1(0.)  # this will print "Tracing!"
jitf1(0.)  # this won't
jitf2(0.)  # this won't either

With equinox modules, bound functions share the same hash if they belong to the same Module type:

class MyModule(eqx.Module):
    y: jnp.ndarray
    def foo(self, x):
        print("Tracing!")
        return self.y * x
    
m1 = MyModule(y=1.)
m2 = MyModule(y=2.)
assert hash(m1) != hash(m2)   # `eqx.Module` implements its own `__hash__` so that their hashes are the same iif their PyTree leaves are the same.
assert hash(m1.foo) == hash(m2.foo)  # however methods bound to different instances have the same hash; not sure how, maybe from the `Partial` shimming?

However, when we jit the bound methods, they will recompile despite having the same hash.

eqx_jitf1 = jax.jit(m1.foo)
eqx_jitf2 = jax.jit(m2.foo)
eqx_jitf1(0.)  # this will print "Tracing!"
eqx_jitf1(0.)  # this won't print
eqx_jitf2(0.)  # this will print "Tracing!"

Is there any way to match the behavior of the two setups? This way we don't need to recompile Module.foo if only the PyTree contents get updated.

@KeAWang
Copy link
Author

KeAWang commented Feb 27, 2023

On second thought, I think this is a case where we're jitting two different closures. So it's expected that we'll trace twice and compile twice, since m1.foo is equivalent to a closure like partial(MyModule.foo, m1).

On the other hand, I'm not sure why m1.foo and m2.foo have the same hash.

Closing for now, since I don't think there's actually an issue! Thanks!

@KeAWang KeAWang closed this as completed Feb 27, 2023
@patrick-kidger
Copy link
Owner

Right, your explanation is correct. I think the cache entries are keyed by object identity (id(...)) of the wrapped function, rather than by __hash__ and __eq__.

I think that's arguably a bug with jax.jit, given that the cache entry are keyed by __hash__ and __eq__ for the (static) arguments that are passed. You might like to raise an issue on the JAX issue tracker.

As a fix: you should find that eqx.filter_jit is smart enough to handle this case, and that it doesn't recompile.

@KeAWang
Copy link
Author

KeAWang commented Feb 28, 2023

If you have the time, would you mind expanding a bit on the caching differences? Which entries are cached by id and which entries are cached by __hash__ and __eq__? That way I can understand the situation better and file a better report upstream.

Also how did you manage to get filter_jit to handle this case when jax.jit doesn't?

@patrick-kidger
Copy link
Owner

So traced arguments are keyed by .shape and .dtype, whilst static arguments are keyed by __hash__ and __eq__.

I'm actually not certain what the wrapped function is keyed as. I'm speculating that it's id(...). I do think that it would be reasonable for this to be keyed by __hash__ and __eq__, though.

As for filter_jit: it's possible to make the function be keyed by __hash__ and __eq__ by doing something like:

def call(fn, args, kwargs):
    return fn(*args, **kwargs)

call = jax.jit(call, static_argnums=0)

def keawang_jit(fn):
   def wrapped_fn(*args, **kwargs):
        return call(fn, args, kwargs)
    return wrapped_fn

And this is something that filter_jit does under the hood (alongside its other improvements to jax.jit).

@KeAWang
Copy link
Author

KeAWang commented Feb 28, 2023

Got it thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants