Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added IncorrectPostInitOverrideError to capture incorrect post init overrides. #2535

Merged
merged 1 commit into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion flax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def __init__(self):

class InvalidInstanceModuleError(FlaxError):
"""
This error occurs when you are trying to call `.init()`, `.init_with_output()` or `.apply()
This error occurs when you are trying to call `.init()`, `.init_with_output()`, `.apply() or `.bind()`
on the Module class itself, instead of an instance of the Module class.
For example, the error will be raised when trying to run this code::

Expand All @@ -588,6 +588,29 @@ def __call__(self, x):
def __init__(self):
super().__init__('Can only call init, init_with_output or apply methods on an instance of the Module class, not the Module class itself')

class IncorrectPostInitOverrideError(FlaxError):
"""
This error occurs when you overrode `.__post_init__()` without calling `super().__post_init__()`.
For example, the error will be raised when trying to run this code::

from flax import linen as nn
import jax.numpy as jnp
import jax
class A(nn.Module):
x: float
def __post_init__(self):
self.x_square = self.x ** 2
# super().__post_init__() <-- forgot to add this line
@nn.compact
def __call__(self, input):
return input + 3

r = A(x=3)
r.init(jax.random.PRNGKey(2), jnp.ones(3))
"""
def __init__(self):
super().__init__('Overrode `.__post_init__()` without calling `super().__post_init__()`')


class InvalidCheckpointError(FlaxError):
"""
Expand Down
21 changes: 15 additions & 6 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,16 @@ def is_initializing(self) -> bool:
raise ValueError("Can't check if running under init() on unbound modules")
return self.scope.get_flag('initializing', False)

def _module_checks(self):
"""Run standard runtime checks."""

if not isinstance(self, Module):
raise errors.InvalidInstanceModuleError()

overridden_post_init = self.__post_init__ != Module.__post_init__
if overridden_post_init and not hasattr(self, "_id"):
raise errors.IncorrectPostInitOverrideError()

@traceback_util.api_boundary
def bind(self,
variables: VariableDict,
Expand Down Expand Up @@ -1191,6 +1201,8 @@ def __call__(self, x):
Returns:
A copy of this instance with bound variables and RNGs.
"""
Module._module_checks(self)

del args
scope = core.bind(variables, rngs=rngs, mutable=mutable)
return self.clone(parent=scope)
Expand Down Expand Up @@ -1255,8 +1267,7 @@ def other_fn(instance, ...):
mutable, returns ``(output, vars)``, where ``vars`` are is a dict
of the modified collections.
"""
if not isinstance(self, Module):
raise errors.InvalidInstanceModuleError()
Module._module_checks(self)

if method is None:
method = self.__call__
Expand Down Expand Up @@ -1298,8 +1309,7 @@ def init_with_output(self,
`(output, vars)``, where ``vars`` are is a dict of the modified
collections.
"""
if not isinstance(self, Module):
raise errors.InvalidInstanceModuleError()
Module._module_checks(self)

if not isinstance(rngs, dict):
if not core.scope._is_valid_rng(rngs):
Expand Down Expand Up @@ -1354,8 +1364,7 @@ def init(self,
Returns:
The initialized variable dict.
"""
if not isinstance(self, Module):
raise errors.InvalidInstanceModuleError()
Module._module_checks(self)

_, v_out = self.init_with_output(
rngs,
Expand Down
17 changes: 17 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1772,6 +1772,23 @@ def __call__(self, x):
B.init_with_output(k, x)
with self.assertRaises(errors.InvalidInstanceModuleError):
B.apply({}, x) # similar issue w. apply called on class instead of instance.
with self.assertRaises(errors.InvalidInstanceModuleError):
B.bind({}, x) # similar issue w. apply called on class instead of instance.

def test_throws_incorrect_post_init_override_error(self):

class A(nn.Module):
x: float
def __post_init__(self):
self.x_square = self.x ** 2
@nn.compact
def __call__(self, input):
return input + 3

r = A(x=3)

with self.assertRaises(errors.IncorrectPostInitOverrideError):
r.init(jax.random.PRNGKey(2), jnp.ones(3))

class LeakTests(absltest.TestCase):

Expand Down