Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap __post_init__ to check or force that Module.__post_init__ is called #2733

Open
levskaya opened this issue Dec 17, 2022 · 2 comments
Open
Assignees
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.

Comments

@levskaya
Copy link
Collaborator

We added a runtime check in #2535 to emit a "IncorrectPostInitOverrideError" when users forget to call super().__post_init__() from a custom __post_init__ method. That implementation only runs the check at the top-level though, so it's not going to catch internal Modules that override __post_init__.

e.g.

import jax.numpy as jnp
from flax import linen as nn
import jax

class Foo(nn.Module):
  def __post_init__(self):
    pass
  @nn.compact
  def __call__(self, x):
    return nn.Dense(12)(x)

class Bar(nn.Module):
  @nn.compact
  def __call__(self, x):
    return Foo()(x)

b = Bar()
b.init(jax.random.PRNGKey(0), x=jnp.ones((10, 3)))

What I think we should do instead is to just wrap __post_init__ on subclass instantiation to check (and possible even force) that Module.__post_init__ is called.

For forcing the only complex case is when users use inheritance and call their own parent-class __post_init__ in some random order inside their child __post_init__ (I've seen this.), but I think there's a way for a wrapper to detect this and to only run in the outermost child-defined __post_init__.

@levskaya levskaya self-assigned this Dec 17, 2022
@cgarciae
Copy link
Collaborator

I think there's a way for a wrapper to detect this and to only run in the outermost child-defined post_init.

Maybe a context manager that contains that stores a references to the object calling __post_init__ so we don't call it again for the same object in case there are multiple wrapped methods?

@chiamp chiamp added the Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. label Dec 22, 2022
@chiamp
Copy link
Collaborator

chiamp commented Dec 30, 2022

For forcing the only complex case is when users use inheritance and call their own parent-class post_init in some random order inside their child post_init (I've seen this.), but I think there's a way for a wrapper to detect this and to only run in the outermost child-defined post_init.

Is there a short code example of this @levskaya?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
Projects
None yet
Development

No branches or pull requests

3 participants