diff --git a/src/nagini_contracts/contracts.py b/src/nagini_contracts/contracts.py index 38cc3cb6..9e5fcf2b 100644 --- a/src/nagini_contracts/contracts.py +++ b/src/nagini_contracts/contracts.py @@ -31,7 +31,7 @@ 'Acc', 'Rd', 'Wildcard', 'Fold', 'Unfold', 'Unfolding', 'Previous', 'RaisedException', 'PSeq', 'PSet', 'ToSeq', 'ToMS', 'MaySet', 'MayCreate', 'getMethod', 'getArg', 'getOld', 'arg', 'Joinable', 'MayStart', 'Let', - 'PMultiset', 'LowExit', 'Refute'] + 'PMultiset', 'LowExit', 'Refute', 'isNaN'] T = TypeVar('T') V = TypeVar('V') @@ -506,6 +506,8 @@ def dict_pred(d: object) -> bool: be folded or unfolded. """ +def isNaN(f: float) -> bool: + pass __all__ = [ 'Requires', @@ -560,4 +562,5 @@ def dict_pred(d: object) -> bool: 'ToMS', 'MaySet', 'MayCreate', + 'isNaN' ] diff --git a/src/nagini_translation/lib/resolver.py b/src/nagini_translation/lib/resolver.py index c0063822..d9d26b1d 100644 --- a/src/nagini_translation/lib/resolver.py +++ b/src/nagini_translation/lib/resolver.py @@ -461,7 +461,7 @@ def _get_call_type(node: ast.Call, module: PythonModule, return ctx.current_contract_exception elif node.func.id in ('Acc', 'Rd', 'Read', 'Implies', 'Forall', 'IOForall', 'Exists', 'Forall2', 'Forall3', 'Forall4', 'Forall5', 'Forall6', - 'MayCreate', 'MaySet', 'Low', 'LowVal', 'LowEvent', 'LowExit'): + 'MayCreate', 'MaySet', 'Low', 'LowVal', 'LowEvent', 'LowExit', 'isNaN'): return module.global_module.classes[BOOL_TYPE] elif node.func.id == 'Declassify': return None diff --git a/src/nagini_translation/resources/float.sil b/src/nagini_translation/resources/float.sil index 1b00c130..03091d69 100644 --- a/src/nagini_translation/resources/float.sil +++ b/src/nagini_translation/resources/float.sil @@ -9,6 +9,11 @@ function float___bool__(self: Ref): Bool ensures self == null ==> !result ensures issubtype(typeof(self), int()) ==> (result == int___bool__(self)) +function float___isNaN(f: Ref): Bool + decreases _ + requires issubtype(typeof(f), float()) + ensures issubtype(typeof(f), int()) ==> result == false + function float___ge__(self: Ref, other: Ref): Bool decreases _ requires issubtype(typeof(self), float()) diff --git a/src/nagini_translation/resources/float_ieee32.sil b/src/nagini_translation/resources/float_ieee32.sil index 68327fe6..cb9cf479 100644 --- a/src/nagini_translation/resources/float_ieee32.sil +++ b/src/nagini_translation/resources/float_ieee32.sil @@ -18,6 +18,11 @@ function float___bool__(self: Ref): Bool ensures self == null ==> !result ensures result == (float___unbox__(self) != ___float32_zero()) +function float___isNaN(f: Ref): Bool + decreases _ + requires issubtype(typeof(f), float()) + ensures result == ___float32_isNaN(float___unbox__(f)) + function float___ge__(self: Ref, other: Ref): Bool decreases _ requires issubtype(typeof(self), float()) @@ -144,6 +149,7 @@ domain ___float32 interpretation (Boogie: "float24e8", SMTLIB: "(_ FloatingPoint function real____to_int(p: Perm): Int interpretation "to_int" function ___float32_to_real(p: ___float32): Perm interpretation "fp.to_real" function ___float32_NaN(): ___float32 interpretation "(_ NaN 8 24)" + function ___float32_isNaN(___float32): Bool interpretation "fp.isNaN" } function ___float32_zero(): ___float32 diff --git a/src/nagini_translation/resources/float_real.sil b/src/nagini_translation/resources/float_real.sil index 4613770d..596be318 100644 --- a/src/nagini_translation/resources/float_real.sil +++ b/src/nagini_translation/resources/float_real.sil @@ -42,6 +42,11 @@ function float___is_inf__(r: Ref, negative: Bool): Bool requires issubtype(typeof(r), float()) ensures issubtype(typeof(r), int()) ==> result == false +function float___isNaN(f: Ref): Bool + decreases _ + requires issubtype(typeof(f), float()) + ensures result == float___is_nan__(f) + function float___bool__(self: Ref): Bool decreases _ requires issubtype(typeof(self), float()) diff --git a/src/nagini_translation/resources/preamble.index b/src/nagini_translation/resources/preamble.index index 04a4aa2a..6c6e1ffd 100644 --- a/src/nagini_translation/resources/preamble.index +++ b/src/nagini_translation/resources/preamble.index @@ -410,6 +410,11 @@ "args": ["float"], "type": "int", "requires": ["unbox", "__prim__int___box__", "float___is_nan__", "float___is_inf__"] + }, + "__isNaN": { + "args": ["float"], + "type": "__prim__bool", + "requires": ["float___is_nan__", "float___unbox__"] } }, "extends": "object" diff --git a/src/nagini_translation/translators/contract.py b/src/nagini_translation/translators/contract.py index 455f00d3..33a15e69 100644 --- a/src/nagini_translation/translators/contract.py +++ b/src/nagini_translation/translators/contract.py @@ -16,6 +16,7 @@ GET_OLD_FUNC, GLOBAL_VAR_FIELD, INT_TYPE, + FLOAT_TYPE, JOINABLE_FUNC, METHOD_ID_DOMAIN, PMSET_TYPE, @@ -985,6 +986,13 @@ def translate_exists(self, node: ast.Call, ctx: Context, self.to_position(node, ctx), self.no_info(ctx)) return dom_stmt, exists + + def translate_isNaN(self, node: ast.Call, ctx: Context) -> StmtsAndExpr: + assert len(node.args) == 1 + stmt, expr = self.translate_expr(node.args[0], ctx, self.viper.Perm, True) + float_class = ctx.module.global_module.classes[FLOAT_TYPE] + call = self.get_function_call(float_class, '__isNaN', [expr], [None], node, ctx, self.to_position(node, ctx)) + return stmt, call def translate_contractfunc_call(self, node: ast.Call, ctx: Context, impure=False, statement=False) -> StmtsAndExpr: @@ -1104,6 +1112,8 @@ def translate_contractfunc_call(self, node: ast.Call, ctx: Context, return self.translate_get_arg(node, ctx) elif func_name == 'getOld': return self.translate_get_old(node, ctx) + elif func_name == 'isNaN': + return self.translate_isNaN(node, ctx) elif func_name == 'getMethod': raise InvalidProgramException(node, 'invalid.get.method.use') elif func_name == 'arg': diff --git a/tests/functional/verification/float_real/test_float.py b/tests/functional/verification/float_real/test_float.py index 9ffcd36a..559c41e4 100644 --- a/tests/functional/verification/float_real/test_float.py +++ b/tests/functional/verification/float_real/test_float.py @@ -49,6 +49,7 @@ def sqr3(num : float) -> float: return num * num def arith(num: float) -> float: + Requires(not isNaN(num)) Ensures(Result() == num + 3) return num + 1.0 + 2.0