-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Double jtu.Partial
results in cache miss
#13071
Comments
Hey @patrick-kidger, because you are passing these as static and hash(h1) != hash(h2) I'd say this behaviour is "expected". |
I don't think so. The single-level partial works just fine, despite the fact that each Partial object will have different object ids. I think it is the special handling here that's misbehaving: |
I agree with @cgarciae actually. This seems like not a bug? I tried your single-level partial case and it seems to behave as expected: import jax
import jax.tree_util as jtu
def f(x):
return
g1 = jtu.Partial(f)
g2 = jtu.Partial(f)
assert hash(g1) != hash(g2)
def check(x):
return
with jax.log_compiles(True):
check1 = jax.jit(check, static_argnums=0)
check1(g1) # compiles since it's the first call
check1(g2) # compiles since hash(g1) != hash(g2)
with jax.log_compiles(True):
check2 = jax.jit(check)
check2(g1) # tracing!
check2(g2) # no tracing as expected since now g1 and g2 are treated as dynamic |
Right; I've since recalled that the function-valued argument to I think there should simply probably be a check to guard against this error, e.g. raise an error if And FWIW in my own code, for this reason I no longer use class Partial(eqx.Module):
func: Callable
args: tuple
kwargs: dict
def __init__(self, func, *args, **kwargs):
self.func = func
self.args = args
self.kwargs = kwargs
def __call__(self, *args, **kwargs):
return self.func(*self.args, *args, **kwargs, **self.kwargs) |
That's a good check. And maybe add some documentation to its assumptions too... |
Description
Both
h1
andh2
have the same PyTree structure so expected behaviour is for the second call not to print out anything.What jax/jaxlib version are you using?
0.3.23
The text was updated successfully, but these errors were encountered: