Skip to content

Commit

Permalink
fix errors
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Oct 18, 2022
1 parent de065fc commit 63ffcfb
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions flax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,46 @@ def __call__(self, x):
def __init__(self):
super().__init__('Can only call init, init_with_output or apply methods on an instance of the Module class, not the Module class itself')

class IncorrectPostInitOverrideError(FlaxError):
"""
This error occurs when you overrode `.__post_init__()` without calling `super().__post_init__()`.
For example, the error will be raised when trying to run this code::
from flax import linen as nn
import jax.numpy as jnp
import jax
class A(nn.Module):
x: float
def __post_init__(self):
self.x_square = self.x ** 2
# super().__post_init__() <-- forgot to add this line
@nn.compact
def __call__(self, input):
return input + 3
r = A(x=3)
r.init(jax.random.PRNGKey(2), jnp.ones(3))
"""
def __init__(self):
super().__init__('Overrode `.__post_init__()` without calling `super().__post_init__()`')

class DescriptorAttributeError(FlaxError):
"""
This error occurs when you are trying to access a property that is accessing a non-existent attribute.
For example, the error will be raised when trying to run this code::
class Foo(nn.Module):
@property
def prop(self):
return self.non_existent_field # ERROR!
def __call__(self, x):
return self.prop
foo = Foo()
variables = foo.init(jax.random.PRNGKey(0), jnp.ones(shape=(1, 8)))
"""
def __init__(self):
super().__init__('Trying to access a property that is accessing a non-existent attribute.')

class InvalidCheckpointError(FlaxError):
"""
Expand Down

0 comments on commit 63ffcfb

Please sign in to comment.