diff --git a/equinox/_module.py b/equinox/_module.py index 69b26d77..d657015c 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -201,6 +201,13 @@ def __call__(cls, *args, **kwargs): else: setattr(self, field.name, converter(getattr(self, field.name))) object.__setattr__(self, "__class__", cls) + for kls in cls.__mro__: + try: + check = kls.__dict__["__check_init__"] + except KeyError: + pass + else: + check(self) return self diff --git a/tests/test_module.py b/tests/test_module.py index 35c0e725..fc654c1d 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -285,3 +285,93 @@ class B(A, foo=True): pass assert called + + +def test_check_init(): + class FooException(Exception): + pass + + called_a = False + called_b = False + + class A(eqx.Module): + a: int + + def __check_init__(self): + nonlocal called_a + called_a = True + if self.a >= 0: + raise FooException + + class B(A): + def __check_init__(self): + nonlocal called_b + called_b = True + + class C(A): + pass + + assert not called_a + assert not called_b + A(-1) + assert called_a + assert not called_b + + called_a = False + with pytest.raises(FooException): + A(1) + assert called_a + assert not called_b + + called_a = False + B(-1) + assert called_a + assert called_b + + called_a = False + called_b = False + with pytest.raises(FooException): + B(1) + assert called_a + assert called_b # B.__check_init__ is called before A.__check_init__ + + called_a = False + called_b = False + C(-1) + assert called_a + assert not called_b + + called_a = False + with pytest.raises(FooException): + C(1) + assert called_a + assert not called_b + + +def test_check_init_order(): + called_a = False + called_b = False + called_c = False + + class A(eqx.Module): + def __check_init__(self): + nonlocal called_a + called_a = True + + class B(A): + def __check_init__(self): + nonlocal called_b + called_b = True + raise ValueError + + class C(B): + def __check_init__(self): # pyright: ignore + nonlocal called_c + called_c = True + + with pytest.raises(ValueError): + C() + + assert called_c + assert called_b + assert not called_a