Skip to content

Commit

Permalink
Added Module.__check_init__ for after-initialisation checking of inva…
Browse files Browse the repository at this point in the history
…riants. See #472.
  • Loading branch information
patrick-kidger committed Sep 11, 2023
1 parent cc2df94 commit 1903e09
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 0 deletions.
7 changes: 7 additions & 0 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ def __call__(cls, *args, **kwargs):
else:
setattr(self, field.name, converter(getattr(self, field.name)))
object.__setattr__(self, "__class__", cls)
for kls in cls.__mro__:
try:
check = kls.__dict__["__check_init__"]
except KeyError:
pass
else:
check(self)
return self


Expand Down
90 changes: 90 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,93 @@ class B(A, foo=True):
pass

assert called


def test_check_init():
class FooException(Exception):
pass

called_a = False
called_b = False

class A(eqx.Module):
a: int

def __check_init__(self):
nonlocal called_a
called_a = True
if self.a >= 0:
raise FooException

class B(A):
def __check_init__(self):
nonlocal called_b
called_b = True

class C(A):
pass

assert not called_a
assert not called_b
A(-1)
assert called_a
assert not called_b

called_a = False
with pytest.raises(FooException):
A(1)
assert called_a
assert not called_b

called_a = False
B(-1)
assert called_a
assert called_b

called_a = False
called_b = False
with pytest.raises(FooException):
B(1)
assert called_a
assert called_b # B.__check_init__ is called before A.__check_init__

called_a = False
called_b = False
C(-1)
assert called_a
assert not called_b

called_a = False
with pytest.raises(FooException):
C(1)
assert called_a
assert not called_b


def test_check_init_order():
called_a = False
called_b = False
called_c = False

class A(eqx.Module):
def __check_init__(self):
nonlocal called_a
called_a = True

class B(A):
def __check_init__(self):
nonlocal called_b
called_b = True
raise ValueError

class C(B):
def __check_init__(self): # pyright: ignore
nonlocal called_c
called_c = True

with pytest.raises(ValueError):
C()

assert called_c
assert called_b
assert not called_a

0 comments on commit 1903e09

Please sign in to comment.