From 4ecdc834b3a961944c8ea5e0c2e589fa85d300ca Mon Sep 17 00:00:00 2001 From: Matthew Yacavone Date: Mon, 28 Mar 2022 13:47:16 -0400 Subject: [PATCH] [MRSolver] Changes to Mr. Solver to get zero_array working (#1624) * add exp_explosion_mr_solver.saw, add is_elem_noErrorsSpec * progress on mr_solver zero_array |= zero_array_spec * fix mrFunOutType, fix lifting and use asApplyAll in askMRSolverH * implement vecMapM, (ec)atM, (ec)updateM without Nat__rec * add maybe elim for IsLe(/Lt)Nat * make `bvNat w (bvToNat w' n)` reduce to `n` in the simulator * add cases for vecMapM, atM, updateM to normComp/normBind * remove maybe elim for IsLe(/Lt)Nat, always unfold is_bvule(t) in maybe * add macro for precondHint in Monadify.hs * do beta reds + look past asserts in mrGetPrecond, get loop spec working * added specification primitives for cryptol * add precondHint to specPrims.saw, lookup macros in set_monadification * rename precondHint to invariantHint * add assertingM, assumingM, and their monadification macros * add assertingM, assumingM to Mr. Solver * add bvVecMapInvarM, get zero_array_spec refinement working * update Prelude.v, clean up comments * whoops remove SAWCorePrelude.v * attempt to fix CI build failure on GHC 8.8.4 Co-authored-by: Eddy Westbrook --- cryptol-saw-core/cryptol-saw-core.cabal | 1 + cryptol-saw-core/saw/CryptolM.sawcore | 134 +++++-- .../src/Verifier/SAW/Cryptol/Monadify.hs | 40 +- heapster-saw/examples/SpecPrims.cry | 37 ++ heapster-saw/examples/arrays.cry | 15 + heapster-saw/examples/arrays.sawcore | 2 +- heapster-saw/examples/arrays_mr_solver.saw | 8 + heapster-saw/examples/exp_explosion.cry | 23 ++ heapster-saw/examples/exp_explosion.saw | 2 +- .../examples/exp_explosion_mr_solver.saw | 23 ++ .../examples/linked_list_mr_solver.saw | 18 +- heapster-saw/examples/specPrims.saw | 10 + .../coq/handwritten/CryptolToCoq/CompM.v | 3 + .../SAW/Translation/Coq/SpecialTreatment.hs | 4 + saw-core/prelude/Prelude.sawcore | 30 +- saw-core/src/Verifier/SAW/Simulator/Prims.hs | 4 +- src/SAWScript/Builtins.hs | 21 +- src/SAWScript/Prover/MRSolver/Monad.hs | 138 ++++--- src/SAWScript/Prover/MRSolver/SMT.hs | 19 +- src/SAWScript/Prover/MRSolver/Solver.hs | 363 +++++++++++++----- src/SAWScript/Prover/MRSolver/Term.hs | 19 + 21 files changed, 719 insertions(+), 195 deletions(-) create mode 100644 heapster-saw/examples/SpecPrims.cry create mode 100644 heapster-saw/examples/arrays.cry create mode 100644 heapster-saw/examples/exp_explosion.cry create mode 100644 heapster-saw/examples/exp_explosion_mr_solver.saw create mode 100644 heapster-saw/examples/specPrims.saw diff --git a/cryptol-saw-core/cryptol-saw-core.cabal b/cryptol-saw-core/cryptol-saw-core.cabal index 6385b341ee..9a38024a2f 100644 --- a/cryptol-saw-core/cryptol-saw-core.cabal +++ b/cryptol-saw-core/cryptol-saw-core.cabal @@ -15,6 +15,7 @@ Description: extra-source-files: saw/Cryptol.sawcore + saw/CryptolM.sawcore library build-depends: diff --git a/cryptol-saw-core/saw/CryptolM.sawcore b/cryptol-saw-core/saw/CryptolM.sawcore index 61df5156f2..01eed7bd27 100644 --- a/cryptol-saw-core/saw/CryptolM.sawcore +++ b/cryptol-saw-core/saw/CryptolM.sawcore @@ -57,16 +57,67 @@ mseq : Num -> sort 0 -> sort 0; mseq num a = Num_rec (\ (_:Num) -> sort 0) (\ (n:Nat) -> Vec n a) (Stream (CompM a)) num; +bvVecMapInvarBindM : (a b c : isort 0) -> (n : Nat) -> (len : Vec n Bool) -> + (a -> CompM b) -> BVVec n len a -> + Bool -> (BVVec n len b -> CompM c) -> CompM c; +bvVecMapInvarBindM a b c n len f xs invar cont = + existsM (BVVec n len b) c (\ (ys0:BVVec n len b) -> + multiArgFixM + (LRT_Fun (Vec n Bool) (\ (_:Vec n Bool) -> + LRT_Fun (BVVec n len b) (\ (_:BVVec n len b) -> + LRT_Ret c))) + (\ (rec : Vec n Bool -> BVVec n len b -> CompM c) + (i:Vec n Bool) (ys:BVVec n len b) -> + invariantHint (CompM c) (and (bvule n i len) invar) + (maybe (is_bvult n i len) (CompM c) + (cont ys) + (\ (pf:is_bvult n i len) -> + bindM b c + (f (atBVVec n len a xs i pf)) + (\ (y:b) -> rec (bvAdd n i (bvNat n 1)) + (updBVVec n len b ys i y))) + (bvultWithProof n i len))) + (bvNat n 0) ys0); + +bvVecMapInvarM : (a b : isort 0) -> (n : Nat) -> (len : Vec n Bool) -> + (a -> CompM b) -> BVVec n len a -> + Bool -> CompM (BVVec n len b); +bvVecMapInvarM a b n len f xs invar = + bvVecMapInvarBindM a b (BVVec n len b) n len f xs invar + (returnM (BVVec n len b)); + +bvVecMapM : (a b : isort 0) -> (n : Nat) -> (len : Vec n Bool) -> + (a -> CompM b) -> BVVec n len a -> CompM (BVVec n len b); +bvVecMapM a b n len f xs = bvVecMapInvarM a b n len f xs True; + +vecMapInvarBindM : (a b c : isort 0) -> (n : Nat) -> (a -> CompM b) -> + Vec n a -> Bool -> (Vec n b -> CompM c) -> CompM c; +vecMapInvarBindM a b c n f xs invar cont = + existsM (Vec n b) c (\ (ys0:Vec n b) -> + multiArgFixM + (LRT_Fun Nat (\ (_:Nat) -> + LRT_Fun (Vec n b) (\ (_:Vec n b) -> + LRT_Ret c))) + (\ (rec : Nat -> Vec n b -> CompM c) (i:Nat) (ys:Vec n b) -> + invariantHint (CompM c) (and (ltNat i (Succ n)) invar) + (maybe (IsLtNat i n) (CompM c) + (cont ys) + (\ (pf:IsLtNat i n) -> + bindM b c + (f (atWithProof n a xs i pf)) + (\ (y:b) -> rec (Succ i) + (updWithProof n b ys i y pf))) + (proveLtNat i n))) + 0 ys0); + +vecMapInvarM : (a b : isort 0) -> (n : Nat) -> (a -> CompM b) -> Vec n a -> + Bool -> CompM (Vec n b); +vecMapInvarM a b n f xs invar = + vecMapInvarBindM a b (Vec n b) n f xs invar (returnM (Vec n b)); + vecMapM : (a b : isort 0) -> (n : Nat) -> (a -> CompM b) -> Vec n a -> CompM (Vec n b); -vecMapM a b n_top f = - Nat__rec (\ (n:Nat) -> Vec n a -> CompM (Vec n b)) - (\ (_:Vec 0 a) -> returnM (Vec 0 b) (EmptyVec b)) - (\ (n:Nat) (rec:Vec n a -> CompM (Vec n b)) (v:Vec (Succ n) a) -> - fmapM2 b (Vec n b) (Vec (Succ n) b) - (\ (x:b) (xs:Vec n b) -> ConsVec b x n xs) - (f (head n a v)) (rec (tail n a v))) - n_top; +vecMapM a b n f xs = vecMapInvarM a b n f xs True; -- Computational version of seqMap seqMapM : (a b : sort 0) -> (n : Num) -> (a -> CompM b) -> mseq n a -> @@ -97,17 +148,36 @@ seqToMseq n_top a = -------------------------------------------------------------------------------- -- Auxiliary functions -atM : (n : Nat) -> (a : sort 0) -> Vec n a -> Nat -> CompM a; -atM n_top a = - Nat__rec - (\ (n:Nat) -> Vec n a -> Nat -> CompM a) - (\ (_:Vec 0 a) (_:Nat) -> errorM a "atM: index out of bounds") - (\ (n:Nat) (rec_f: Vec n a -> Nat -> CompM a) (v:Vec (Succ n) a) (i:Nat) -> - Nat_cases (CompM a) - (returnM a (head n a v)) - (\ (i_prev:Nat) (_:CompM a) -> rec_f (tail n a v) i_prev) i) - n_top; - +bvVecAtM : (n : Nat) -> (len : Vec n Bool) -> (a : isort 0) -> + BVVec n len a -> Vec n Bool -> CompM a; +bvVecAtM n len a xs i = + maybe (is_bvult n i len) (CompM a) + (errorM a "bvVecAtM: invalid sequence index") + (\ (pf:is_bvult n i len) -> returnM a (atBVVec n len a xs i pf)) + (bvultWithProof n i len); + +atM : (n : Nat) -> (a : isort 0) -> Vec n a -> Nat -> CompM a; +atM n a xs i = + maybe (IsLtNat i n) (CompM a) + (errorM a "atM: invalid sequence index") + (\ (pf:IsLtNat i n) -> returnM a (atWithProof n a xs i pf)) + (proveLtNat i n); + +bvVecUpdateM : (n : Nat) -> (len : Vec n Bool) -> (a : isort 0) -> + BVVec n len a -> Vec n Bool -> a -> CompM (BVVec n len a); +bvVecUpdateM n len a xs i x = + maybe (is_bvult n i len) (CompM (BVVec n len a)) + (errorM (BVVec n len a) "bvVecUpdateM: invalid sequence index") + (\ (_:is_bvult n i len) -> returnM (BVVec n len a) + (updBVVec n len a xs i x)) + (bvultWithProof n i len); + +updateM : (n : Nat) -> (a : isort 0) -> Vec n a -> Nat -> a -> CompM (Vec n a); +updateM n a xs i x = + maybe (IsLtNat i n) (CompM (Vec n a)) + (errorM (Vec n a) "updateM: invalid sequence index") + (\ (pf:IsLtNat i n) -> returnM (Vec n a) (updWithProof n a xs i x pf)) + (proveLtNat i n); eListSelM : (a : isort 0) -> (n : Num) -> mseq n a -> Nat -> CompM a; eListSelM a = @@ -347,15 +417,35 @@ primitive ecTransposeM : (m n : Num) -> (a : sort 0) -> mseq m (mseq n a) -> mseq n (mseq m a); -ecAtM : (n : Num) -> (a ix: sort 0) -> PIntegral ix -> mseq n a -> ix -> CompM a; +ecAtM : (n : Num) -> (a : isort 0) -> (ix : sort 0) -> PIntegral ix -> + mseq n a -> ix -> CompM a; ecAtM n_top a ix pix = Num_rec (\ (n:Num) -> mseq n a -> ix -> CompM a) (\ (n:Nat) (v:Vec n a) -> - pix.posNegCases (CompM a) (atM n a v) (\ (_:Nat) -> atM n a v 0)) + pix.posNegCases (CompM a) (atM n a v) + (\ (_:Nat) -> errorM a "ecAtM: invalid sequence index")) (\ (s:Stream (CompM a)) -> pix.posNegCases (CompM a) (streamGet (CompM a) s) - (\ (_:Nat) -> (streamGet (CompM a) s) 0)) + (\ (_:Nat) -> errorM a "ecAtM: invalid sequence index")) + n_top; + +ecUpdateM : (n : Num) -> (a : isort 0) -> (ix : sort 0) -> PIntegral ix -> + mseq n a -> ix -> a -> CompM (mseq n a); +ecUpdateM n_top a ix pix = + Num_rec + (\ (n:Num) -> mseq n a -> ix -> a -> CompM (mseq n a)) + (\ (n:Nat) (v:Vec n a) (i:ix) (x:a) -> + pix.posNegCases (CompM (Vec n a)) + (\ (i:Nat) -> updateM n a v i x) + (\ (_:Nat) -> errorM (Vec n a) + "ecUpdateM: invalid sequence index") i) + (\ (s:Stream (CompM a)) (i:ix) (x:a) -> + pix.posNegCases (CompM (Stream (CompM a))) + (\ (i:Nat) -> returnM (Stream (CompM a)) + (streamUpd (CompM a) s i (returnM a x))) + (\ (_:Nat) -> errorM (Stream (CompM a)) + "ecUpdateM: invalid sequence index") i) n_top; -- FIXME diff --git a/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs b/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs index b296dabb4d..be8cf2a241 100644 --- a/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs +++ b/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs @@ -1015,6 +1015,41 @@ eitherMacro = MonMacro 3 $ \_ args -> (MTyArrow (MTyArrow mtp_b mtp_c) (MTyArrow (mkMonType0 tp_eith) mtp_c))) eith_app +-- | The macro for invariantHint, which converts @invariantHint a cond m@ +-- to @invariantHint (CompM a) cond m@ and which contains any binds in the body +-- to the body +invariantHintMacro :: MonMacro +invariantHintMacro = MonMacro 3 $ \_ args -> + do let (tp, cond, m) = + case args of + [t1, t2, t3] -> (t1, t2, t3) + _ -> error "invariantHintMacro: wrong number of arguments!" + atrm_cond <- monadifyArg (Just boolMonType) cond + mtp <- monadifyTypeM tp + mtrm <- resetMonadifyM (toArgType mtp) $ monadifyTerm (Just mtp) m + return $ fromCompTerm mtp $ + applyOpenTermMulti (globalOpenTerm "Prelude.invariantHint") + [toCompType mtp, toArgTerm atrm_cond, toCompTerm mtrm] + +-- | The macro for @asserting@ or @assuming@, which converts @asserting@ to +-- @assertingM@ or @assuming@ to @assumingM@ (depending on whether the given +-- 'Bool' is true or false, respectively) and which contains any binds in the +-- body to the body +assertingOrAssumingMacro :: Bool -> MonMacro +assertingOrAssumingMacro doAsserting = MonMacro 3 $ \_ args -> + do let (tp, cond, m) = + case args of + [t1, t2, t3] -> (t1, t2, t3) + _ -> error "assertingOrAssumingMacro: wrong number of arguments!" + atrm_cond <- monadifyArg (Just boolMonType) cond + mtp <- monadifyTypeM tp + mtrm <- resetMonadifyM (toArgType mtp) $ monadifyTerm (Just mtp) m + let ident = if doAsserting then "Prelude.assertingM" + else "Prelude.assumingM" + return $ fromCompTerm mtp $ + applyOpenTermMulti (globalOpenTerm ident) + [toArgType mtp, toArgTerm atrm_cond, toCompTerm mtrm] + -- | Make a 'MonMacro' that maps a named global whose first argument is @n:Num@ -- to a global of semi-pure type that takes an additional argument of type -- @isFinite n@ @@ -1050,7 +1085,6 @@ lrtFromMonType (MTyArrow mtp1 mtp2) = lrtFromMonType mtp = ctorOpenTerm "Prelude.LRT_Ret" [toArgType mtp] - -- | The macro for fix -- -- FIXME: does not yet handle mutual recursion @@ -1104,6 +1138,9 @@ defaultMonEnv = , mmCustom "Prelude.ite" iteMacro , mmCustom "Prelude.fix" fixMacro , mmCustom "Prelude.either" eitherMacro + , mmCustom "Prelude.invariantHint" invariantHintMacro + , mmCustom "Prelude.asserting" (assertingOrAssumingMacro True) + , mmCustom "Prelude.assuming" (assertingOrAssumingMacro False) -- Top-level sequence functions , mmArg "Cryptol.seqMap" "CryptolM.seqMapM" @@ -1176,6 +1213,7 @@ defaultMonEnv = , mmSemiPureFin1 "Cryptol.ecReverse" "CryptolM.ecReverseM" , mmSemiPure "Cryptol.ecTranspose" "CryptolM.ecTransposeM" , mmArg "Cryptol.ecAt" "CryptolM.ecAtM" + , mmArg "Cryptol.ecUpdate" "CryptolM.ecUpdateM" -- , mmArgFin1 "Cryptol.ecAtBack" "CryptolM.ecAtBackM" -- , mmSemiPureFin2 "Cryptol.ecFromTo" "CryptolM.ecFromToM" , mmSemiPureFin1 "Cryptol.ecFromToLessThan" "CryptolM.ecFromToLessThanM" diff --git a/heapster-saw/examples/SpecPrims.cry b/heapster-saw/examples/SpecPrims.cry new file mode 100644 index 0000000000..3fc5d022fd --- /dev/null +++ b/heapster-saw/examples/SpecPrims.cry @@ -0,0 +1,37 @@ +module SpecPrims where + +/* Specification primitives */ + +// The specification that holds for f x for some input x +exists : {a, b} (a -> b) -> b +exists f = error "Cannot run exists" + +// The specification that holds for f x for all inputs x +forall : {a, b} (a -> b) -> b +forall f = error "Cannot run forall" + +// The specification that a computation returns some value with no errors +returnsSpec : {a} a +returnsSpec = exists (\x -> x) + +// The specification that matches any computation. This calls exists at the +// function type () -> a, which is monadified to () -> CompM a. This means that +// the exists does not just quantify over all values of type a like noErrors, +// but it quantifies over all computations of type a, including those that +// contain errors. +anySpec : {a} a +anySpec = exists (\f -> f ()) + +// The specification which asserts that the first argument is True and then +// returns the second argument +asserting : {a} Bit -> a -> a +asserting b x = if b then x else error "Assertion failed" + +// The specification which assumes that the first argument is True and then +// returns the second argument +assuming : {a} Bit -> a -> a +assuming b x = if b then x else anySpec + +// A hint to Mr Solver that a recursive function has the given loop invariant +invariantHint : {a} Bit -> a -> a +invariantHint b x = x diff --git a/heapster-saw/examples/arrays.cry b/heapster-saw/examples/arrays.cry new file mode 100644 index 0000000000..4b7ce92922 --- /dev/null +++ b/heapster-saw/examples/arrays.cry @@ -0,0 +1,15 @@ + +module Arrays where + +import SpecPrims + +zero_array_loop_spec : {n} Literal n [64] => [n][64] -> [n][64] +zero_array_loop_spec ys = loop 0 ys + where loop : [64] -> [n][64] -> [n][64] + loop i xs = invariantHint (i <= 0x0fffffffffffffff) + (if i < `n then loop (i+1) (update xs i 0) + else xs) + +zero_array_spec : {n} Literal n [64] => [n][64] -> [n][64] +zero_array_spec xs = assuming (`n <= 0x0fffffffffffffff) + [ 0 | _ <- xs ] diff --git a/heapster-saw/examples/arrays.sawcore b/heapster-saw/examples/arrays.sawcore index 6b1f16867b..ff7ef8cdbb 100644 --- a/heapster-saw/examples/arrays.sawcore +++ b/heapster-saw/examples/arrays.sawcore @@ -21,7 +21,7 @@ noErrorsContains0H len_top i_top v_top = (\ (f : (len i:Vec 64 Bool) -> BVVec 64 len (Vec 64 Bool) -> CompM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool)) -> ((\ (len:Vec 64 Bool) (i:Vec 64 Bool) (v:BVVec 64 len (Vec 64 Bool)) -> - precondHint + invariantHint (CompM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool)) (and (bvsle 64 0x0000000000000000 i) (bvsle 64 i 0x0fffffffffffffff)) diff --git a/heapster-saw/examples/arrays_mr_solver.saw b/heapster-saw/examples/arrays_mr_solver.saw index eaa38a79f7..386c4f095a 100644 --- a/heapster-saw/examples/arrays_mr_solver.saw +++ b/heapster-saw/examples/arrays_mr_solver.saw @@ -21,3 +21,11 @@ contains0 <- parse_core_mod "arrays" "contains0"; noErrorsContains0 <- parse_core_mod "arrays" "noErrorsContains0"; run_test "contains0 |= noErrorsContains0" (mr_solver_debug 0 contains0 noErrorsContains0) true; + +include "specPrims.saw"; +import "arrays.cry"; + +zero_array <- parse_core_mod "arrays" "zero_array"; +run_test "zero_array |= zero_array_spec" +// (mr_solver_debug 0 zero_array {{ zero_array_loop_spec }}) true; + (mr_solver_debug 0 zero_array {{ zero_array_spec }}) true; diff --git a/heapster-saw/examples/exp_explosion.cry b/heapster-saw/examples/exp_explosion.cry new file mode 100644 index 0000000000..6a2b5cc5a0 --- /dev/null +++ b/heapster-saw/examples/exp_explosion.cry @@ -0,0 +1,23 @@ + +module ExpExplosion where + +op : [64] -> [64] -> [64] +op x y = x ^ (y << (1 : [6])) + +exp_explosion_spec : [64] -> [64] +exp_explosion_spec x0 = x15 + where x1 = op x0 x0 + x2 = op x1 x1 + x3 = op x2 x2 + x4 = op x3 x3 + x5 = op x4 x4 + x6 = op x5 x5 + x7 = op x6 x6 + x8 = op x7 x7 + x9 = op x8 x8 + x10 = op x9 x9 + x11 = op x10 x10 + x12 = op x11 x11 + x13 = op x12 x12 + x14 = op x13 x13 + x15 = op x14 x14 diff --git a/heapster-saw/examples/exp_explosion.saw b/heapster-saw/examples/exp_explosion.saw index 64856fb92c..d27dd66b27 100644 --- a/heapster-saw/examples/exp_explosion.saw +++ b/heapster-saw/examples/exp_explosion.saw @@ -4,6 +4,6 @@ env <- heapster_init_env "exp_explosion" "exp_explosion.bc"; heapster_define_perm env "int64" " " "llvmptr 64" "exists x:bv 64.eq(llvmword(x))"; heapster_typecheck_fun env "exp_explosion" - "(). arg0:int64<> -o arg0:int64<>, ret:int64<>"; + "(). arg0:int64<> -o ret:int64<>"; heapster_export_coq env "exp_explosion_gen.v"; diff --git a/heapster-saw/examples/exp_explosion_mr_solver.saw b/heapster-saw/examples/exp_explosion_mr_solver.saw new file mode 100644 index 0000000000..03c97256c2 --- /dev/null +++ b/heapster-saw/examples/exp_explosion_mr_solver.saw @@ -0,0 +1,23 @@ +include "exp_explosion.saw"; + +let eq_bool b1 b2 = + if b1 then + if b2 then true else false + else + if b2 then false else true; + +let fail = do { print "Test failed"; exit 1; }; +let run_test name test expected = + do { if expected then print (str_concat "Test: " name) else + print (str_concat (str_concat "Test: " name) " (expecting failure)"); + actual <- test; + if eq_bool actual expected then print "Success\n" else + do { print "Test failed\n"; exit 1; }; }; + + + +import "exp_explosion.cry"; +monadify_term {{ op }}; + +exp_explosion <- parse_core_mod "exp_explosion" "exp_explosion"; +run_test "exp_explosion |= exp_explosion_spec" (mr_solver exp_explosion {{ exp_explosion_spec }}) true; diff --git a/heapster-saw/examples/linked_list_mr_solver.saw b/heapster-saw/examples/linked_list_mr_solver.saw index c741d7e890..98f1196b02 100644 --- a/heapster-saw/examples/linked_list_mr_solver.saw +++ b/heapster-saw/examples/linked_list_mr_solver.saw @@ -51,5 +51,19 @@ run_test "is_head |= is_head" (mr_solver is_head is_head) true; */ is_elem <- parse_core_mod "linked_list" "is_elem"; -run_test "is_elem |= is_elem_spec" (mr_solver_debug 2 is_elem {{ is_elem_spec }}) true; -//run_test "is_elem |= is_elem" (mr_solver_debug 1 is_elem is_elem) true; +// run_test "is_elem |= is_elem" (mr_solver_debug 0 is_elem is_elem) true; + +/* +is_elem_noErrorsSpec <- parse_core + "\\ (x:Vec 64 Bool) (y:List (Vec 64 Bool)) -> \ + \ fixM (Vec 64 Bool * List (Vec 64 Bool)) \ + \ (\\ (pr : Vec 64 Bool * List (Vec 64 Bool)) -> Vec 64 Bool) \ + \ (\\ (rec : (x : Vec 64 Bool * List (Vec 64 Bool)) -> CompM (Vec 64 Bool)) \ + \ (x : Vec 64 Bool * List (Vec 64 Bool)) -> \ + \ orM (Vec 64 Bool) \ + \ (existsM (Vec 64 Bool) (Vec 64 Bool) (returnM (Vec 64 Bool))) \ + \ (rec x)) (x, y)"; +run_test "is_elem |= noErrorsSpec" (mr_solver is_elem is_elem_noErrorsSpec) true; +*/ + +run_test "is_elem |= is_elem_spec" (mr_solver is_elem {{ is_elem_spec }}) true; diff --git a/heapster-saw/examples/specPrims.saw b/heapster-saw/examples/specPrims.saw new file mode 100644 index 0000000000..0f54d7deef --- /dev/null +++ b/heapster-saw/examples/specPrims.saw @@ -0,0 +1,10 @@ +/* Helper SAW script for using specification primitives in Cryptol */ + +import "SpecPrims.cry"; + +set_monadification "exists" "Prelude.existsM"; +set_monadification "forall" "Prelude.forallM"; +set_monadification "anySpec" "Prelude.anySpec"; +set_monadification "asserting" "Prelude.asserting"; +set_monadification "assuming" "Prelude.assuming"; +set_monadification "invariantHint" "Prelude.invariantHint"; diff --git a/saw-core-coq/coq/handwritten/CryptolToCoq/CompM.v b/saw-core-coq/coq/handwritten/CryptolToCoq/CompM.v index d0cb2ae90d..2830372fdf 100644 --- a/saw-core-coq/coq/handwritten/CryptolToCoq/CompM.v +++ b/saw-core-coq/coq/handwritten/CryptolToCoq/CompM.v @@ -888,6 +888,9 @@ Qed. Definition assertM (P:Prop) : CompM unit := existsM (fun pf:P => returnM tt). +Definition assertingM {A} (P:Prop) (m:CompM A) : CompM A := + assertM P >> m. + Definition assertM_eq (P:Prop) (pf:P) : assertM P ~= returnM tt. Proof. intro opt_a; split. diff --git a/saw-core-coq/src/Verifier/SAW/Translation/Coq/SpecialTreatment.hs b/saw-core-coq/src/Verifier/SAW/Translation/Coq/SpecialTreatment.hs index d8e61537de..44c0892e57 100644 --- a/saw-core-coq/src/Verifier/SAW/Translation/Coq/SpecialTreatment.hs +++ b/saw-core-coq/src/Verifier/SAW/Translation/Coq/SpecialTreatment.hs @@ -454,6 +454,10 @@ sawCorePreludeSpecialTreatmentMap configuration = , ("existsM", mapsToExpl compMModule "existsM") , ("forallM", mapsToExpl compMModule "forallM") , ("orM", mapsToExpl compMModule "orM") + , ("assertingM", mapsToExpl compMModule "assertingM") + , ("assumingM", mapsToExpl compMModule "assumingM") + , ("asserting", skip) + , ("assuming", skip) , ("fixM", replace (Coq.App (Coq.ExplVar "fixM") [Coq.Var "CompM", Coq.Var "_"])) , ("LetRecType", mapsTo compMModule "LetRecType") diff --git a/saw-core/prelude/Prelude.sawcore b/saw-core/prelude/Prelude.sawcore index 8eeaaf033b..d8c4950a78 100644 --- a/saw-core/prelude/Prelude.sawcore +++ b/saw-core/prelude/Prelude.sawcore @@ -2152,15 +2152,39 @@ primitive existsM : (a b:sort 0) -> (a -> CompM b) -> CompM b; orM : (a : sort 0) -> CompM a -> CompM a -> CompM a; orM a m1 m2 = existsM Bool a (\ (b:Bool) -> ite (CompM a) b m1 m2); +-- The specification that matches any computation +anySpec : (a : sort 0) -> CompM a; +anySpec a = existsM (CompM a) a (\ (m:CompM a) -> m); + -- The specification formed from the intersection of all computations f x for -- all possible inputs x. Computationally, this is sort of like running f for -- all possible inputs x at the same time and then raising an error if any of -- those computations diverge from each other. primitive forallM : (a b:sort 0) -> (a -> CompM b) -> CompM b; --- A hint to Mr Solver that a recursive function has the given precondition -precondHint : (a : sort 0) -> Bool -> a -> a; -precondHint _ _ a = a; +-- The specification which asserts that the first argument is True and then +-- runs the second argument +assertingM : (a : sort 0) -> Bool -> CompM a -> CompM a; +assertingM a b m = ite (CompM a) b m (errorM a "Assertion failed"); + +-- The specification which assumes that the first argument is True and then +-- runs the second argument +assumingM : (a : sort 0) -> Bool -> CompM a -> CompM a; +assumingM a b m = ite (CompM a) b m (anySpec a); + +-- A hint to Mr Solver that a recursive function has the given loop invariant +invariantHint : (a : sort 0) -> Bool -> a -> a; +invariantHint _ _ a = a; + +-- The version of assertingM which appears in un-monadified Cryptol (this gets +-- converted to assertingM during monadification, see assertingOrAssumingMacro) +asserting : (a : isort 0) -> Bool -> a -> a; +asserting a b x = ite a b x (error a "Assertion failed"); + +-- The version of assumingM which appears in un-monadified Cryptol (this gets +-- converted to assumingM during monadification, see assertingOrAssumingMacro) +assuming : (a : isort 0) -> Bool -> a -> a; +assuming a b x = ite a b x (error a "Assuming failed"); -- NOTE: for the simplicity and efficiency of MR solver, we define all -- fixed-point computations in CompM via a primitive multiFixM, defined below. diff --git a/saw-core/src/Verifier/SAW/Simulator/Prims.hs b/saw-core/src/Verifier/SAW/Simulator/Prims.hs index e2a16c7614..8648f52ac3 100644 --- a/saw-core/src/Verifier/SAW/Simulator/Prims.hs +++ b/saw-core/src/Verifier/SAW/Simulator/Prims.hs @@ -499,8 +499,8 @@ selectV mux maxValue valueFn v = impl len 0 bvNatOp :: (VMonad l, Show (Extra l)) => BasePrims l -> Prim l bvNatOp bp = natFun $ \w -> - natFun $ \x -> - Prim (VWord <$> bpBvLit bp (fromIntegral w) (toInteger x)) -- FIXME check for overflow on w + strictFun $ \v -> + Prim (VWord <$> natToWord bp (fromIntegral w) v) -- FIXME check for overflow on w -- bvToNat : (n : Nat) -> Vec n Bool -> Nat; bvToNatOp :: VMonad l => Prim l diff --git a/src/SAWScript/Builtins.hs b/src/SAWScript/Builtins.hs index 4edc37ef2d..08ffc0ace0 100644 --- a/src/SAWScript/Builtins.hs +++ b/src/SAWScript/Builtins.hs @@ -1627,17 +1627,22 @@ setMonadification sc cry_str saw_str = _ -> fail ("Could not find type for Cryptol name: " ++ cry_str) cry_mon_tp <- liftIO $ Monadify.monadifyCompleteArgType sc cry_saw_tp - -- Step 3: convert the second string to a typed SAW core term, and check - -- that it has the same type as the monadified type for the Cryptol name + -- Step 3: convert the second string to a typed SAW core term, and if it + -- has an existing macro, check that it has the same type as the type for + -- the cryptol name, or if no macro exists, check that it has the same + -- type as the monadified type for the Cryptol name and generate a macro + -- which maps the Cryptol name to the SAW core term let saw_ident = parseIdent saw_str saw_trm <- liftIO $ scGlobalDef sc saw_ident saw_tp <- liftIO $ scTypeOf sc saw_trm - liftIO $ scCheckSubtype sc Nothing (TC.TypedTerm saw_trm saw_tp) cry_mon_tp - - -- Step 4: Add a mapping from the Cryptol name to the SAW core term - put (rw { rwMonadify = - Map.insert cry_nmi (Monadify.argGlobalMacro - cry_nmi saw_ident) (rwMonadify rw) }) + let (tp_to_check, macro) = + case Map.lookup (ModuleIdentifier saw_ident) (rwMonadify rw) of + Just existing_macro -> (cry_saw_tp, existing_macro) + Nothing -> (cry_mon_tp, Monadify.argGlobalMacro cry_nmi saw_ident) + liftIO $ scCheckSubtype sc Nothing (TC.TypedTerm saw_trm saw_tp) tp_to_check + + -- Step 4: Add the generated macro + put (rw { rwMonadify = Map.insert cry_nmi macro (rwMonadify rw) }) parseSharpSATResult :: String -> Maybe Integer parseSharpSATResult s = parse (lines s) diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index 71e79735ba..0fe9915bec 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -73,7 +73,9 @@ data MRFailure | MalformedDefsFun Term | MalformedComp Term | NotCompFunType Term - | PrecondNotProvable FunName FunName Term + | AssertionNotProvable Term + | AssumptionNotProvable Term + | InvariantNotProvable FunName FunName Term -- | A local variable binding | MRFailureLocalVar LocalName MRFailure -- | Information about the context of the failure @@ -130,8 +132,12 @@ instance PrettyInCtx MRFailure where ppWithPrefix "Could not handle computation:" t prettyInCtx (NotCompFunType tp) = ppWithPrefix "Not a computation or computational function type:" tp - prettyInCtx (PrecondNotProvable f g pre) = - prettyAppList [return "Could not prove precondition for functions", + prettyInCtx (AssertionNotProvable cond) = + ppWithPrefix "Failed to prove assertion:" cond + prettyInCtx (AssumptionNotProvable cond) = + ppWithPrefix "Failed to prove condition for `assuming`:" cond + prettyInCtx (InvariantNotProvable f g pre) = + prettyAppList [return "Could not prove loop invariant for functions", prettyInCtx f, return "and", prettyInCtx g, return ":", prettyInCtx pre] prettyInCtx (MRFailureLocalVar x err) = @@ -195,12 +201,12 @@ data CoIndHyp = CoIndHyp { coIndHypLHS :: [Term], -- | The RHS argument expressions @y1, ..., ym@ over the 'coIndHypCtx' uvars coIndHypRHS :: [Term], - -- | The precondition for the left-hand arguments, as a closed function from + -- | The invariant for the left-hand arguments, as a closed function from -- the left-hand arguments to @Bool@ - coIndHypPrecondLHS :: Maybe Term, - -- | The precondition for the right-hand arguments, as a closed function from + coIndHypInvariantLHS :: Maybe Term, + -- | The invariant for the right-hand arguments, as a closed function from -- the left-hand arguments to @Bool@ - coIndHypPrecondRHS :: Maybe Term + coIndHypInvariantRHS :: Maybe Term } deriving Show -- | Extract the @i@th argument on either the left- or right-hand side of a @@ -214,13 +220,13 @@ coIndHypArg hyp (Right i) = (coIndHypRHS hyp) !! i type CoIndHyps = Map (FunName, FunName) CoIndHyp instance PrettyInCtx CoIndHyp where - prettyInCtx (CoIndHyp ctx f1 f2 args1 args2 pre1 pre2) = + prettyInCtx (CoIndHyp ctx f1 f2 args1 args2 invar1 invar2) = local (const $ map fst $ reverse ctx) $ prettyAppList [return (ppCtx ctx <> "."), - (case pre1 of + (case invar1 of Just f -> prettyTermApp f args1 Nothing -> return "True"), return "=>", - (case pre2 of + (case invar2 of Just f -> prettyTermApp f args2 Nothing -> return "True"), return "=>", prettyInCtx (FunBind f1 args1 CompFunReturn), @@ -260,6 +266,10 @@ asIsFinite (asApp -> Just (isGlobalDef "CryptolM.isFinite" -> Just (), n)) = Just n asIsFinite _ = Nothing +-- | Create a term representing the type @IsFinite n@ +mrIsFinite :: Term -> MRM Term +mrIsFinite n = liftSC2 scGlobalApply "CryptolM.isFinite" [n] + -- | A map from 'Term's to 'DataTypeAssump's over that term type DataTypeAssumps = HashMap Term DataTypeAssump @@ -345,6 +355,10 @@ mrDataTypeAssumps = mriDataTypeAssumps <$> ask mrDebugLevel :: MRM Int mrDebugLevel = mriDebugLevel <$> ask +-- | Get the current value of 'mriEnv' +mrEnv :: MRM MREnv +mrEnv = mriEnv <$> ask + -- | Get the current value of 'mrsVars' mrVars :: MRM MRVarMap mrVars = mrsVars <$> get @@ -484,7 +498,7 @@ mrTypeOf :: Term -> MRM Term mrTypeOf t = -- NOTE: scTypeOf' wants the type context in the most recently bound var -- first, i.e., in the mrUVarCtxRev order - mrDebugPPPrefix 2 "mrTypeOf:" t >> + mrDebugPPPrefix 3 "mrTypeOf:" t >> mrUVarCtxRev >>= \ctx -> liftSC2 scTypeOf' (map snd ctx) t -- | Check if two 'Term's are convertible in the 'MRM' monad @@ -496,17 +510,13 @@ mrConvertible = liftSC4 scConvertibleEval scTypeCheckWHNF True -- type @[args/vars]a@ that @CompM@ is applied to. mrFunOutType :: FunName -> [Term] -> MRM Term mrFunOutType fname args = - funNameType fname >>= \case - (asPiList -> (vars, asCompM -> Just tp)) - | length vars == length args -> substTermLike 0 (reverse args) tp - ftype@(asPiList -> (vars, _)) -> - do pp_ftype <- mrPPInCtx ftype - pp_fname <- mrPPInCtx fname - debugPrint 0 "mrFunOutType: function applied to the wrong number of args" - debugPrint 0 ("Expected: " ++ show (length vars) ++ - ", found: " ++ show (length args)) - debugPretty 0 ("For function: " <> pp_fname <> " with type: " <> pp_ftype) - error "mrFunOutType" + mrApplyAll (funNameTerm fname) args >>= mrTypeOf >>= \case + (asCompM -> Just tp) -> liftSC1 scWhnf tp + _ -> do pp_ftype <- funNameType fname >>= mrPPInCtx + pp_fname <- mrPPInCtx fname + debugPrint 0 "mrFunOutType: function does not have CompM return type" + debugPretty 0 ("Function:" <> pp_fname <> " with type: " <> pp_ftype) + error "mrFunOutType" -- | Turn a 'LocalName' into one not in a list, adding a suffix if necessary uniquifyName :: LocalName -> [LocalName] -> LocalName @@ -562,7 +572,8 @@ withNoUVars m = local (\info -> info { mriUVars = [], mriAssumptions = true_tm, mriDataTypeAssumps = HashMap.empty }) m --- | Run a MR Solver in a context of only the specified UVars, no others +-- | Run a MR Solver in a context of only the specified UVars, no others - +-- note that this also clears all assumptions withOnlyUVars :: [(LocalName,Term)] -> MRM a -> MRM a withOnlyUVars vars m = withNoUVars $ withUVars vars $ const m @@ -628,6 +639,18 @@ extCnsToFunName ec = let var = MRVar ec in mrVarInfo var >>= \case return $ GlobalName glob [] _ -> error "extCnsToFunName: unreachable" +-- | Get the 'FunName' of a global definition +mrGlobalDef :: Ident -> MRM FunName +mrGlobalDef ident = asTypedGlobalDef <$> liftSC1 scGlobalDef ident >>= \case + Just glob -> return $ GlobalName glob [] + _ -> error $ "mrGlobalDef: could not get GlobalDef of: " ++ show ident + +-- | Get the body of a global definition, raising an 'error' if none is found +mrGlobalDefBody :: Ident -> MRM Term +mrGlobalDefBody ident = asConstant <$> liftSC1 scGlobalDef ident >>= \case + Just (_, Just body) -> return body + _ -> error $ "mrGlobalDefBody: global has no definition: " ++ show ident + -- | Get the body of a function @f@ if it has one mrFunNameBody :: FunName -> MRM (Maybe Term) mrFunNameBody (LetRecName var) = @@ -837,24 +860,24 @@ instantiateCoIndHyp (CoIndHyp {..}) = rhs <- substTermLike 0 evars coIndHypRHS return (lhs, rhs) --- | Apply the preconditions of a 'CoIndHyp' to their respective arguments, --- yielding @Bool@ conditions, using the constant @True@ value when a --- precondition is absent -applyCoIndHypPreconds :: CoIndHyp -> MRM (Term, Term) -applyCoIndHypPreconds hyp = - let apply_precond :: Maybe Term -> [Term] -> MRM Term - apply_precond (Just (asLambdaList -> (vars, phi))) args +-- | Apply the invariants of a 'CoIndHyp' to their respective arguments, +-- yielding @Bool@ conditions, using the constant @True@ value when an +-- invariant is absent +applyCoIndHypInvariants :: CoIndHyp -> MRM (Term, Term) +applyCoIndHypInvariants hyp = + let apply_invariant :: Maybe Term -> [Term] -> MRM Term + apply_invariant (Just (asLambdaList -> (vars, phi))) args | length vars == length args -- NOTE: applying to a list of arguments == substituting the reverse -- of that list, because the first argument corresponds to the -- greatest deBruijn index = substTerm 0 (reverse args) phi - apply_precond (Just _) _ = - error "applyCoIndHypPreconds: wrong number of arguments for precondition!" - apply_precond Nothing _ = liftSC1 scBool True in - do pre1 <- apply_precond (coIndHypPrecondLHS hyp) (coIndHypLHS hyp) - pre2 <- apply_precond (coIndHypPrecondRHS hyp) (coIndHypRHS hyp) - return (pre1, pre2) + apply_invariant (Just _) _ = + error "applyCoIndHypInvariants: wrong number of arguments for invariant!" + apply_invariant Nothing _ = liftSC1 scBool True in + do invar1 <- apply_invariant (coIndHypInvariantLHS hyp) (coIndHypLHS hyp) + invar2 <- apply_invariant (coIndHypInvariantRHS hyp) (coIndHypRHS hyp) + return (invar1, invar2) -- | Look up the 'FunAssump' for a 'FunName', if there is one mrGetFunAssump :: FunName -> MRM (Maybe FunAssump) @@ -883,21 +906,44 @@ instantiateFunAssump fassump = rhs <- substTermLike 0 evars $ fassumpRHS fassump return (args, rhs) --- | Get the precondition hint associated with a function name, by unfolding the +-- | Get the invariant hint associated with a function name, by unfolding the -- name and checking if its body has the form -- --- > \ x1 ... xn -> precondHint a phi m +-- > \ x1 ... xn -> invariantHint a phi m -- -- If so, return @\ x1 ... xn -> phi@ as a term with the @xi@ variables free. --- Otherwise, return 'Nothing'. -mrGetPrecond :: FunName -> MRM (Maybe Term) -mrGetPrecond nm = +-- Otherwise, return 'Nothing'. Note that this function will also look past +-- any initial @bindM ... (assertFiniteM ...)@ applications. +mrGetInvariant :: FunName -> MRM (Maybe Term) +mrGetInvariant nm = mrFunNameBody nm >>= \case - Just (asLambdaList -> - (args, - asApplyAll -> (isGlobalDef "Prelude.precondHint" -> Just (), - [_, phi, _]))) -> - Just <$> liftSC2 scLambdaList args phi + Just body -> mrGetInvariantBody body + _ -> return Nothing + +-- | The main loop of 'mrGetInvariant', which operates on a function body +mrGetInvariantBody :: Term -> MRM (Maybe Term) +mrGetInvariantBody tm = case asApplyAll tm of + -- go inside any top-level lambdas + (asLambda -> Just (nm, tp, body), []) -> + do body' <- liftSC1 betaNormalize body + mb_phi <- mrGetInvariantBody body' + liftSC3 scLambda nm tp `mapM` mb_phi + -- always beta-reduce + (f@(asLambda -> Just _), args) -> + do tm' <- mrApplyAll f args + mrGetInvariantBody tm' + -- go inside any top-level applications of of bindM ... (assertFiniteM ...) + (isGlobalDef "Prelude.bindM" -> Just (), + [_, _, + asApp -> Just (isGlobalDef "CryptolM.assertFiniteM" -> Just (), + asCtor -> Just (primName -> "Cryptol.TCNum", _)), + k]) -> + do pf <- liftSC1 scGlobalDef "Prelude.TrueI" + body <- mrApplyAll k [pf] + mrGetInvariantBody body + -- otherwise, return Just iff there is a top-level invariant hint + (isGlobalDef "Prelude.invariantHint" -> Just (), + [_, phi, _]) -> return $ Just phi _ -> return Nothing -- | Add an assumption of type @Bool@ to the current path condition while diff --git a/src/SAWScript/Prover/MRSolver/SMT.hs b/src/SAWScript/Prover/MRSolver/SMT.hs index efb196f0bc..b42322f8a3 100644 --- a/src/SAWScript/Prover/MRSolver/SMT.hs +++ b/src/SAWScript/Prover/MRSolver/SMT.hs @@ -19,6 +19,7 @@ module SAWScript.Prover.MRSolver.SMT where import qualified Data.Vector as V import Control.Monad.Except +import qualified Control.Exception as X import Data.Map (Map) import qualified Data.Map as Map @@ -30,6 +31,7 @@ import Verifier.SAW.SharedTerm import Verifier.SAW.Recognizer import Verifier.SAW.OpenTerm +import Verifier.SAW.Prim (EvalError(..)) import qualified Verifier.SAW.Prim as Prim import Verifier.SAW.Simulator.TermModel import Verifier.SAW.Simulator.Prims @@ -190,12 +192,21 @@ mrProvableRaw prop_term = debugPrint 2 ("Calling SMT solver with proposition: " ++ prettyProp defaultPPOpts prop) sym <- liftIO $ setupWhat4_sym True - (smt_res, _) <- - liftIO $ proveWhat4_solver z3Adapter sym unints sc prop (return ()) + -- If there are any saw-core `error`s in the term, this will throw a + -- Haskell error - in this case we want to just return False, not stop + -- execution + smt_res <- liftIO $ + (Right <$> proveWhat4_solver z3Adapter sym unints sc prop (return ())) + `X.catch` \case + UserError msg -> return $ Left msg + e -> X.throw e case smt_res of - Just _ -> + Left msg -> + debugPrint 2 ("SMT solver encountered a saw-core error term: " ++ msg) + >> return False + Right (Just _, _) -> debugPrint 2 "SMT solver response: not provable" >> return False - Nothing -> + Right (Nothing, _) -> debugPrint 2 "SMT solver response: provable" >> return True -- | Test if a Boolean term over the current uvars is provable given the current diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index f6b711a1e7..90a3c4aca8 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -2,6 +2,13 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ViewPatterns #-} +-- This is to stop GHC 8.8.4's pattern match checker exceeding its limit when +-- checking the pattern match in the 'CompTerm' case of 'normComp' +{-# LANGUAGE CPP #-} +#if __GLASGOW_HASKELL__ <= 808 +{-# OPTIONS_GHC -fno-warn-incomplete-patterns -fno-warn-overlapping-patterns #-} +#endif + {- | Module : SAWScript.Prover.MRSolver.Solver Copyright : Galois, Inc. 2022 @@ -20,6 +27,8 @@ errorM str >>= k = errorM (m >>= k1) >>= k2 = m >>= \x -> k1 x >>= k2 (existsM f) >>= k = existsM (\x -> f x >>= k) (forallM f) >>= k = forallM (\x -> f x >>= k) +(assumingM b m) >>= k = assumingM b (m >>= k) +(assertingM b m) >>= k = assertingM b (m >>= k) (orM m1 m2) >>= k = orM (m1 >>= k) (m2 >>= k) (if b then m1 else m2) >>= k = if b then m1 >>= k else m2 >>1 k (either f1 f2 e) >>= k = either (\x -> f1 x >= k) (\x -> f2 x >= k) e @@ -118,12 +127,16 @@ import Data.Maybe import Data.Either import Data.List (findIndices, intercalate) import Control.Monad.Except +import qualified Data.Map as Map +import qualified Data.Text as Text import Prettyprinter import Verifier.SAW.Term.Functor +import Verifier.SAW.Term.CtxTerm (substTerm) import Verifier.SAW.SharedTerm import Verifier.SAW.Recognizer +import Verifier.SAW.Cryptol.Monadify import SAWScript.Prover.MRSolver.Term import SAWScript.Prover.MRSolver.Monad @@ -195,6 +208,7 @@ normComp (CompBind m f) = do norm <- normComp m normBind norm f normComp (CompTerm t) = + (>>) (mrDebugPPPrefix 3 "normCompTerm:" t) $ withFailureCtx (FailCtxMNF t) $ case asApplyAll t of (f@(asLambda -> Just _), args) -> @@ -202,7 +216,7 @@ normComp (CompTerm t) = (isGlobalDef "Prelude.returnM" -> Just (), [_, x]) -> return $ ReturnM x (isGlobalDef "Prelude.bindM" -> Just (), [_, _, m, f]) -> - do norm <- normComp (CompTerm m) + do norm <- normCompTerm m normBind norm (CompFunTerm f) (isGlobalDef "Prelude.errorM" -> Just (), [_, str]) -> return (ErrorM str) @@ -211,15 +225,24 @@ normComp (CompTerm t) = (isGlobalDef "Prelude.either" -> Just (), [ltp, rtp, _, f, g, eith]) -> return $ Either (Type ltp) (Type rtp) (CompFunTerm f) (CompFunTerm g) eith (isGlobalDef "Prelude.maybe" -> Just (), [tp, _, m, f, mayb]) -> - return $ MaybeElim (Type tp) (CompTerm m) (CompFunTerm f) mayb + do tp' <- case asApplyAll tp of + -- Always unfold: is_bvult, is_bvule + (tpf@(asGlobalDef -> Just ident), args) + | ident `elem` ["Prelude.is_bvult", "Prelude.is_bvule"] + , Just (_, Just body) <- asConstant tpf -> + mrApplyAll body args + _ -> return tp + return $ MaybeElim (Type tp') (CompTerm m) (CompFunTerm f) mayb (isGlobalDef "Prelude.orM" -> Just (), [_, m1, m2]) -> return $ OrM (CompTerm m1) (CompTerm m2) + (isGlobalDef "Prelude.assertingM" -> Just (), [_, cond, body_tm]) -> + return $ AssertingM cond (CompTerm body_tm) + (isGlobalDef "Prelude.assumingM" -> Just (), [_, cond, body_tm]) -> + return $ AssumingM cond (CompTerm body_tm) (isGlobalDef "Prelude.existsM" -> Just (), [tp, _, body_tm]) -> return $ ExistsM (Type tp) (CompFunTerm body_tm) (isGlobalDef "Prelude.forallM" -> Just (), [tp, _, body_tm]) -> return $ ForallM (Type tp) (CompFunTerm body_tm) - (isGlobalDef "Prelude.precondHint" -> Just (), [_, _, body_tm]) -> - normCompTerm body_tm (isGlobalDef "Prelude.letRecM" -> Just (), [lrts, _, defs_f, body_f]) -> do -- Bind fresh function vars for the letrec-bound functions @@ -227,22 +250,7 @@ normComp (CompTerm t) = -- Apply the body function to our function vars and recursively -- normalize the resulting computation body_tm <- mrApplyAll body_f fun_tms - normComp (CompTerm body_tm) - - -- Only unfold constants that are not recursive functions, i.e., whose - -- bodies do not contain letrecs - {- FIXME: this should be handled by mrRefines; we want it to be handled there - so that we use refinement assumptions before unfolding constants, to give - the user control over refinement proofs - ((asConstant -> Just (_, body)), args) - | not (containsLetRecM body) -> - mrApplyAll body args >>= normCompTerm - -} - - -- Recognize and unfold a multiArgFixM - (f@(isGlobalDef "Prelude.multiArgFixM" -> Just ()), args) - | Just (_, Just body) <- asConstant f -> - mrApplyAll body args >>= normCompTerm + normCompTerm body_tm -- Recognize (multiFixM lrts (\ f1 ... fn -> (body1, ..., bodyn))).i args (asTupleSelector -> @@ -258,8 +266,56 @@ normComp (CompTerm t) = if i > 0 && i <= length fun_tms then mrApplyAll (fun_tms !! (i-1)) args else throwMRFailure (MalformedComp t) - normComp (CompTerm body_tm) + normCompTerm body_tm + + -- Convert `vecMapM (bvToNat ...)` into `bvVecMapInvarM`, with the + -- invariant being the current set of assumptions + (asGlobalDef -> Just "CryptolM.vecMapM", [a, b, asBvToNat -> Just (w, n), + f, xs]) -> + do invar <- mrAssumptions + liftSC2 scGlobalApply "CryptolM.bvVecMapInvarM" + [a, b, w, n, f, xs, invar] >>= normCompTerm + + -- Convert `atM (bvToNat ...)` into the unfolding of `bvVecAtM` + (asGlobalDef -> Just "CryptolM.atM", [asBvToNat -> Just (w1, n), a, xs, + asBvToNat -> Just (w2, i)]) -> + do body <- mrGlobalDefBody "CryptolM.bvVecAtM" + ws_are_eq <- mrConvertible w1 w2 + if ws_are_eq then + mrApplyAll body [w1, n, a, xs, i] >>= normCompTerm + else throwMRFailure (MalformedComp t) + + -- Convert `updateM (bvToNat ...)` into the unfolding of `bvVecUpdateM` + (asGlobalDef -> Just "CryptolM.updateM", [asBvToNat -> Just (w1, n), a, xs, + asBvToNat -> Just (w2, i), x]) -> + do body <- mrGlobalDefBody "CryptolM.bvVecUpdateM" + ws_are_eq <- mrConvertible w1 w2 + if ws_are_eq then + mrApplyAll body [w1, n, a, xs, i, x] >>= normCompTerm + else throwMRFailure (MalformedComp t) + + -- Always unfold: sawLet, multiArgFixM, invariantHint, Num_rec + (f@(asGlobalDef -> Just ident), args) + | ident `elem` ["Prelude.sawLet", "Prelude.multiArgFixM", + "Prelude.invariantHint", "Cryptol.Num_rec"] + , Just (_, Just body) <- asConstant f -> + mrApplyAll body args >>= normCompTerm + -- Always unfold recursors applied to constructors + (asRecursorApp -> Just (rc, crec, _, arg), args) + | Just (c, _, cargs) <- asCtorParams arg -> + do hd' <- liftSC4 scReduceRecursor rc crec c cargs + >>= liftSC1 betaNormalize + t' <- mrApplyAll hd' args + normCompTerm t' + + -- Always unfold record selectors applied to record values (after scWhnf) + (asRecordSelector -> Just (r, fld), args) -> + do r' <- liftSC1 scWhnf r + case asRecordValue r' of + Just (Map.lookup fld -> Just f) -> do t' <- mrApplyAll f args + normCompTerm t' + _ -> throwMRFailure (MalformedComp t) -- For an ExtCns, we have to check what sort of variable it is -- FIXME: substitute for evars if they have been instantiated @@ -285,10 +341,25 @@ normBind (MaybeElim tp m f t) k = return $ MaybeElim tp (CompBind m k) (compFunComp f k) t normBind (OrM comp1 comp2) k = return $ OrM (CompBind comp1 k) (CompBind comp2 k) +normBind (AssertingM cond comp) k = return $ AssertingM cond (CompBind comp k) +normBind (AssumingM cond comp) k = return $ AssumingM cond (CompBind comp k) normBind (ExistsM tp f) k = return $ ExistsM tp (compFunComp f k) normBind (ForallM tp f) k = return $ ForallM tp (compFunComp f k) -normBind (FunBind f args k1) k2 = - return $ FunBind f args (compFunComp k1 k2) +normBind (FunBind f args k1) k2 + -- Turn `bvVecMapInvarM ... >>= k` into `bvVecMapInvarBindM ... k` + | GlobalName (globalDefString -> "CryptolM.bvVecMapInvarM") [] <- f + , not (isCompFunReturn (compFunComp k1 k2)) = + do f' <- mrGlobalDef "CryptolM.bvVecMapInvarBindM" + cont <- compFunToTerm (compFunComp k1 k2) + return $ FunBind f' (args ++ [cont]) CompFunReturn + -- Turn `bvVecMapInvarBindM ... k1 >>= k2` into + -- `bvVecMapInvarBindM ... (composeM ... k1 k2)` + | GlobalName (globalDefString -> "CryptolM.bvVecMapInvarBindM") [] <- f + , (args_pre, [cont]) <- splitAt 8 args + , not (isCompFunReturn (compFunComp k1 k2)) = + do cont' <- compFunToTerm (compFunComp (compFunComp (CompFunTerm cont) k1) k2) + return $ FunBind f (args_pre ++ [cont']) CompFunReturn + | otherwise = return $ FunBind f args (compFunComp k1 k2) -- | Bind a 'Term' for a computation with a function and normalize normBindTerm :: Term -> CompFun -> MRM NormComp @@ -304,6 +375,36 @@ applyCompFun CompFunReturn t = return $ CompReturn t applyCompFun (CompFunTerm f) t = CompTerm <$> mrApplyAll f [t] +-- | Convert a 'CompFun' which is not a 'CompFunReturn' into a 'Term' +compFunToTerm :: CompFun -> MRM Term +compFunToTerm (CompFunTerm t) = return t +compFunToTerm (CompFunComp f g) = + do f' <- compFunToTerm f + g' <- compFunToTerm g + f_tp <- mrTypeOf f' + g_tp <- mrTypeOf g' + case (f_tp, g_tp) of + (asPi -> Just (_, a, asCompM -> Just b), + asPi -> Just (_, _, asCompM -> Just c)) -> + liftSC2 scGlobalApply "Prelude.composeM" [a, b, c, f', g'] + _ -> error "compFunToTerm: type(s) not of the form: a -> CompM b" +compFunToTerm CompFunReturn = error "compFunToTerm: got a CompFunReturn" + +-- | Convert a 'Comp' into a 'Term' +compToTerm :: Comp -> MRM Term +compToTerm (CompTerm t) = return t +compToTerm (CompReturn t) = + do tp <- mrTypeOf t + liftSC2 scGlobalApply "Prelude.returnM" [tp, t] +compToTerm (CompBind m CompFunReturn) = compToTerm m +compToTerm (CompBind m f) = + do m' <- compToTerm m + f' <- compFunToTerm f + mrTypeOf f' >>= \case + (asPi -> Just (_, a, asCompM -> Just b)) -> + liftSC2 scGlobalApply "Prelude.bindM" [a, b, m', f'] + _ -> error "compToTerm: type not of the form: a -> CompM b" + -- | Apply a 'CompFun' to a term and normalize the resulting computation applyNormCompFun :: CompFun -> Term -> MRM NormComp applyNormCompFun f arg = applyCompFun f arg >>= normComp @@ -344,15 +445,15 @@ handling the recursive ones -- * Handling Coinductive Hypotheses ---------------------------------------------------------------------- --- | Prove the precondition of a coinductive hypothesis -proveCoIndHypPreCond :: CoIndHyp -> MRM () -proveCoIndHypPreCond hyp = - do (pre1, pre2) <- applyCoIndHypPreconds hyp - pre <- liftSC2 scAnd pre1 pre2 - success <- mrProvable pre +-- | Prove the invariant of a coinductive hypothesis +proveCoIndHypInvariant :: CoIndHyp -> MRM () +proveCoIndHypInvariant hyp = + do (invar1, invar2) <- applyCoIndHypInvariants hyp + invar <- liftSC2 scAnd invar1 invar2 + success <- mrProvable invar if success then return () else throwMRFailure $ - PrecondNotProvable (coIndHypLHSFun hyp) (coIndHypRHSFun hyp) pre + InvariantNotProvable (coIndHypLHSFun hyp) (coIndHypRHSFun hyp) invar -- | Co-inductively prove the refinement -- @@ -361,21 +462,21 @@ proveCoIndHypPreCond hyp = -- -- where @F@ and @G@ are the given 'FunName's, @y1, ..., ym@ and @z1, ..., zl@ -- are the given argument lists, @x1, ..., xn@ is the current context of uvars, --- and @preF@ and @preG@ are the preconditions associated with @F@ and @G@, +-- and @invarF@ and @invarG@ are the invariants associated with @F@ and @G@, -- respectively. This proof is performed by coinductively assuming the -- refinement holds and proving the refinement with the definitions of @F@ and -- @G@ unfolded to their bodies. Note that this refinement is performed with --- /only/ the preconditions @preF@ and @preG@ as assumptions; all other +-- /only/ the invariants @invarF@ and @invarG@ as assumptions; all other -- assumptions are thrown away. If while running the refinement computation a -- 'CoIndHypMismatchWidened' error is reached with the given names, the state is -- restored and the computation is re-run with the widened hypothesis. mrRefinesCoInd :: FunName -> [Term] -> FunName -> [Term] -> MRM () mrRefinesCoInd f1 args1 f2 args2 = do ctx <- mrUVarCtx - preF1 <- mrGetPrecond f1 - preF2 <- mrGetPrecond f2 + preF1 <- mrGetInvariant f1 + preF2 <- mrGetInvariant f2 let hyp = CoIndHyp ctx f1 f2 args1 args2 preF1 preF2 - proveCoIndHypPreCond hyp + proveCoIndHypInvariant hyp proveCoIndHyp hyp -- | Prove the refinement represented by a 'CoIndHyp' coinductively. This is the @@ -389,9 +490,9 @@ proveCoIndHyp hyp = debugPretty 1 ("proveCoIndHyp" <+> ppInEmptyCtx hyp) lhs <- fromMaybe (error "proveCoIndHyp") <$> mrFunBody f1 args1 rhs <- fromMaybe (error "proveCoIndHyp") <$> mrFunBody f2 args2 - (pre1, pre2) <- applyCoIndHypPreconds hyp - pre <- liftSC2 scAnd pre1 pre2 - (withOnlyUVars (coIndHypCtx hyp) $ withOnlyAssumption pre $ + (invar1, invar2) <- applyCoIndHypInvariants hyp + invar <- liftSC2 scAnd invar1 invar2 + (withOnlyUVars (coIndHypCtx hyp) $ withOnlyAssumption invar $ withCoIndHyp hyp $ mrRefines lhs rhs) `catchError` \case MRExnWiden nm1' nm2' new_vars | f1 == nm1' && f2 == nm2' -> @@ -415,7 +516,7 @@ matchCoIndHyp hyp args1 args2 = if and (eqs1 ++ eqs2) then return () else throwError $ MRExnWiden (coIndHypLHSFun hyp) (coIndHypRHSFun hyp) (map Left (findIndices not eqs1) ++ map Right (findIndices not eqs2)) - proveCoIndHypPreCond hyp + proveCoIndHypInvariant hyp -- | Generalize some of the arguments of a coinductive hypothesis @@ -446,7 +547,7 @@ generalizeCoIndHyp hyp all_specs@(arg_spec:arg_specs) = -- | Add a new variable of the given type to the context of a coinductive -- hypothesis and set the specified arguments to that new variable generalizeCoIndHypArgs :: CoIndHyp -> Term -> [Either Int Int] -> MRM CoIndHyp -generalizeCoIndHypArgs (CoIndHyp ctx f1 f2 args1 args2 pre1 pre2) tp specs = +generalizeCoIndHypArgs (CoIndHyp ctx f1 f2 args1 args2 invar1 invar2) tp specs = do let set_arg i args = take i args ++ (Unshared $ LocalVar 0) : drop (i+1) args let (specs1, specs2) = partitionEithers specs @@ -455,7 +556,7 @@ generalizeCoIndHypArgs (CoIndHyp ctx f1 f2 args1 args2 pre1 pre2) tp specs = args2' <- liftTermLike 0 1 args2 let args1'' = foldr set_arg args1' specs1 args2'' = foldr set_arg args2' specs2 - return $ CoIndHyp (ctx ++ [("z",tp)]) f1 f2 args1'' args2'' pre1 pre2 + return $ CoIndHyp (ctx ++ [("z",tp)]) f1 f2 args1'' args2'' invar1 invar2 ---------------------------------------------------------------------- @@ -490,34 +591,37 @@ mrRefines' (ErrorM _) (ErrorM _) = return () mrRefines' (ReturnM e) (ErrorM _) = throwMRFailure (ReturnNotError e) mrRefines' (ErrorM _) (ReturnM e) = throwMRFailure (ReturnNotError e) --- A maybe eliminator on an equality type on the left +-- maybe elimination on equality types mrRefines' (MaybeElim (Type (asEq -> Just (tp,e1,e2))) m1 f1 _) m2 = do cond <- mrEq' tp e1 e2 not_cond <- liftSC1 scNot cond cond_pf <- liftSC1 scEqTrue cond >>= mrDummyProof m1' <- applyNormCompFun f1 cond_pf cond_holds <- mrProvable cond - if cond_holds then mrRefines m1' m2 else - withAssumption cond (mrRefines m1' m2) >> - withAssumption not_cond (mrRefines m1 m2) - --- A maybe eliminator on an equality type on the right + not_cond_holds <- mrProvable not_cond + case (cond_holds, not_cond_holds) of + (True, _) -> mrRefines m1' m2 + (_, True) -> mrRefines m1 m2 + _ -> withAssumption cond (mrRefines m1' m2) >> + withAssumption not_cond (mrRefines m1 m2) mrRefines' m1 (MaybeElim (Type (asEq -> Just (tp,e1,e2))) m2 f2 _) = do cond <- mrEq' tp e1 e2 not_cond <- liftSC1 scNot cond cond_pf <- liftSC1 scEqTrue cond >>= mrDummyProof m2' <- applyNormCompFun f2 cond_pf cond_holds <- mrProvable cond - if cond_holds then mrRefines m1 m2' else - withAssumption cond (mrRefines m1 m2') >> - withAssumption not_cond (mrRefines m1 m2) - --- A maybe eliminator on an isFinite type on the left + not_cond_holds <- mrProvable not_cond + case (cond_holds, not_cond_holds) of + (True, _) -> mrRefines m1 m2' + (_, True) -> mrRefines m1 m2 + _ -> withAssumption cond (mrRefines m1 m2') >> + withAssumption not_cond (mrRefines m1 m2) + +-- maybe elimination on isFinite types mrRefines' (MaybeElim (Type (asIsFinite -> Just n1)) m1 f1 _) m2 = do n1_norm <- mrNormOpenTerm n1 maybe_assump <- mrGetDataTypeAssump n1_norm - fin_pf <- - liftSC2 scGlobalApply "CryptolM.isFinite" [n1_norm] >>= mrDummyProof + fin_pf <- mrIsFinite n1_norm >>= mrDummyProof case (maybe_assump, asNum n1_norm) of (_, Just (Left _)) -> applyNormCompFun f1 fin_pf >>= flip mrRefines m2 (_, Just (Right _)) -> mrRefines m1 m2 @@ -529,13 +633,10 @@ mrRefines' (MaybeElim (Type (asIsFinite -> Just n1)) m1 f1 _) m2 = (withUVarLift "n" (Type nat_tp) (n1_norm, f1, m2) $ \ n (n1', f1', m2') -> withDataTypeAssump n1' (IsNum n) (applyNormCompFun f1' n >>= flip mrRefines m2')) - --- A maybe eliminator on an isFinite type on the right mrRefines' m1 (MaybeElim (Type (asIsFinite -> Just n2)) m2 f2 _) = do n2_norm <- mrNormOpenTerm n2 maybe_assump <- mrGetDataTypeAssump n2_norm - fin_pf <- - liftSC2 scGlobalApply "CryptolM.isFinite" [n2_norm] >>= mrDummyProof + fin_pf <- mrIsFinite n2_norm >>= mrDummyProof case (maybe_assump, asNum n2_norm) of (_, Just (Left _)) -> applyNormCompFun f2 fin_pf >>= mrRefines m1 (_, Just (Right _)) -> mrRefines m1 m2 @@ -600,6 +701,11 @@ mrRefines' m1 (Either ltp2 rtp2 f2 g2 t2) = applyNormCompFun g2' x >>= withDataTypeAssump t2'' (IsRight x) . mrRefines m1') +mrRefines' m1 (AssumingM cond2 m2) = + withAssumption cond2 $ mrRefines m1 m2 +mrRefines' (AssertingM cond1 m1) m2 = + withAssumption cond1 $ mrRefines m1 m2 + mrRefines' m1 (ForallM tp f2) = let nm = maybe "x" id (compFunVarName f2) in withUVarLift nm tp (m1,f2) $ \x (m1',f2') -> @@ -691,23 +797,6 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = mrDebugPPPrefixSep 1 "mrRefines: bind types not equal:" tp1 "/=" tp2 >> throwMRFailure (CompsDoNotRefine m1 m2) -{- FIXME: handle FunBind on just one side -mrRefines' m1@(FunBind f@(GlobalName _) args k1) m2 = - mrGetFunAssump f >>= \case - Just fassump -> - -- If we have an assumption that f args' refines some rhs, then prove that - -- args = args' and then that rhs refines m2 - do (assump_args, assump_rhs) <- instantiateFunAssump fassump - zipWithM_ mrAssertProveEq assump_args args - m1' <- normBind assump_rhs k1 - mrRefines m1' m2 - Nothing -> - -- We don't want to do inter-procedural proofs, so if we don't know anything - -- about f already then give up - throwMRFailure (CompsDoNotRefine m1 m2) --} - - mrRefines' m1@(FunBind f1 args1 k1) m2 = mrGetFunAssump f1 >>= \case @@ -730,9 +819,7 @@ mrRefines' m1@(FunBind f1 args1 k1) m2 = -- Otherwise we would have to somehow split m2 into some computation of the -- form m2' >>= k2 where f1 args1 |= m2' and k1 |= k2, but we don't know how -- to do this splitting, so give up - _ -> - throwMRFailure (CompsDoNotRefine m1 m2) - + _ -> mrRefines'' m1 m2 mrRefines' m1 m2@(FunBind f2 args2 k2) = mrFunBodyRecInfo f2 args2 >>= \case @@ -753,25 +840,37 @@ mrRefines' m1 m2@(FunBind f2 args2 k2) = -- Otherwise we would have to somehow split m1 into some computation of the -- form m1' >>= k1 where m1' |= f2 args2 and k1 |= k2, but we don't know how -- to do this splitting, so give up - _ -> - throwMRFailure (CompsDoNotRefine m1 m2) + _ -> mrRefines'' m1 m2 + +mrRefines' m1 m2 = mrRefines'' m1 m2 +-- | The cases of 'mrRefines' which must occur after the ones in 'mrRefines''. +-- For example, the rules that introduce existential variables need to go last, +-- so that they can quantify over as many universals as possible +mrRefines'' :: NormComp -> NormComp -> MRM () --- NOTE: the rules that introduce existential variables need to go last, so that --- they can quantify over as many universals as possible -mrRefines' m1 (ExistsM tp f2) = +mrRefines'' m1 (AssertingM cond2 m2) = + mrProvable cond2 >>= \cond2_pv -> + if cond2_pv then mrRefines m1 m2 + else throwMRFailure (AssertionNotProvable cond2) +mrRefines'' (AssumingM cond1 m1) m2 = + mrProvable cond1 >>= \cond1_pv -> + if cond1_pv then mrRefines m1 m2 + else throwMRFailure (AssumptionNotProvable cond1) + +mrRefines'' m1 (ExistsM tp f2) = do let nm = maybe "x" id (compFunVarName f2) evar <- mrFreshEVar nm tp m2' <- applyNormCompFun f2 evar mrRefines m1 m2' -mrRefines' (ForallM tp f1) m2 = +mrRefines'' (ForallM tp f1) m2 = do let nm = maybe "x" id (compFunVarName f1) evar <- mrFreshEVar nm tp m1' <- applyNormCompFun f1 evar mrRefines m1' m2 -- If none of the above cases match, then fail -mrRefines' m1 m2 = throwMRFailure (CompsDoNotRefine m1 m2) +mrRefines'' m1 m2 = throwMRFailure (CompsDoNotRefine m1 m2) -- | Prove that one function refines another for all inputs @@ -791,6 +890,77 @@ mrRefinesFun _ _ = error "mrRefinesFun: unreachable!" -- * External Entrypoints ---------------------------------------------------------------------- +-- | The main loop of 'askMRSolver'. The first argument is an accumulator of +-- variables to introduce, innermost first. +askMRSolverH :: [Term] -> Term -> Term -> Term -> Term -> MRM MREnv + +-- If we need to introduce a bitvector on one side and a Num on the other, +-- introduce a bitvector variable and substitute `TCNum` of `bvToNat` of that +-- variable on the Num side +askMRSolverH vars (asPi -> Just (nm1, tp@(asBitvectorType -> Just n), body1)) t1 + (asPi -> Just (nm2, asDataType -> Just (primName -> "Cryptol.Num", _), body2)) t2 = + let nm = if Text.head nm2 == '_' then nm1 else nm2 in + withUVarLift nm (Type tp) (vars, t1, t2) $ \var (vars', t1', t2') -> + do nat_tm <- liftSC2 scBvToNat n var + num_tm <- liftSC2 scCtorApp "Cryptol.TCNum" [nat_tm] + body2' <- substTerm 0 (num_tm : vars') body2 + t1'' <- mrApplyAll t1' [var] + t2'' <- mrApplyAll t2' [num_tm] + askMRSolverH (var : vars') body1 t1'' body2' t2'' +askMRSolverH vars (asPi -> Just (nm1, asDataType -> Just (primName -> "Cryptol.Num", _), body1)) t1 + (asPi -> Just (nm2, tp@(asBitvectorType -> Just n), body2)) t2 = + let nm = if Text.head nm2 == '_' then nm1 else nm2 in + withUVarLift nm (Type tp) (vars, t1, t2) $ \var (vars', t1', t2') -> + do nat_tm <- liftSC2 scBvToNat n var + num_tm <- liftSC2 scCtorApp "Cryptol.TCNum" [nat_tm] + body1' <- substTerm 0 (num_tm : vars') body1 + t1'' <- mrApplyAll t1' [num_tm] + t2'' <- mrApplyAll t2' [var] + askMRSolverH (var : vars') body1' t1'' body2 t2'' + +-- Introduce variables of the same type together +askMRSolverH vars tp11@(asPi -> Just (nm1, tp1, body1)) t1 + tp22@(asPi -> Just (nm2, tp2, body2)) t2 = + do tps_are_eq <- mrConvertible tp1 tp2 + if tps_are_eq then return () else + throwMRFailure (TypesNotEq (Type tp11) (Type tp22)) + let nm = if Text.head nm2 == '_' then nm1 else nm2 + withUVarLift nm (Type tp1) (vars, t1, t2) $ \var (vars', t1', t2') -> + do t1'' <- mrApplyAll t1' [var] + t2'' <- mrApplyAll t2' [var] + askMRSolverH (var : vars') body1 t1'' body2 t2'' + +-- Error if we don't have the same number of arguments on both sides +askMRSolverH _ tp1@(asPi -> Just _) _ tp2 _ = + throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) +askMRSolverH _ tp1 _ tp2@(asPi -> Just _) _ = + throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) + +-- The base case: both sides are CompM of the same type +askMRSolverH _ tp1@(asCompM -> Just _) t1 tp2@(asCompM -> Just _) t2 = + do tps_are_eq <- mrConvertible tp1 tp2 + if tps_are_eq then return () else + throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) + m1 <- normCompTerm t1 + m2 <- normCompTerm t2 + mrRefines m1 m2 + -- If t1 is a named function, add forall xs. f1 xs |= m2 to the env + case asApplyAll t1 of + ((asGlobalFunName -> Just f1), args) -> + mrUVarCtx >>= \uvar_ctx -> + let fassump = FunAssump { fassumpCtx = uvar_ctx, + fassumpArgs = args, + fassumpRHS = m2 } in + mrEnvAddFunAssump f1 fassump <$> mrEnv + _ -> mrEnv + +-- Error if we don't have CompM at the end +askMRSolverH _ (asCompM -> Just _) _ tp2 _ = + throwMRFailure (NotCompFunType tp2) +askMRSolverH _ tp1 _ _ _ = + throwMRFailure (NotCompFunType tp1) + + -- | Test two monadic, recursive terms for refinement. On success, if the -- left-hand term is a named function, add the refinement to the 'MREnv' -- environment. @@ -804,23 +974,6 @@ askMRSolver :: askMRSolver sc dlvl env timeout t1 t2 = do tp1 <- scTypeOf sc t1 >>= scWhnf sc tp2 <- scTypeOf sc t2 >>= scWhnf sc - case asPiList tp1 of - (uvar_ctx, asCompM -> Just _) -> - runMRM sc timeout dlvl env $ - withUVars uvar_ctx $ \vars -> - do tps_are_eq <- mrConvertible tp1 tp2 - if tps_are_eq then return () else - throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) - mrDebugPPPrefixSep 1 "mr_solver" t1 "|=" t2 - m1 <- mrApplyAll t1 vars >>= normCompTerm - m2 <- mrApplyAll t2 vars >>= normCompTerm - mrRefines m1 m2 - -- If t1 is a named function, add forall xs. f1 xs |= m2 to the env - case asGlobalFunName t1 of - Just f1 -> - let fassump = FunAssump { fassumpCtx = uvar_ctx, - fassumpArgs = vars, - fassumpRHS = m2 } in - return $ mrEnvAddFunAssump f1 fassump env - Nothing -> return env - _ -> return $ Left $ NotCompFunType tp1 + runMRM sc timeout dlvl env $ + mrDebugPPPrefixSep 1 "mr_solver" t1 "|=" t2 >> + askMRSolverH [] tp1 t1 tp2 t2 diff --git a/src/SAWScript/Prover/MRSolver/Term.hs b/src/SAWScript/Prover/MRSolver/Term.hs index cd7a10c86d..a498c61717 100644 --- a/src/SAWScript/Prover/MRSolver/Term.hs +++ b/src/SAWScript/Prover/MRSolver/Term.hs @@ -118,6 +118,8 @@ data NormComp | Either Type Type CompFun CompFun Term -- ^ A sum elimination | MaybeElim Type Comp CompFun Term -- ^ A maybe elimination | OrM Comp Comp -- ^ an @orM@ computation + | AssertingM Term Comp -- ^ an @assertingM@ computation + | AssumingM Term Comp -- ^ an @assumingM@ computation | ExistsM Type CompFun -- ^ an @existsM@ computation | ForallM Type CompFun -- ^ a @forallM@ computation | FunBind FunName [Term] CompFun @@ -154,6 +156,11 @@ compFunInputType (CompFunTerm (asLambda -> Just (_, tp, _))) = Just $ Type tp compFunInputType (CompFunComp f _) = compFunInputType f compFunInputType _ = Nothing +-- | Returns true iff the given 'CompFun' is 'CompFunReturn' +isCompFunReturn :: CompFun -> Bool +isCompFunReturn CompFunReturn = True +isCompFunReturn _ = False + -- | A computation of type @CompM a@ for some @a@ data Comp = CompTerm Term | CompBind Comp CompFun | CompReturn Term deriving (Generic, Show) @@ -170,6 +177,12 @@ isCompFunType sc t = scWhnf sc t >>= \case (asPiList -> (_, asCompM -> Just _)) -> return True _ -> return False +-- | Recognize a 'Term' as an application of `bvToNat` +asBvToNat :: Recognizer Term (Term, Term) +asBvToNat (asApplyAll -> ((isGlobalDef "Prelude.bvToNat" -> Just ()), + [n, x])) = Just (n, x) +asBvToNat _ = Nothing + ---------------------------------------------------------------------- -- * Mr Solver Environments @@ -426,6 +439,12 @@ instance PrettyInCtx NormComp where prettyInCtx (OrM t1 t2) = prettyAppList [return "orM", return "_", parens <$> prettyInCtx t1, parens <$> prettyInCtx t2] + prettyInCtx (AssertingM cond t) = + prettyAppList [return "assertingM", parens <$> prettyInCtx cond, + parens <$> prettyInCtx t] + prettyInCtx (AssumingM cond t) = + prettyAppList [return "assumingM", parens <$> prettyInCtx cond, + parens <$> prettyInCtx t] prettyInCtx (ExistsM tp f) = prettyAppList [return "existsM", prettyInCtx tp, return "_", parens <$> prettyInCtx f]