Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unhelpful error message and a limitation of __post_init__ #2479

Closed
ppwwyyxx opened this issue Sep 23, 2022 · 2 comments
Closed

Unhelpful error message and a limitation of __post_init__ #2479

ppwwyyxx opened this issue Sep 23, 2022 · 2 comments
Assignees
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)

Comments

@ppwwyyxx
Copy link
Contributor

ppwwyyxx commented Sep 23, 2022

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): colab
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: latest head
  • Python version: 3.9
  • GPU/TPU model and memory: TPU
  • CUDA version (if applicable):

Problem 1

My module has an attribute computed from another attribute. The below works:

from flax import linen as nn 
import jax.numpy as jnp
import jax
class A(nn.Module):
  x: float
  def __post_init__(self):
    self.x_square = self.x ** 2
    super().__post_init__()
  @nn.compact
  def __call__(self, input):
    return input + 3

r = A(x=3)
r.init(jax.random.PRNGKey(2), jnp.ones(3))

However, if I remove super().__post_init__(), I got the unhelpful message flax.errors.CallCompactUnboundModuleError: Can't call compact methods on unbound modules that takes me a while to figure out why the module is considered unbound.

I hope the error message can get improved.

Problem 2

There is a limitation that super().__post_init__() has to happen after setting self.x_square. Otherwise it will say flax.errors.SetAttributeFrozenModuleError: Can't set x_square=9 for Module of type A: Module instance is frozen outside of setup method.

This limitation makes it impossible to do anything interesting in a subclass's __post_init__. I'd like to compute another attribute based on x_square, but none of the below works:

class B(A):
  def __post_init__(self):
    self.x_square_square = self.x_square ** 2   # "B" object has no attribute "x_square"
    super().__post_init__()
r = B(x=3)
r.init(jax.random.PRNGKey(2), jnp.ones(3))
class B(A):
  def __post_init__(self):
    super().__post_init__()
    self.x_square_square = self.x_square ** 2  # Can't set x_square_square=81 for Module of type B: Module instance is frozen outside of setup method. 

Assuming square is an expensive operation, I hope there is a reasonable way to subclass and do what I want.

The below works - though a bit annoying:

class A(nn.Module):
  x: float
  def __post_init__(self):
    self._do_post_init()
    super().__post_init__()
    
  def _do_post_init(self):
    self.x_square = self.x ** 2

  @nn.compact
  def __call__(self, input):
    return input + 3

class B(A):
  def __do_post_init(self):
    super()._do_post_init()
    self.x_square_square = self.x_square ** 2

r = B(x=3)
r.init(jax.random.PRNGKey(2), jnp.ones(3))
@cgarciae cgarciae self-assigned this Sep 27, 2022
@cgarciae cgarciae added the Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) label Sep 27, 2022
@levskaya
Copy link
Collaborator

re "Problem 1" we've added some clearer errors in #2535
"Problem 2" is considerably trickier to solve in general - unless we decided to basically create a new "pre-post-init" "hook" function that we'd direct people to override - much as you've done here.

@ppwwyyxx
Copy link
Contributor Author

Thanks for addressing it!

Yeah I don't have any concrete suggestions about Problem2 either. I guess it's just an inflexibility of inheritance in general - subclasses can customize the overrided function, but cannot easily control when/where is called.

Closing this since __do_post_init seems like the best thing we can do.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)
Projects
None yet
Development

No branches or pull requests

3 participants