Skip to content

Commit

Permalink
Improve ENABLE_CONTRACT_CHECKING handling in contract-checking custom…
Browse files Browse the repository at this point in the history
… setattr (#932)
  • Loading branch information
Bruce-8 authored Jul 28, 2023
1 parent 2f6966a commit 97746f7
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 26 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ and adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
outer function, efficiency of code was improved, and the attribute value is now restored to the original value if the
`_check_invariants` call raises an error.
- Added new function `validate_invariants` which takes in an object and checks that the representation invariants of the object are satisfied.
- The check for `ENABLE_CONTRACT_CHECKING` is now moved to the top of the body of the `new_setattr` function.
- Added the file `conftest.py` to store `pytest` fixtures.

### Bug Fixes

Expand Down
5 changes: 4 additions & 1 deletion python_ta/contracts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def new_setattr(self: klass, name: str, value: Any) -> None:
Check representation invariants for this class when not within an instance method of the class.
"""
if not ENABLE_CONTRACT_CHECKING:
super(klass, self).__setattr__(name, value)
return
if name in cls_annotations:
try:
_debug(f"Checking type of attribute {attr} for {klass.__qualname__} instance")
Expand All @@ -165,7 +168,7 @@ def new_setattr(self: klass, name: str, value: Any) -> None:
frame_locals = inspect.currentframe().f_back.f_locals
if self is not frame_locals.get("self"):
# Only validating if the attribute is not being set in a instance/class method
if klass_mod is not None and ENABLE_CONTRACT_CHECKING:
if klass_mod is not None:
try:
_check_invariants(self, klass, klass_mod.__dict__)
except PyTAContractError as e:
Expand Down
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest

import python_ta.contracts


@pytest.fixture()
def disable_contract_checking():
"""Fixture for setting python_ta.contracts.ENABLE_CONTRACT_CHECKING = False."""
python_ta.contracts.ENABLE_CONTRACT_CHECKING = False
yield
python_ta.contracts.ENABLE_CONTRACT_CHECKING = True
4 changes: 1 addition & 3 deletions tests/test_class_contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,15 +240,13 @@ def test_pizza_invalid() -> None:
)


def test_pizza_invalid_disable_contract_checking() -> None:
def test_pizza_invalid_disable_contract_checking(disable_contract_checking) -> None:
"""
Test the Pizza representation invariant on an invalid instance but with ENABLE_CONTRACT_CHECKING = False so
no error is raised.
"""
python_ta.contracts.ENABLE_CONTRACT_CHECKING = False
pizza = Pizza(radius=10, ingredients=[])
assert pizza.radius == 10 and pizza.ingredients == []
python_ta.contracts.ENABLE_CONTRACT_CHECKING = True # Reset default value to True


def test_set_wrapper_valid() -> None:
Expand Down
41 changes: 19 additions & 22 deletions tests/test_contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,15 @@ def nullary() -> bool:
nullary()


def test_nullary_int_bool_disable_contract_checking() -> None:
def test_nullary_int_bool_disable_contract_checking(disable_contract_checking) -> None:
"""Calling a nullary function with incorrect return type and with ENABLE_CONTRACT_CHECKING disabled so no error
is raised."""

@check_contracts
def nullary() -> int:
return True

python_ta.contracts.ENABLE_CONTRACT_CHECKING = False
nullary()
python_ta.contracts.ENABLE_CONTRACT_CHECKING = True # Reset default value to True


def test_nullary_no_return_type() -> None:
Expand Down Expand Up @@ -233,14 +231,12 @@ def parameter_bool(result: bool) -> None:
parameter_bool(1)


def test_parameter_int_bool_disable_contract_checking() -> None:
def test_parameter_int_bool_disable_contract_checking(disable_contract_checking) -> None:
@check_contracts
def parameter_int(num: int) -> None:
return None

python_ta.contracts.ENABLE_CONTRACT_CHECKING = False
parameter_int(True)
python_ta.contracts.ENABLE_CONTRACT_CHECKING = True # Reset default value to True


@check_contracts
Expand Down Expand Up @@ -274,12 +270,10 @@ def test_my_sum_one_pre_violation() -> None:
assert "len(numbers) > 2" in msg


def test_my_sum_one_disable_contract_checking() -> None:
def test_my_sum_one_disable_contract_checking(disable_contract_checking) -> None:
"""Calling _my_sum_one_precondition with a value that violates the precondition but with ENABLE_CONTRACT_CHECKING
= False so no error is raised"""
python_ta.contracts.ENABLE_CONTRACT_CHECKING = False
_my_sum_one_precondition([1])
python_ta.contracts.ENABLE_CONTRACT_CHECKING = True # Reset default value to True


# Checking to see if functions we defined are in-scope for preconditions
Expand Down Expand Up @@ -449,13 +443,11 @@ def test_get_double_invalid() -> None:
assert "$return_value == num * 2" in msg


def test_get_double_disabled_contract_checking() -> None:
def test_get_double_disabled_contract_checking(disable_contract_checking) -> None:
"""Test that calling the invalid implementation of _get_double does NOT raise an AssertionError when
ENABLE_CONTRACT_CHECKING is False.
"""
python_ta.contracts.ENABLE_CONTRACT_CHECKING = False
assert _get_double_invalid(5) == 11
python_ta.contracts.ENABLE_CONTRACT_CHECKING = True # Reset default value to True


# Test that postcondition checks involving function parameters pass and fail as expected
Expand Down Expand Up @@ -686,16 +678,6 @@ def test_check_all_contracts_module_names_argument() -> None:
run()


@pytest.fixture()
def disable_contract_checking():
"""Fixture for setting python_ta.contracts.ENABLE_CONTRACT_CHECKING = False."""
import python_ta.contracts

python_ta.contracts.ENABLE_CONTRACT_CHECKING = False
yield
python_ta.contracts.ENABLE_CONTRACT_CHECKING = True


def test_enable_contract_checking_false(disable_contract_checking) -> None:
"""Test that check_contracts does nothing when ENABLE_CONTRACT_CHECKING is False."""

Expand All @@ -705,3 +687,18 @@ def unary2(arg: int) -> int:

# No error should be raised even though the argument is the wrong type
assert unary2("wrong type!") == "wrong type!"


def test_invalid_attr_type_disable_contract_checking(disable_contract_checking) -> None:
"""
Test that a Person object is created with an attribute value that doesn't match the specified type annotation but
with ENABLE_CONTRACT_CHECKING = False so no error is raised.
"""

@check_contracts
class Person:
age: int

my_person = Person()
my_person.age = "John"
assert my_person.age == "John"

0 comments on commit 97746f7

Please sign in to comment.