You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 22.04
Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: (flax: 0.8.5, Jax: 0.4.28, JAXlib: 0.4.28)
Python version: 3.11
Problem you have encountered:
It appears that the initialization of model parameters in flax.linen depends on the name of the submodule in the parent model, rather than solely on the provided PRNGKey. When two models use the same module but with different submodule names, the parameters are initialized differently.
What you expected to happen:
I expected the parameter initialization to depend only on the provided PRNGkey (jax.random.PRNGKey(0)) and not on the names of submodule in the parent model.
Logs, error messages, etc:
None, but the observed output when running the example code is:
The key that is used in the model.init function for all three times is [0, 0] but the keys used for the initialization are not the same for all three cases but are as seen below:
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
pip show flax jax jaxlib
: (flax: 0.8.5, Jax: 0.4.28, JAXlib: 0.4.28)Problem you have encountered:
It appears that the initialization of model parameters in
flax.linen
depends on the name of the submodule in the parent model, rather than solely on the providedPRNGKey
. When two models use the same module but with different submodule names, the parameters are initialized differently.What you expected to happen:
I expected the parameter initialization to depend only on the provided PRNGkey (
jax.random.PRNGKey(0)
) and not on the names of submodule in the parent model.Logs, error messages, etc:
None, but the observed output when running the example code is:
Steps to reproduce:
The text was updated successfully, but these errors were encountered: