You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
However, if I remove super().__post_init__(), I got the unhelpful message flax.errors.CallCompactUnboundModuleError: Can't call compact methods on unbound modules that takes me a while to figure out why the module is considered unbound.
I hope the error message can get improved.
Problem 2
There is a limitation that super().__post_init__() has to happen after setting self.x_square. Otherwise it will say flax.errors.SetAttributeFrozenModuleError: Can't set x_square=9 for Module of type A: Module instance is frozen outside of setup method.
This limitation makes it impossible to do anything interesting in a subclass's __post_init__. I'd like to compute another attribute based on x_square, but none of the below works:
classB(A):
def__post_init__(self):
self.x_square_square=self.x_square**2# "B" object has no attribute "x_square"super().__post_init__()
r=B(x=3)
r.init(jax.random.PRNGKey(2), jnp.ones(3))
classB(A):
def__post_init__(self):
super().__post_init__()
self.x_square_square=self.x_square**2# Can't set x_square_square=81 for Module of type B: Module instance is frozen outside of setup method.
Assuming square is an expensive operation, I hope there is a reasonable way to subclass and do what I want.
re "Problem 1" we've added some clearer errors in #2535
"Problem 2" is considerably trickier to solve in general - unless we decided to basically create a new "pre-post-init" "hook" function that we'd direct people to override - much as you've done here.
Yeah I don't have any concrete suggestions about Problem2 either. I guess it's just an inflexibility of inheritance in general - subclasses can customize the overrided function, but cannot easily control when/where is called.
Closing this since __do_post_init seems like the best thing we can do.
System information
pip show flax jax jaxlib
: latest headProblem 1
My module has an attribute computed from another attribute. The below works:
However, if I remove
super().__post_init__()
, I got the unhelpful messageflax.errors.CallCompactUnboundModuleError: Can't call compact methods on unbound modules
that takes me a while to figure out why the module is considered unbound.I hope the error message can get improved.
Problem 2
There is a limitation that
super().__post_init__()
has to happen after settingself.x_square
. Otherwise it will sayflax.errors.SetAttributeFrozenModuleError: Can't set x_square=9 for Module of type A: Module instance is frozen outside of setup method.
This limitation makes it impossible to do anything interesting in a subclass's
__post_init__
. I'd like to compute another attribute based onx_square
, but none of the below works:Assuming square is an expensive operation, I hope there is a reasonable way to subclass and do what I want.
The below works - though a bit annoying:
The text was updated successfully, but these errors were encountered: