-
-
Notifications
You must be signed in to change notification settings - Fork 146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bound methods of the same Module
have same hash but gets recompiled when jitted
#268
Comments
On second thought, I think this is a case where we're jitting two different closures. So it's expected that we'll trace twice and compile twice, since On the other hand, I'm not sure why Closing for now, since I don't think there's actually an issue! Thanks! |
Right, your explanation is correct. I think the cache entries are keyed by object identity ( I think that's arguably a bug with As a fix: you should find that |
If you have the time, would you mind expanding a bit on the caching differences? Which entries are cached by Also how did you manage to get |
So traced arguments are keyed by I'm actually not certain what the wrapped function is keyed as. I'm speculating that it's As for def call(fn, args, kwargs):
return fn(*args, **kwargs)
call = jax.jit(call, static_argnums=0)
def keawang_jit(fn):
def wrapped_fn(*args, **kwargs):
return call(fn, args, kwargs)
return wrapped_fn And this is something that |
Got it thank you! |
jax.jit
treats two functions with the same hash as the same function, allowing them to share the compilation cache. For example,With equinox modules, bound functions share the same hash if they belong to the same
Module
type:However, when we jit the bound methods, they will recompile despite having the same hash.
Is there any way to match the behavior of the two setups? This way we don't need to recompile
Module.foo
if only the PyTree contents get updated.The text was updated successfully, but these errors were encountered: