From 63ffcfb97f93aff8364292e73bceb96eddf656d6 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Tue, 18 Oct 2022 20:32:03 +0000 Subject: [PATCH] fix errors --- flax/errors.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/flax/errors.py b/flax/errors.py index 611c6015b8..891b6f9cb2 100644 --- a/flax/errors.py +++ b/flax/errors.py @@ -588,6 +588,46 @@ def __call__(self, x): def __init__(self): super().__init__('Can only call init, init_with_output or apply methods on an instance of the Module class, not the Module class itself') +class IncorrectPostInitOverrideError(FlaxError): + """ + This error occurs when you overrode `.__post_init__()` without calling `super().__post_init__()`. + For example, the error will be raised when trying to run this code:: + + from flax import linen as nn + import jax.numpy as jnp + import jax + class A(nn.Module): + x: float + def __post_init__(self): + self.x_square = self.x ** 2 + # super().__post_init__() <-- forgot to add this line + @nn.compact + def __call__(self, input): + return input + 3 + + r = A(x=3) + r.init(jax.random.PRNGKey(2), jnp.ones(3)) + """ + def __init__(self): + super().__init__('Overrode `.__post_init__()` without calling `super().__post_init__()`') + +class DescriptorAttributeError(FlaxError): + """ + This error occurs when you are trying to access a property that is accessing a non-existent attribute. + For example, the error will be raised when trying to run this code:: + + class Foo(nn.Module): + @property + def prop(self): + return self.non_existent_field # ERROR! + def __call__(self, x): + return self.prop + + foo = Foo() + variables = foo.init(jax.random.PRNGKey(0), jnp.ones(shape=(1, 8))) + """ + def __init__(self): + super().__init__('Trying to access a property that is accessing a non-existent attribute.') class InvalidCheckpointError(FlaxError): """