Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nnx static fields not part of static tree structure #3863

Closed
NeilGirdhar opened this issue Apr 17, 2024 · 1 comment
Closed

nnx static fields not part of static tree structure #3863

NeilGirdhar opened this issue Apr 17, 2024 · 1 comment

Comments

@NeilGirdhar
Copy link
Contributor

from flax.experimental import nnx
from jax import jit
from jax.tree import flatten


class C(nnx.Module, experimental_pytree=True):
    def __init__(self, x):
        self.x = x


c = C(1)
d = C(2)

values, tree, = flatten(c)
valuesd, treed, = flatten(d)


@jit
def f(x):
    print(x.x)


print(hash(tree), hash(treed), tree, treed, values, valuesd)
f(c)
f(d)
f(c)

Despite the static fields being different:

  static_fields={
    'x': 1
  }
# vs
  static_fields={
    'x': 2
  }

the hashes are the same and therefore the jitted function is called once.

Have I misunderstood how the pytree flattener is supposed to work?

@NeilGirdhar
Copy link
Contributor Author

@cgarciae Thanks! This works on master (prints "1 2"), but not on 0.8.2 (prints "1").

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant