Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New function validate-invariants #928

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ and adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
- For the contract checking `new_setattr` function, any variables that depend only on `klass` are now defined in the
outer function, efficiency of code was improved, and the attribute value is now restored to the original value if the
`_check_invariants` call raises an error.
- Can now check explicitly for whether the representation invariants of an object are satisfied.
sarahsonder marked this conversation as resolved.
Show resolved Hide resolved

### Bug Fixes

Expand Down
56 changes: 35 additions & 21 deletions python_ta/contracts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,27 +137,7 @@ def add_class_invariants(klass: type) -> None:
# This means the class has already been decorated
return

# Update representation invariants from this class' docstring and those of its superclasses.
rep_invariants: List[Tuple[str, CodeType]] = []

# Iterate over all inherited classes except builtins
for cls in reversed(klass.__mro__):
if "__representation_invariants__" in cls.__dict__:
rep_invariants.extend(cls.__representation_invariants__)
elif cls.__module__ != "builtins":
assertions = parse_assertions(cls, parse_token="Representation Invariant")
# Try compiling assertions
for assertion in assertions:
try:
compiled = compile(assertion, "<string>", "eval")
except:
_debug(
f"Warning: representation invariant {assertion} could not be parsed as a valid Python expression"
)
continue
rep_invariants.append((assertion, compiled))

setattr(klass, "__representation_invariants__", rep_invariants)
_set_invariants(klass)

klass_mod = _get_module(klass)
cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)
Expand Down Expand Up @@ -603,3 +583,37 @@ def _debug(msg: str) -> None:
return

print("[PyTA]", msg, file=sys.stderr)


def _set_invariants(klass: type) -> None:
"""Retrieve and set the representation invariants of this class"""
# Update representation invariants from this class' docstring and those of its superclasses.
rep_invariants: List[Tuple[str, CodeType]] = []

# Iterate over all inherited classes except builtins
for cls in reversed(klass.__mro__):
if "__representation_invariants__" in cls.__dict__:
rep_invariants.extend(cls.__representation_invariants__)
elif cls.__module__ != "builtins":
assertions = parse_assertions(cls, parse_token="Representation Invariant")
# Try compiling assertions
for assertion in assertions:
try:
compiled = compile(assertion, "<string>", "eval")
except:
_debug(
f"Warning: representation invariant {assertion} could not be parsed as a valid Python expression"
)
continue
rep_invariants.append((assertion, compiled))

setattr(klass, "__representation_invariants__", rep_invariants)


def check_invariants(obj: object) -> None:
sarahsonder marked this conversation as resolved.
Show resolved Hide resolved
"""Check that the representation invariants of obj are satisfied."""
klass = obj.__class__
klass_mod = _get_module(klass)

_set_invariants(klass)
sarahsonder marked this conversation as resolved.
Show resolved Hide resolved
_check_invariants(obj, klass, klass_mod.__dict__)
41 changes: 41 additions & 0 deletions tests/test_check_invariants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
Test suite for checking the functionality of check_invariants.
"""

import pytest

from python_ta.contracts import PyTAContractError, check_invariants


class Person:
"""A custom data type that represents data for a person.

Representation Invariants:
- self.age >= 0
"""

given_name: str
age: int

def __init__(self, given_name: str, age: int) -> None:
"""Initialize a new Person object."""
self.given_name = given_name
self.age = age


def test_no_errors() -> None:
"""Checks that check_invariants does not raise an error when representation invariants are satisfied."""
person_obj = Person("Jim", 50)

try:
check_invariants(person_obj)
except Exception:
sarahsonder marked this conversation as resolved.
Show resolved Hide resolved
assert False
sarahsonder marked this conversation as resolved.
Show resolved Hide resolved


def test_raise_error() -> None:
"""Checks that check_invariants raises an error when representation invariants are violated."""
person_obj = Person("Jim", -50)

with pytest.raises(PyTAContractError):
check_invariants(person_obj)