diff --git a/CHANGELOG.md b/CHANGELOG.md index 5df25114a..fbbd5388f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,9 @@ and adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - When running `check_contracts` on a class with type aliases as type annotations for its attributes, the `NameError` that appears (which indicates that the type alias is undefined) is now resolved. - The default value of `pyta-number-of-messages` is now 0. This automatically displays all occurrences of the same error. +- For the contract checking `new_setattr` function, any variables that depend only on `klass` are now defined in the + outer function, efficiency of code was improved, and the attribute value is now restored to the original value if the + `_check_invariants` call raises an error. ### Bug Fixes diff --git a/python_ta/contracts/__init__.py b/python_ta/contracts/__init__.py index 3cd64c2e9..cc1286470 100644 --- a/python_ta/contracts/__init__.py +++ b/python_ta/contracts/__init__.py @@ -159,14 +159,14 @@ def add_class_invariants(klass: type) -> None: setattr(klass, "__representation_invariants__", rep_invariants) + klass_mod = _get_module(klass) + cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__) + def new_setattr(self: klass, name: str, value: Any) -> None: """Set the value of the given attribute on self to the given value. Check representation invariants for this class when not within an instance method of the class. """ - klass_mod = _get_module(klass) - cls_annotations = typing.get_type_hints(klass, localns=klass_mod.__dict__) - if name in cls_annotations: try: _debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance") @@ -176,18 +176,23 @@ def new_setattr(self: klass, name: str, value: Any) -> None: f"Value {_display_value(value)} did not match type annotation for attribute " f"{name}: {_display_annotation(cls_annotations[name])}" ) from None - + original_attr_value_exists = False + original_attr_value = None + if hasattr(super(klass, self), name): + original_attr_value_exists = True + original_attr_value = super(klass, self).__getattribute__(name) super(klass, self).__setattr__(name, value) - curframe = inspect.currentframe() - callframe = inspect.getouterframes(curframe, 2) - frame_locals = callframe[1].frame.f_locals + frame_locals = inspect.currentframe().f_back.f_locals if self is not frame_locals.get("self"): # Only validating if the attribute is not being set in a instance/class method - klass_mod = _get_module(klass) if klass_mod is not None and ENABLE_CONTRACT_CHECKING: try: _check_invariants(self, klass, klass_mod.__dict__) except PyTAContractError as e: + if original_attr_value_exists: + super(klass, self).__setattr__(name, original_attr_value) + else: + super(klass, self).__delattr__(name) raise AssertionError(str(e)) from None for attr, value in klass.__dict__.items(): diff --git a/tests/test_contracts_attr_value_restoration.py b/tests/test_contracts_attr_value_restoration.py new file mode 100644 index 000000000..d296ee2b9 --- /dev/null +++ b/tests/test_contracts_attr_value_restoration.py @@ -0,0 +1,43 @@ +from python_ta.contracts import check_contracts + + +def test_class_attr_value_restores_to_original_if_violates_rep_inv() -> None: + """Test to check whether the class attribute value is restored to the original value if a representation invariant + is violated.""" + + @check_contracts + class Person: + """ + Representation Invariants: + - self.age >= 0 + """ + + age: int = 0 + + my_person = Person() + + try: + my_person.age = -1 + except AssertionError: + assert my_person.age == 0 + + +def test_class_attr_value_does_not_exist_if_violates_rep_inv() -> None: + """Test to check whether the class attribute value does not exist if a representation invariant + is violated.""" + + @check_contracts + class Person: + """ + Representation Invariants: + - self.age >= 0 + """ + + age: int + + my_person = Person() + + try: + my_person.age = -1 + except AssertionError: + assert not hasattr(my_person, "age")