Skip to content

Commit

Permalink
Added IncorrectPostInitOverrideError to capture incorrect post init o…
Browse files Browse the repository at this point in the history
…verrides.

Additionally, re-factored runtime tests into helper method.

PiperOrigin-RevId: 481261709
  • Loading branch information
Flax Team committed Oct 17, 2022
1 parent e0de630 commit ded3d78
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 7 deletions.
25 changes: 24 additions & 1 deletion flax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def __init__(self):

class InvalidInstanceModuleError(FlaxError):
"""
This error occurs when you are trying to call `.init()`, `.init_with_output()` or `.apply()
This error occurs when you are trying to call `.init()`, `.init_with_output()`, `.apply() or `.bind()`
on the Module class itself, instead of an instance of the Module class.
For example, the error will be raised when trying to run this code::
Expand All @@ -588,6 +588,29 @@ def __call__(self, x):
def __init__(self):
super().__init__('Can only call init, init_with_output or apply methods on an instance of the Module class, not the Module class itself')

class IncorrectPostInitOverrideError(FlaxError):
"""
This error occurs when you overrode `.__post_init__()` without calling `super().__post_init__()`.
For example, the error will be raised when trying to run this code::
from flax import linen as nn
import jax.numpy as jnp
import jax
class A(nn.Module):
x: float
def __post_init__(self):
self.x_square = self.x ** 2
# super().__post_init__() <-- forgot to add this line
@nn.compact
def __call__(self, input):
return input + 3
r = A(x=3)
r.init(jax.random.PRNGKey(2), jnp.ones(3))
"""
def __init__(self):
super().__init__('Overrode `.__post_init__()` without calling `super().__post_init__()`')


class InvalidCheckpointError(FlaxError):
"""
Expand Down
21 changes: 15 additions & 6 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,16 @@ def is_initializing(self) -> bool:
raise ValueError("Can't check if running under init() on unbound modules")
return self.scope.get_flag('initializing', False)

def _module_checks(self):
"""Run standard runtime checks."""

if not isinstance(self, Module):
raise errors.InvalidInstanceModuleError()

overridden_post_init = self.__post_init__ != Module.__post_init__
if overridden_post_init and not hasattr(self, "_id"):
raise errors.IncorrectPostInitOverrideError()

@traceback_util.api_boundary
def bind(self,
variables: VariableDict,
Expand Down Expand Up @@ -1191,6 +1201,8 @@ def __call__(self, x):
Returns:
A copy of this instance with bound variables and RNGs.
"""
Module._module_checks(self)

del args
scope = core.bind(variables, rngs=rngs, mutable=mutable)
return self.clone(parent=scope)
Expand Down Expand Up @@ -1255,8 +1267,7 @@ def other_fn(instance, ...):
mutable, returns ``(output, vars)``, where ``vars`` are is a dict
of the modified collections.
"""
if not isinstance(self, Module):
raise errors.InvalidInstanceModuleError()
Module._module_checks(self)

if method is None:
method = self.__call__
Expand Down Expand Up @@ -1298,8 +1309,7 @@ def init_with_output(self,
`(output, vars)``, where ``vars`` are is a dict of the modified
collections.
"""
if not isinstance(self, Module):
raise errors.InvalidInstanceModuleError()
Module._module_checks(self)

if not isinstance(rngs, dict):
if not core.scope._is_valid_rng(rngs):
Expand Down Expand Up @@ -1354,8 +1364,7 @@ def init(self,
Returns:
The initialized variable dict.
"""
if not isinstance(self, Module):
raise errors.InvalidInstanceModuleError()
Module._module_checks(self)

_, v_out = self.init_with_output(
rngs,
Expand Down
17 changes: 17 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1772,6 +1772,23 @@ def __call__(self, x):
B.init_with_output(k, x)
with self.assertRaises(errors.InvalidInstanceModuleError):
B.apply({}, x) # similar issue w. apply called on class instead of instance.
with self.assertRaises(errors.InvalidInstanceModuleError):
B.bind({}, x) # similar issue w. apply called on class instead of instance.

def test_throws_incorrect_post_init_override_error(self):

class A(nn.Module):
x: float
def __post_init__(self):
self.x_square = self.x ** 2
@nn.compact
def __call__(self, input):
return input + 3

r = A(x=3)

with self.assertRaises(errors.IncorrectPostInitOverrideError):
r.init(jax.random.PRNGKey(2), jnp.ones(3))

class LeakTests(absltest.TestCase):

Expand Down

0 comments on commit ded3d78

Please sign in to comment.