diff --git a/src/nagini_translation/resources/bool.sil b/src/nagini_translation/resources/bool.sil index 1261d6e8..b3e4e0d9 100644 --- a/src/nagini_translation/resources/bool.sil +++ b/src/nagini_translation/resources/bool.sil @@ -56,17 +56,17 @@ function int___eq__(self: Ref, other: Ref): Bool decreases _ requires issubtype(typeof(self), int()) requires issubtype(typeof(other), int()) -{ - int___unbox__(self) == int___unbox__(other) -} + ensures result == int___unbox__(self) == int___unbox__(other) + ensures result == object___eq__(self, other) + function bool___eq__(self: Ref, other: Ref): Bool decreases _ requires issubtype(typeof(self), bool()) requires issubtype(typeof(other), bool()) -{ - bool___unbox__(self) == bool___unbox__(other) -} + ensures result == bool___unbox__(self) == bool___unbox__(other) + ensures result == object___eq__(self, other) + function int___ge__(self: Int, other: Int): Bool decreases _ @@ -135,10 +135,25 @@ function int___int__(self: Ref): Ref requires issubtype(typeof(self), int()) ensures result == self -function object___eq__(self: Ref, other: Ref): Bool - decreases _ - ensures self == other ==> result - ensures ((self == null) != (other == null)) ==> !result +domain __ObjectEquality { + function object___eq__(Ref, Ref): Bool + + axiom { + forall o1: Ref, o2: Ref, o3: Ref :: + { object___eq__(o1, o2), object___eq__(o2, o3) } + { object___eq__(o1, o2), object___eq__(o1, o3) } + { object___eq__(o2, o3), object___eq__(o1, o3) } + object___eq__(o1, o2) && object___eq__(o2, o3) ==> object___eq__(o1, o3) + } + + axiom { + forall o1: Ref, o2: Ref :: { object___eq__(o1, o2) } + (object___eq__(o1, o2) == object___eq__(o2, o1)) && + (o1 == o2 ==> object___eq__(o1, o2)) && + (((o1 == null) != (o2 == null)) ==> !object___eq__(o1, o2)) + } + +} function Place___eq__(self: Ref, other: Ref): Bool decreases _ diff --git a/src/nagini_translation/resources/bytes.sil b/src/nagini_translation/resources/bytes.sil index a5fb515b..89ee78de 100644 --- a/src/nagini_translation/resources/bytes.sil +++ b/src/nagini_translation/resources/bytes.sil @@ -33,6 +33,7 @@ function bytes___eq__(self: Ref, other: Ref): Bool requires issubtype(typeof(self), bytes()) ensures (bytes___val__(self) == bytes___val__(other)) == result ensures result ==> (issubtype(typeof(other), bytes()) && bytes___len__(self) == bytes___len__(other)) + ensures result == object___eq__(self, other) function bytes___sil_seq__(self: Ref) : Seq[Ref] decreases _ diff --git a/src/nagini_translation/resources/preamble.index b/src/nagini_translation/resources/preamble.index index be52406f..0b549d3a 100644 --- a/src/nagini_translation/resources/preamble.index +++ b/src/nagini_translation/resources/preamble.index @@ -559,7 +559,7 @@ "__eq__": { "args": ["tuple", "object"], "type": "__prim__bool", - "requires": ["__len__", "__getitem__"] + "requires": ["__len__", "__getitem__", "object___eq__"] }, "__sil_seq__": { "args": ["tuple"], diff --git a/src/nagini_translation/resources/pset.sil b/src/nagini_translation/resources/pset.sil index 334f8c29..6649b463 100644 --- a/src/nagini_translation/resources/pset.sil +++ b/src/nagini_translation/resources/pset.sil @@ -52,6 +52,7 @@ function PSet___eq__(self: Ref, other: Ref): Bool requires PSet_arg(typeof(self), 0) == PSet_arg(typeof(other), 0) ensures result == (PSet___unbox__(self) == PSet___unbox__(other)) ensures result ==> self == other + ensures result == object___eq__(self, other) @@ -103,3 +104,4 @@ function PMultiset___eq__(self: Ref, other: Ref): Bool requires PMultiset_arg(typeof(self), 0) == PMultiset_arg(typeof(other), 0) ensures result == (PMultiset___unbox__(self) == PMultiset___unbox__(other)) ensures result ==> self == other // extensionality + ensures result == object___eq__(self, other) diff --git a/src/nagini_translation/resources/range.sil b/src/nagini_translation/resources/range.sil index 8983ffb9..97d02389 100644 --- a/src/nagini_translation/resources/range.sil +++ b/src/nagini_translation/resources/range.sil @@ -63,6 +63,7 @@ function range___eq__(self: Ref, other: Ref): Bool requires issubtype(typeof(self), range_0()) ensures (range___val__(self) == range___val__(other)) == result ensures result ==> (issubtype(typeof(other), range_0()) && range___len__(self) == range___len__(other)) + ensures result == object___eq__(self, other) function range___contains__(self: Ref, item: Ref): Bool diff --git a/src/nagini_translation/resources/seq.sil b/src/nagini_translation/resources/seq.sil index 5e35d1e3..d430c764 100644 --- a/src/nagini_translation/resources/seq.sil +++ b/src/nagini_translation/resources/seq.sil @@ -64,6 +64,7 @@ function PSeq___eq__(self: Ref, other: Ref): Bool requires PSeq_arg(typeof(self), 0) == PSeq_arg(typeof(other), 0) ensures result == (PSeq___sil_seq__(self) == PSeq___sil_seq__(other)) ensures result ==> self == other // extensionality + ensures result == object___eq__(self, other) domain __MSHelper[T$] { diff --git a/src/nagini_translation/resources/str.sil b/src/nagini_translation/resources/str.sil index 6540b2d6..3144072b 100644 --- a/src/nagini_translation/resources/str.sil +++ b/src/nagini_translation/resources/str.sil @@ -29,6 +29,7 @@ function str___eq__(self: Ref, other: Ref): Bool requires issubtype(typeof(self), str()) ensures (str___val__(self) == str___val__(other)) == result ensures result ==> (str___len__(self) == str___len__(other)) + ensures result == object___eq__(self, other) function str___add__(self: Ref, other: Ref): Ref decreases _ diff --git a/src/nagini_translation/resources/tuple.sil b/src/nagini_translation/resources/tuple.sil index cc2c56dd..f3f30b7c 100644 --- a/src/nagini_translation/resources/tuple.sil +++ b/src/nagini_translation/resources/tuple.sil @@ -119,7 +119,9 @@ function tuple___contains__(self: Ref, item: Ref): Bool function tuple___eq__(self: Ref, other: Ref): Bool decreases _ - ensures (tuple___len__(self) == tuple___len__(other) && - (forall i: Int :: {tuple___getitem__(self, i), tuple___getitem__(other, i)} i >= 0 && i < tuple___len__(self) - ==> tuple___getitem__(self, i) == tuple___getitem__(other, i))) - ==> result + ensures result <==> + (tuple___len__(self) == tuple___len__(other) && + (forall i: Int :: { tuple___getitem__(self, i) } + { tuple___getitem__(other, i)} + i >= 0 && i < tuple___len__(self) + ==> object___eq__(tuple___getitem__(self, i), tuple___getitem__(other, i)))) diff --git a/src/nagini_translation/translators/common.py b/src/nagini_translation/translators/common.py index 215e717a..9660d35d 100644 --- a/src/nagini_translation/translators/common.py +++ b/src/nagini_translation/translators/common.py @@ -20,6 +20,7 @@ MAIN_METHOD_NAME, MAY_SET_PRED, NAME_DOMAIN, + OBJECT_TYPE, PRIMITIVE_BOOL_TYPE, PRIMITIVE_INT_TYPE, RANGE_TYPE, @@ -699,9 +700,9 @@ def get_function_call(self, receiver: PythonType, guard = self.type_check(args[0], cls, position, ctx) # Translate the function call on this particular receiver's class - function = self._get_function_call(cls, func_name, args, - arg_types, node, ctx, - position) + function = self.get_function_call(cls, func_name, args, + arg_types, node, ctx, + position) # Stores guard and translated function call as tuple in a list guarded_functions.append((guard, function)) @@ -711,6 +712,14 @@ def get_function_call(self, receiver: PythonType, return chain_cond_exp(guarded_functions, self.viper, position, self.no_info(ctx), ctx) else: + if func_name == '__eq__': + func_cls = receiver.get_function(func_name).cls + if func_cls.name == OBJECT_TYPE: + assert len(args) == 2 + arg1 = self.to_ref(args[0], ctx) + arg2 = self.to_ref(args[1], ctx) + return self.viper.DomainFuncApp('object___eq__', [arg1, arg2], self.viper.Bool, position, + self.no_info(ctx), '__ObjectEquality') if receiver.python_class.name == FLOAT_TYPE: if ctx.float_encoding is None: import logging diff --git a/src/nagini_translation/translators/expression.py b/src/nagini_translation/translators/expression.py index 1660adc4..09d4f7d9 100644 --- a/src/nagini_translation/translators/expression.py +++ b/src/nagini_translation/translators/expression.py @@ -1130,7 +1130,7 @@ def translate_Compare(self, node: ast.Compare, comparison = self.get_function_call(left_type, compare_func, [left, right], [left_type, right_type], - node, ctx) + node, ctx, position) elif compare_func == '__ne__' and left_type.get_function('__eq__'): # The default behavior if __ne__ is not explicitly defined # is to invert the result of __eq__. diff --git a/tests/functional/verification/issues/00164.py b/tests/functional/verification/issues/00164.py new file mode 100644 index 00000000..b1f40426 --- /dev/null +++ b/tests/functional/verification/issues/00164.py @@ -0,0 +1,46 @@ +from typing import * + +from nagini_contracts.contracts import * + +Shape = Tuple[int, ...] + +class ndarray: + @property + def shape(self) -> Shape: + ... + + @shape.setter + def shape(self, new_shape: Shape) -> None: + ... + + +@Predicate +@ContractOnly +def array_pred(array: ndarray) -> bool: + return True + + +@Pure +@ContractOnly +def array_shape(array: ndarray) -> Shape: #type: ignore[return] + Requires(Acc(array_pred(array), 1/2)) + ... + +@ContractOnly +def ones(shape: Shape) -> ndarray: #type: ignore[return] + Requires(len(shape) > 0) + Requires(Forall(shape, lambda l: l > 0)) + + Ensures(array_pred(Result())) + Ensures(array_shape(Result()) == shape) + ... + +shape = (2,) +array1 = ones(shape) +array2 = ones(shape) + + + +assert array_shape(array1) == shape +assert array_shape(array2) == shape +assert array_shape(array1) == array_shape(array2) \ No newline at end of file diff --git a/tests/sif-true/verification/test_lowval.py b/tests/sif-true/verification/test_lowval.py index e3f36c7b..3b4b00af 100644 --- a/tests/sif-true/verification/test_lowval.py +++ b/tests/sif-true/verification/test_lowval.py @@ -60,7 +60,6 @@ def example_tuple_low(secret: bool) -> Example: def example_tuple_lowval(secret: bool) -> Example: Ensures(Acc(Result().f) and Acc(Result().g)) - #:: ExpectedOutput(postcondition.violated:assertion.false) Ensures(LowVal((Result().f, Result().g))) a = Example() b = Example()