Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialization of flax.linen model parameters depends on member variable names #4367

Open
bluesunb opened this issue Nov 10, 2024 · 1 comment

Comments

@bluesunb
Copy link

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 22.04
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: (flax: 0.8.5, Jax: 0.4.28, JAXlib: 0.4.28)
  • Python version: 3.11

Problem you have encountered:

It appears that the initialization of model parameters in flax.linen depends on the name of the submodule in the parent model, rather than solely on the provided PRNGKey. When two models use the same module but with different submodule names, the parameters are initialized differently.

What you expected to happen:

I expected the parameter initialization to depend only on the provided PRNGkey (jax.random.PRNGKey(0)) and not on the names of submodule in the parent model.

Logs, error messages, etc:

None, but the observed output when running the example code is:

{'1': {'params': {'net': {'Dense_0': {'bias': Array(0., dtype=float32),
                                      'kernel': Array(2.1496024, dtype=float32)}}}},
 '2': {'params': {'net2': {'Dense_0': {'bias': Array(0., dtype=float32),
                                       'kernel': Array(1.4864768, dtype=float32)}}}},
 '3': {'params': {'net': {'Dense_0': {'bias': Array(0., dtype=float32),
                                      'kernel': Array(2.1496024, dtype=float32)}}}}}

Steps to reproduce:

import jax, jax.numpy as jp
import flax.linen as nn
import optax
from pprint import pp

class Tmp(nn.Module):
    @nn.compact
    def __call__(self, x):
        return nn.Dense(2)(x)

class Network(nn.Module):
    net: Tmp
    def __call__(self, x):
        return self.net(x)

class Network2(nn.Module):
    net2: Tmp
    def __call__(self, x):
        return self.net2(x)

class Network3(nn.Module):
    net: Tmp
    def __call__(self, x):
        return self.net(x)

net = Tmp()
model1 = Network(net)
model2 = Network2(net)
model3 = Network3(net)

rng = jax.random.PRNGKey(0)
x = jp.zeros((1, 3))
params = {
    "1": model1.init(rng, x),
    "2": model2.init(rng, x),
    "3": model3.init(rng, x),
}

pp(jax.tree_map(optax.global_norm, params))
@stergiosba
Copy link

stergiosba commented Nov 22, 2024

The key that is used in the model.init function for all three times is [0, 0] but the keys used for the initialization are not the same for all three cases but are as seen below:

model1.init(rng, x) -> rng = [4189436379 3584055865]
model2.init(rng, x) -> rng = [3798156617 2715557378]
model3.init(rng, x) -> rng = [4189436379 3584055865]

I got these values by printing the key inside function variance_scaling in jax.nn.initializers.py -> lecun_normal initializer.

This is caused because the suffix is different in the LazyRng:

return LazyRng.create(self.rngs[name], self.rng_counters[name]).as_jax_rng()

Specifically, in this case, we have:

LazyRng(rng=Array((), dtype=key<fry>) overlaying:
[0 0], suffix=('net', 'Dense_0'))
LazyRng(rng=Array((), dtype=key<fry>) overlaying:
[0 0], suffix=('net', 'Dense_0'))
LazyRng(rng=Array((), dtype=key<fry>) overlaying:
[0 0], suffix=('net2', 'Dense_0'))
LazyRng(rng=Array((), dtype=key<fry>) overlaying:
[0 0], suffix=('net2', 'Dense_0'))
LazyRng(rng=Array((), dtype=key<fry>) overlaying:
[0 0], suffix=('net', 'Dense_0'))
LazyRng(rng=Array((), dtype=key<fry>) overlaying:
[0 0], suffix=('net', 'Dense_0'))

It's called 6 times, since weights and biases call this. Hope this helps you understand why.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants