Skip to content

Commit

Permalink
Improved the behavior of the contract checking new_setattr function (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Bruce-8 committed Jul 17, 2023
1 parent 2109a69 commit 5da03c0
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 8 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ and adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
- When running `check_contracts` on a class with type aliases as type annotations for its attributes, the `NameError`
that appears (which indicates that the type alias is undefined) is now resolved.
- The default value of `pyta-number-of-messages` is now 0. This automatically displays all occurrences of the same error.
- For the contract checking `new_setattr` function, any variables that depend only on `klass` are now defined in the
outer function, efficiency of code was improved, and the attribute value is now restored to the original value if the
`_check_invariants` call raises an error.

### Bug Fixes

Expand Down
21 changes: 13 additions & 8 deletions python_ta/contracts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,14 @@ def add_class_invariants(klass: type) -> None:

setattr(klass, "__representation_invariants__", rep_invariants)

klass_mod = _get_module(klass)
cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)

def new_setattr(self: klass, name: str, value: Any) -> None:
"""Set the value of the given attribute on self to the given value.
Check representation invariants for this class when not within an instance method of the class.
"""
klass_mod = _get_module(klass)
cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__)

if name in cls_annotations:
try:
_debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
Expand All @@ -176,18 +176,23 @@ def new_setattr(self: klass, name: str, value: Any) -> None:
f"Value {_display_value(value)} did not match type annotation for attribute "
f"{name}: {_display_annotation(cls_annotations[name])}"
) from None

original_attr_value_exists = False
original_attr_value = None
if hasattr(super(klass, self), name):
original_attr_value_exists = True
original_attr_value = super(klass, self).__getattribute__(name)
super(klass, self).__setattr__(name, value)
curframe = inspect.currentframe()
callframe = inspect.getouterframes(curframe, 2)
frame_locals = callframe[1].frame.f_locals
frame_locals = inspect.currentframe().f_back.f_locals
if self is not frame_locals.get("self"):
# Only validating if the attribute is not being set in a instance/class method
klass_mod = _get_module(klass)
if klass_mod is not None and ENABLE_CONTRACT_CHECKING:
try:
_check_invariants(self, klass, klass_mod.__dict__)
except PyTAContractError as e:
if original_attr_value_exists:
super(klass, self).__setattr__(name, original_attr_value)
else:
super(klass, self).__delattr__(name)
raise AssertionError(str(e)) from None

for attr, value in klass.__dict__.items():
Expand Down
43 changes: 43 additions & 0 deletions tests/test_contracts_attr_value_restoration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from python_ta.contracts import check_contracts


def test_class_attr_value_restores_to_original_if_violates_rep_inv() -> None:
"""Test to check whether the class attribute value is restored to the original value if a representation invariant
is violated."""

@check_contracts
class Person:
"""
Representation Invariants:
- self.age >= 0
"""

age: int = 0

my_person = Person()

try:
my_person.age = -1
except AssertionError:
assert my_person.age == 0


def test_class_attr_value_does_not_exist_if_violates_rep_inv() -> None:
"""Test to check whether the class attribute value does not exist if a representation invariant
is violated."""

@check_contracts
class Person:
"""
Representation Invariants:
- self.age >= 0
"""

age: int

my_person = Person()

try:
my_person.age = -1
except AssertionError:
assert not hasattr(my_person, "age")

0 comments on commit 5da03c0

Please sign in to comment.