Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Double jtu.Partial results in cache miss #13071

Open
patrick-kidger opened this issue Nov 2, 2022 · 5 comments · May be fixed by #15999
Open

Double jtu.Partial results in cache miss #13071

patrick-kidger opened this issue Nov 2, 2022 · 5 comments · May be fixed by #15999
Labels
bug Something isn't working

Comments

@patrick-kidger
Copy link
Collaborator

patrick-kidger commented Nov 2, 2022

Description

import jax
import jax.tree_util as jtu

def f(x): pass
g = jtu.Partial(f)
h1 = jtu.Partial(g)
h2 = jtu.Partial(g)

def check(x):
    print("tracing!")

check = jax.jit(check, static_argnums=0)
check(h1)  # tracing!
check(h2)  # tracing!

Both h1 and h2 have the same PyTree structure so expected behaviour is for the second call not to print out anything.

What jax/jaxlib version are you using?

0.3.23

@patrick-kidger patrick-kidger added the bug Something isn't working label Nov 2, 2022
@cgarciae
Copy link
Collaborator

cgarciae commented Nov 8, 2022

Hey @patrick-kidger, because you are passing these as static and Partial seems to implement hash by object id (or something similar):

hash(h1) != hash(h2)

I'd say this behaviour is "expected".

@patrick-kidger
Copy link
Collaborator Author

patrick-kidger commented Nov 8, 2022

I don't think so. The single-level partial works just fine, despite the fact that each Partial object will have different object ids.

I think it is the special handling here that's misbehaving:

https://github.com/google/jax/blob/96f6c1c9d414e6ebc54ff7f08115a9a9a6d6a8f8/jax/_src/tree_util.py#L366

@KeAWang
Copy link

KeAWang commented Mar 16, 2023

I agree with @cgarciae actually. This seems like not a bug? I tried your single-level partial case and it seems to behave as expected:

import jax
import jax.tree_util as jtu

def f(x):
    return 

g1 = jtu.Partial(f)
g2 = jtu.Partial(f)

assert hash(g1) != hash(g2)

def check(x):
    return

with jax.log_compiles(True):
    check1 = jax.jit(check, static_argnums=0)
    check1(g1) # compiles since it's the first call
    check1(g2)  # compiles since hash(g1) != hash(g2)

with jax.log_compiles(True):
    check2 = jax.jit(check)
    check2(g1)  # tracing!
    check2(g2)  # no tracing as expected since now g1 and g2 are treated as dynamic

@patrick-kidger
Copy link
Collaborator Author

patrick-kidger commented Mar 16, 2023

Right; I've since recalled that the function-valued argument to jtu.Partial is assumed to be static. Thus Partial(Partial(...), ...) is actually a user error, since the inner Partial(...) isn't really a static thing.

I think there should simply probably be a check to guard against this error, e.g. raise an error if len(jtu.tree_leaves(func)) > 1.

And FWIW in my own code, for this reason I no longer use jtu.Partial, and use my own implementation instead:

class Partial(eqx.Module):
    func: Callable
    args: tuple
    kwargs: dict

    def __init__(self, func, *args, **kwargs):
        self.func = func
        self.args = args
        self.kwargs = kwargs

    def __call__(self, *args, **kwargs):
        return self.func(*self.args, *args, **kwargs, **self.kwargs)

@KeAWang
Copy link

KeAWang commented Mar 16, 2023

That's a good check. And maybe add some documentation to its assumptions too...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants