Skip to content

Commit

Permalink
Add validate_invariants function for manual contract checking (#928)
Browse files Browse the repository at this point in the history
Co-authored-by: Sarah Wang <sarahxp.wang@mail.utoronto.ca>
  • Loading branch information
sarahsonder and sarahsonder committed Jul 27, 2023
1 parent 5da03c0 commit b89f42c
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ and adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
- For the contract checking `new_setattr` function, any variables that depend only on `klass` are now defined in the
outer function, efficiency of code was improved, and the attribute value is now restored to the original value if the
`_check_invariants` call raises an error.
- Added new function `validate_invariants` which takes in an object and checks that the representation invariants of the object are satisfied.

### Bug Fixes

Expand Down
9 changes: 7 additions & 2 deletions docs/contracts/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ AssertionError: divide argument '2' did not match type annotation for parameter

The `python_ta.contracts` module offers two functions for enabling contract checking.
The first, `check_all_contracts`, enables contract checking for all functions and classes defined within a module or set of modules.
The second, `check_contracts`, is a decorator allowing more fine-grained control over which
functions/classes have contract checking enabled.
The second, `check_contracts`, is a decorator allowing more fine-grained control over which functions/classes have contract checking enabled.

```{eval-rst}
.. autofunction:: python_ta.contracts.check_all_contracts
Expand All @@ -61,6 +60,12 @@ functions/classes have contract checking enabled.
.. autofunction:: python_ta.contracts.check_contracts(func_or_class)
```

You can pass an object into the function `validate_invariants` to manually check the representation invariants of the object.

```{eval-rst}
.. autofunction:: python_ta.contracts.validate_invariants(object)
```

You can set the `ENABLE_CONTRACT_CHECKING` constant to `True` to enable all contract checking.

```{eval-rst}
Expand Down
58 changes: 37 additions & 21 deletions python_ta/contracts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,27 +137,7 @@ def add_class_invariants(klass: type) -> None:
# This means the class has already been decorated
return

# Update representation invariants from this class' docstring and those of its superclasses.
rep_invariants: List[Tuple[str, CodeType]] = []

# Iterate over all inherited classes except builtins
for cls in reversed(klass.__mro__):
if "__representation_invariants__" in cls.__dict__:
rep_invariants.extend(cls.__representation_invariants__)
elif cls.__module__ != "builtins":
assertions = parse_assertions(cls, parse_token="Representation Invariant")
# Try compiling assertions
for assertion in assertions:
try:
compiled = compile(assertion, "<string>", "eval")
except:
_debug(
f"Warning: representation invariant {assertion} could not be parsed as a valid Python expression"
)
continue
rep_invariants.append((assertion, compiled))

setattr(klass, "__representation_invariants__", rep_invariants)
_set_invariants(klass)

klass_mod = _get_module(klass)
cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
Expand Down Expand Up @@ -603,3 +583,39 @@ def _debug(msg: str) -> None:
return

print("[PyTA]", msg, file=sys.stderr)


def _set_invariants(klass: type) -> None:
"""Retrieve and set the representation invariants of this class"""
# Update representation invariants from this class' docstring and those of its superclasses.
rep_invariants: List[Tuple[str, CodeType]] = []

# Iterate over all inherited classes except builtins
for cls in reversed(klass.__mro__):
if "__representation_invariants__" in cls.__dict__:
rep_invariants.extend(cls.__representation_invariants__)
elif cls.__module__ != "builtins":
assertions = parse_assertions(cls, parse_token="Representation Invariant")
# Try compiling assertions
for assertion in assertions:
try:
compiled = compile(assertion, "<string>", "eval")
except:
_debug(
f"Warning: representation invariant {assertion} could not be parsed as a valid Python expression"
)
continue
rep_invariants.append((assertion, compiled))

setattr(klass, "__representation_invariants__", rep_invariants)


def validate_invariants(obj: object) -> None:
"""Check that the representation invariants of obj are satisfied."""
klass = obj.__class__
klass_mod = _get_module(klass)

try:
_check_invariants(obj, klass, klass_mod.__dict__)
except PyTAContractError as e:
raise AssertionError(str(e)) from None
48 changes: 48 additions & 0 deletions tests/test_validate_invariants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Test suite for checking the functionality of validate_invariants.
"""

from typing import List

import pytest

from python_ta.contracts import check_contracts, validate_invariants


@check_contracts
class Person:
"""A custom data type that represents data for a person.
Representation Invariants:
- self.age >= 0
- len(self.friends) > 1
"""

given_name: str
age: int
friends: List[str]

def __init__(self, given_name: str, age: int, friends: List[str]) -> None:
"""Initialize a new Person object."""
self.given_name = given_name
self.age = age
self.friends = friends


def test_no_errors() -> None:
"""Checks that validate_invariants does not raise an error when representation invariants are satisfied."""
person = Person("Jim", 50, ["Pam", "Dwight"])

try:
validate_invariants(person)
except AssertionError:
pytest.fail("validate_invariants has incorrectly raised an AssertionError")


def test_raise_error() -> None:
"""Checks that validate_invariants raises an error when representation invariants are violated."""
person = Person("Jim", 50, ["Pam", "Dwight"])
person.friends.pop()

with pytest.raises(AssertionError):
validate_invariants(person)

0 comments on commit b89f42c

Please sign in to comment.