Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Module.__check_init__ for after-initialisation checking of invariants. See #472. #492

Merged
merged 1 commit into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions docs/api/module/advanced_fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,69 @@ Equinox modules can be used as [abstract base classes](https://docs.python.org/3
selection:
members: false

## Checking invariants

Equinox extends dataclasses with a `__check_init__` method, which is automatically ran after initialisation. This can be used to check invariants like so:

```python
class Positive(eqx.Module):
x: int

def __check_init__(self):
if self.x <= 0:
raise ValueError("Oh no!")
```

This method has three key differences compared to the `__post_init__` provided by dataclasses:

- It is not overridden by an `__init__` method of a subclass. In contrast, the following code has a silent bug:

```python
class Parent(eqx.Module):
x: int

def __check_init__(self):
if self.x <= 0:
raise ValueError("Oh no!")

class Child(Parent):
x_as_str: str

def __init__(self, x):
self.x = x
self.x_as_str = str(x)

Child(-1) # No error!
```

- It is automatically called for parent classes; `super().__check_init__()` is not required:

```python
class Parent(eqx.Module):
def __check_init__(self):
print("Parent")

class Child(Parent):
def __check_init__(self):
print("Child")

Child() # prints out both Child and Parent
```

As with the previous bullet point, this is to prevent child classes accidentally failing to check that the invariants of their parent hold.

- Assignment is not allowed:

```python
class MyModule(eqx.Module):
foo: int

def __check_init__(self):
self.foo = 1 # will raise an error
```

This is to prevent `__check_init__` from doing anything too surprising: as the name suggests, it's meant to be used for checking invariants.

## Creating wrapper modules

::: equinox.module_update_wrapper
7 changes: 7 additions & 0 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ def __call__(cls, *args, **kwargs):
else:
setattr(self, field.name, converter(getattr(self, field.name)))
object.__setattr__(self, "__class__", cls)
for kls in cls.__mro__:
try:
check = kls.__dict__["__check_init__"]
except KeyError:
pass
else:
check(self)
return self


Expand Down
102 changes: 102 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import functools as ft
from typing import Any

Expand Down Expand Up @@ -285,3 +286,104 @@ class B(A, foo=True):
pass

assert called


def test_check_init():
class FooException(Exception):
pass

called_a = False
called_b = False

class A(eqx.Module):
a: int

def __check_init__(self):
nonlocal called_a
called_a = True
if self.a >= 0:
raise FooException

class B(A):
def __check_init__(self):
nonlocal called_b
called_b = True

class C(A):
pass

assert not called_a
assert not called_b
A(-1)
assert called_a
assert not called_b

called_a = False
with pytest.raises(FooException):
A(1)
assert called_a
assert not called_b

called_a = False
B(-1)
assert called_a
assert called_b

called_a = False
called_b = False
with pytest.raises(FooException):
B(1)
assert called_a
assert called_b # B.__check_init__ is called before A.__check_init__

called_a = False
called_b = False
C(-1)
assert called_a
assert not called_b

called_a = False
with pytest.raises(FooException):
C(1)
assert called_a
assert not called_b


def test_check_init_order():
called_a = False
called_b = False
called_c = False

class A(eqx.Module):
def __check_init__(self):
nonlocal called_a
called_a = True

class B(A):
def __check_init__(self):
nonlocal called_b
called_b = True
raise ValueError

class C(B):
def __check_init__(self): # pyright: ignore
nonlocal called_c
called_c = True

with pytest.raises(ValueError):
C()

assert called_c
assert called_b
assert not called_a


def test_check_init_no_assignment():
class A(eqx.Module):
x: int

def __check_init__(self):
self.x = 4

with pytest.raises(dataclasses.FrozenInstanceError):
A(1)
Loading