Skip to content

Commit

Permalink
Added Module.__check_init__ for after-initialisation checking of inva…
Browse files Browse the repository at this point in the history
…riants. See #472.
  • Loading branch information
patrick-kidger committed Sep 12, 2023
1 parent 50c3ef2 commit 82c33ae
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 0 deletions.
63 changes: 63 additions & 0 deletions docs/api/module/advanced_fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,69 @@ Equinox modules can be used as [abstract base classes](https://docs.python.org/3
selection:
members: false

## Checking invariants

Equinox extends dataclasses with a `__check_init__` method, which is automatically ran after initialisation. This can be used to check invariants like so:

```python
class Positive(eqx.Module):
x: int

def __check_init__(self):
if self.x <= 0:
raise ValueError("Oh no!")
```

This method has three key differences compared to the `__post_init__` provided by dataclasses:

- It is not overridden by an `__init__` method of a subclass. In contrast, the following code has a silent bug:

```python
class Parent(eqx.Module):
x: int

def __check_init__(self):
if self.x <= 0:
raise ValueError("Oh no!")

class Child(Parent):
x_as_str: str

def __init__(self, x):
self.x = x
self.x_as_str = str(x)

Child(-1) # No error!
```

- It is automatically called for parent classes; `super().__check_init__()` is not required:

```python
class Parent(eqx.Module):
def __check_init__(self):
print("Parent")

class Child(Parent):
def __check_init__(self):
print("Child")

Child() # prints out both Child and Parent
```

As with the previous bullet point, this is to prevent child classes accidentally failing to check that the invariants of their parent hold.

- Assignment is not allowed:

```python
class MyModule(eqx.Module):
foo: int

def __check_init__(self):
self.foo = 1 # will raise an error
```

This is to prevent `__check_init__` from doing anything too surprising: as the name suggests, it's meant to be used for checking invariants.

## Creating wrapper modules

::: equinox.module_update_wrapper
7 changes: 7 additions & 0 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,13 @@ def __call__(cls, *args, **kwargs):
else:
setattr(self, field.name, converter(getattr(self, field.name)))
object.__setattr__(self, "__class__", cls)
for kls in cls.__mro__:
try:
check = kls.__dict__["__check_init__"]
except KeyError:
pass
else:
check(self)
return self


Expand Down
102 changes: 102 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import functools as ft
from typing import Any

Expand Down Expand Up @@ -285,3 +286,104 @@ class B(A, foo=True):
pass

assert called


def test_check_init():
class FooException(Exception):
pass

called_a = False
called_b = False

class A(eqx.Module):
a: int

def __check_init__(self):
nonlocal called_a
called_a = True
if self.a >= 0:
raise FooException

class B(A):
def __check_init__(self):
nonlocal called_b
called_b = True

class C(A):
pass

assert not called_a
assert not called_b
A(-1)
assert called_a
assert not called_b

called_a = False
with pytest.raises(FooException):
A(1)
assert called_a
assert not called_b

called_a = False
B(-1)
assert called_a
assert called_b

called_a = False
called_b = False
with pytest.raises(FooException):
B(1)
assert called_a
assert called_b # B.__check_init__ is called before A.__check_init__

called_a = False
called_b = False
C(-1)
assert called_a
assert not called_b

called_a = False
with pytest.raises(FooException):
C(1)
assert called_a
assert not called_b


def test_check_init_order():
called_a = False
called_b = False
called_c = False

class A(eqx.Module):
def __check_init__(self):
nonlocal called_a
called_a = True

class B(A):
def __check_init__(self):
nonlocal called_b
called_b = True
raise ValueError

class C(B):
def __check_init__(self): # pyright: ignore
nonlocal called_c
called_c = True

with pytest.raises(ValueError):
C()

assert called_c
assert called_b
assert not called_a


def test_check_init_no_assignment():
class A(eqx.Module):
x: int

def __check_init__(self):
self.x = 4

with pytest.raises(dataclasses.FrozenInstanceError):
A(1)

0 comments on commit 82c33ae

Please sign in to comment.