diff --git a/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs b/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs index 2e549a0969..39b32169d9 100644 --- a/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs +++ b/cryptol-saw-core/src/Verifier/SAW/Cryptol/Monadify.hs @@ -92,7 +92,7 @@ import Verifier.SAW.SharedTerm import Verifier.SAW.OpenTerm import Verifier.SAW.TypedTerm import Verifier.SAW.Cryptol (Env) --- import Verifier.SAW.SCTypeCheck +import Verifier.SAW.SCTypeCheck import Verifier.SAW.Recognizer -- import Verifier.SAW.Position import Verifier.SAW.Cryptol.PreludeM @@ -608,7 +608,9 @@ fromCompTerm mtp t = ArgMonTerm $ fromArgTerm mtp t -- | Take a function of type @A1 -> ... -> An -> SpecM E emptyFunStack B@ and -- lift the stack of the output type to an arbitrary @stack@ parameter using --- @liftStackS@ +-- @liftStackS@. Note that @liftStackS@ is only added if the stack of the +-- output type is non-empty, i.e. not @emptyFunStack@. Otherwise, this operation +-- leaves the function unchanged. class LiftCompStack a where liftCompStack :: HasSpecMParams => a -> a @@ -623,10 +625,14 @@ instance LiftCompStack ArgMonTerm where instance LiftCompStack MonTerm where liftCompStack (ArgMonTerm amtrm) = ArgMonTerm $ liftCompStack amtrm - liftCompStack (CompMonTerm mtp trm) = - CompMonTerm mtp $ - applyGlobalOpenTerm "Prelude.liftStackS" - [specMEvType ?specMParams, specMStack ?specMParams, toArgType mtp, trm] + liftCompStack (CompMonTerm mtp trm) = CompMonTerm mtp $ OpenTerm $ do + -- Only add @liftStackS@ when the stack is not @emptyFunStack@ + empty_stk <- typedVal <$> unOpenTerm emptyStackOpenTerm + curr_stk <- typedVal <$> unOpenTerm (specMStack ?specMParams) + curr_stk_empty <- liftTCM scConvertible False empty_stk curr_stk + unOpenTerm $ if curr_stk_empty then trm else + applyGlobalOpenTerm "Prelude.liftStackS" + [specMEvType ?specMParams, specMStack ?specMParams, toArgType mtp, trm] -- | Test if a monadification type @tp@ is pure, meaning @MT(tp)=tp@ monTypeIsPure :: MonType -> Bool @@ -708,8 +714,8 @@ applyMonTermMulti :: HasCallStack => MonTerm -> [Either MonType ArgMonTerm] -> applyMonTermMulti = foldl applyMonTerm -- | Build a 'MonTerm' from a global of a given argument type, applying it to --- the current 'SpecMParams' if the 'Bool' flag is 'True' and lifting it using --- @liftStackS@ if it is 'False' +-- the current 'SpecMParams' if the 'Bool' flag is 'True' or lifting it using +-- @liftStackS@ if it is 'False' and the stack is non-empty mkGlobalArgMonTerm :: HasSpecMParams => MonType -> Ident -> Bool -> ArgMonTerm mkGlobalArgMonTerm tp ident params_p = (if params_p then id else liftCompStack) $ @@ -747,9 +753,11 @@ data MonMacro = MonMacro { macroNumArgs :: Int, macroApply :: GlobalDef -> [Term] -> MonadifyM MonTerm } --- | Make a simple 'MonMacro' that inspects 0 arguments and just returns a term +-- | Make a simple 'MonMacro' that inspects 0 arguments and just returns a term, +-- lifted with @liftStackS@ if the outer stack is non-empty monMacro0 :: MonTerm -> MonMacro -monMacro0 mtrm = MonMacro 0 (\_ _ -> return mtrm) +monMacro0 mtrm = MonMacro 0 $ \_ _ -> usingSpecMParams $ + return $ liftCompStack mtrm -- | Make a 'MonMacro' that maps a named global to a global of semi-pure type. -- (See 'fromSemiPureTermFun'.) Because we can't get access to the type of the @@ -773,7 +781,7 @@ semiPureGlobalMacro from to params_p = -- indicates whether the "to" global is polymorphic in the event type and -- function stack; if so, the current 'SpecMParams' are passed as its first two -- arguments, and otherwise the returned computation is lifted with --- @liftStackS@. +-- @liftStackS@ if the outer stack is non-empty. argGlobalMacro :: NameInfo -> Ident -> Bool -> MonMacro argGlobalMacro from to params_p = MonMacro 0 $ \glob args -> usingSpecMParams $ diff --git a/heapster-saw/examples/Makefile b/heapster-saw/examples/Makefile index ae44946fef..7f9b6e2ebf 100644 --- a/heapster-saw/examples/Makefile +++ b/heapster-saw/examples/Makefile @@ -41,7 +41,7 @@ endif $(SAW) $< # Lists all the Mr Solver tests, without their ".saw" suffix -MR_SOLVER_TESTS = exp_explosion_mr_solver # arrays_mr_solver linked_list_mr_solver sha512_mr_solver +MR_SOLVER_TESTS = exp_explosion_mr_solver linked_list_mr_solver arrays_mr_solver sha512_mr_solver .PHONY: mr-solver-tests $(MR_SOLVER_TESTS) mr-solver-tests: $(MR_SOLVER_TESTS) diff --git a/heapster-saw/examples/SpecPrims.cry b/heapster-saw/examples/SpecPrims.cry index 3fc5d022fd..f62cf9782a 100644 --- a/heapster-saw/examples/SpecPrims.cry +++ b/heapster-saw/examples/SpecPrims.cry @@ -2,25 +2,25 @@ module SpecPrims where /* Specification primitives */ -// The specification that holds for f x for some input x -exists : {a, b} (a -> b) -> b -exists f = error "Cannot run exists" +// The specification that holds for some element of type a +exists : {a} a +exists = error "Cannot run exists" -// The specification that holds for f x for all inputs x -forall : {a, b} (a -> b) -> b -forall f = error "Cannot run forall" +// The specification that holds for all elements of type a +forall : {a} a +forall = error "Cannot run forall" -// The specification that a computation returns some value with no errors -returnsSpec : {a} a -returnsSpec = exists (\x -> x) +// The specification that a computation has no errors +noErrors : {a} a +noErrors = exists // The specification that matches any computation. This calls exists at the -// function type () -> a, which is monadified to () -> CompM a. This means that +// function type () -> a, which is monadified to () -> SpecM a. This means that // the exists does not just quantify over all values of type a like noErrors, // but it quantifies over all computations of type a, including those that // contain errors. anySpec : {a} a -anySpec = exists (\f -> f ()) +anySpec = exists () // The specification which asserts that the first argument is True and then // returns the second argument diff --git a/heapster-saw/examples/arrays.sawcore b/heapster-saw/examples/arrays.sawcore index ff7ef8cdbb..bdb4dedf3d 100644 --- a/heapster-saw/examples/arrays.sawcore +++ b/heapster-saw/examples/arrays.sawcore @@ -3,39 +3,41 @@ module arrays where import Prelude; +-- The LetRecType of noErrorsContains0 +noErrorsContains0LRT : LetRecType; +noErrorsContains0LRT = + LRT_Fun (Vec 64 Bool) (\ (len:Vec 64 Bool) -> + LRT_Fun (Vec 64 Bool) (\ (_:Vec 64 Bool) -> + LRT_Fun (BVVec 64 len (Vec 64 Bool)) (\ (_:BVVec 64 len (Vec 64 Bool)) -> + LRT_Ret (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool)))); + -- The helper function for noErrorsContains0 -- -- noErrorsContains0H len i v = --- orM (exists x. returnM x) (noErrorsContains0H len (i+1) v) +-- orS (existsS x. x) (noErrorsContains0H len (i+1) v) noErrorsContains0H : (len i:Vec 64 Bool) -> BVVec 64 len (Vec 64 Bool) -> - CompM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool); -noErrorsContains0H len_top i_top v_top = - letRecM - (LRT_Cons - (LRT_Fun (Vec 64 Bool) (\ (len:Vec 64 Bool) -> - LRT_Fun (Vec 64 Bool) (\ (_:Vec 64 Bool) -> - LRT_Fun (BVVec 64 len (Vec 64 Bool)) (\ (_:BVVec 64 len (Vec 64 Bool)) -> - LRT_Ret (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool))))) - LRT_Nil) - (BVVec 64 len_top (Vec 64 Bool) * Vec 64 Bool) - (\ (f : (len i:Vec 64 Bool) -> BVVec 64 len (Vec 64 Bool) -> - CompM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool)) -> - ((\ (len:Vec 64 Bool) (i:Vec 64 Bool) (v:BVVec 64 len (Vec 64 Bool)) -> - invariantHint - (CompM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool)) - (and (bvsle 64 0x0000000000000000 i) - (bvsle 64 i 0x0fffffffffffffff)) - (orM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool) - (existsM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool) - (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool) - (returnM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool))) - (f len (bvAdd 64 i 0x0000000000000001) v))), ())) + SpecM VoidEv emptyFunStack + (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool); +noErrorsContains0H = + multiArgFixS VoidEv emptyFunStack noErrorsContains0LRT (\ (f : (len i:Vec 64 Bool) -> BVVec 64 len (Vec 64 Bool) -> - CompM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool)) -> - f len_top i_top v_top); + SpecM VoidEv (pushFunStack (singletonFrame noErrorsContains0LRT) emptyFunStack) + (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool)) -> + (\ (len:Vec 64 Bool) (i:Vec 64 Bool) (v:BVVec 64 len (Vec 64 Bool)) -> + invariantHint + (SpecM VoidEv (pushFunStack (singletonFrame noErrorsContains0LRT) emptyFunStack) + (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool)) + (and (bvsle 64 0x0000000000000000 i) + (bvsle 64 i 0x0fffffffffffffff)) + (orS VoidEv (pushFunStack (singletonFrame noErrorsContains0LRT) emptyFunStack) + (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool) + (existsS VoidEv (pushFunStack (singletonFrame noErrorsContains0LRT) emptyFunStack) + (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool)) + (f len (bvAdd 64 i 0x0000000000000001) v)))); -- The specification that contains0 has no errors noErrorsContains0 : (len:Vec 64 Bool) -> BVVec 64 len (Vec 64 Bool) -> - CompM (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool); + SpecM VoidEv emptyFunStack + (BVVec 64 len (Vec 64 Bool) * Vec 64 Bool); noErrorsContains0 len v = noErrorsContains0H len 0x0000000000000000 v; diff --git a/heapster-saw/examples/arrays_mr_solver.saw b/heapster-saw/examples/arrays_mr_solver.saw index 2f24a7d9b0..c70ab7c986 100644 --- a/heapster-saw/examples/arrays_mr_solver.saw +++ b/heapster-saw/examples/arrays_mr_solver.saw @@ -11,6 +11,9 @@ prove_extcore mrsolver (refines [] contains0 noErrorsContains0); include "specPrims.saw"; import "arrays.cry"; -zero_array <- parse_core_mod "arrays" "zero_array"; -prove_extcore mrsolver (refines [] zero_array {{ zero_array_loop_spec }}); -prove_extcore mrsolver (refines [] zero_array {{ zero_array_spec }}); +monadify_term {{ zero_array_spec }}; + +// FIXME: Uncomment once FunStacks are removed +// zero_array <- parse_core_mod "arrays" "zero_array"; +// prove_extcore mrsolver (refines [] zero_array {{ zero_array_loop_spec }}); +// prove_extcore mrsolver (refines [] zero_array {{ zero_array_spec }}); diff --git a/heapster-saw/examples/specPrims.saw b/heapster-saw/examples/specPrims.saw index 0f54d7deef..cd57da322a 100644 --- a/heapster-saw/examples/specPrims.saw +++ b/heapster-saw/examples/specPrims.saw @@ -2,9 +2,8 @@ import "SpecPrims.cry"; -set_monadification "exists" "Prelude.existsM"; -set_monadification "forall" "Prelude.forallM"; -set_monadification "anySpec" "Prelude.anySpec"; -set_monadification "asserting" "Prelude.asserting"; -set_monadification "assuming" "Prelude.assuming"; -set_monadification "invariantHint" "Prelude.invariantHint"; +set_monadification "exists" "Prelude.existsS" true; +set_monadification "forall" "Prelude.forallS" true; +set_monadification "asserting" "Prelude.asserting" true; +set_monadification "assuming" "Prelude.assuming" true; +set_monadification "invariantHint" "Prelude.invariantHint" true; diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index aff5b94ed9..f7b4663444 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -546,10 +546,6 @@ liftSC5 f a b c d e = mrSC >>= \sc -> liftIO (f sc a b c d e) -- * Functions for Building Terms ---------------------------------------------------------------------- --- | Create a term representing the type @IsFinite n@ -mrIsFinite :: Term -> MRM t Term -mrIsFinite n = liftSC2 scGlobalApply "CryptolM.isFinite" [n] - -- | Create a term representing an application of @Prelude.error@ mrErrorTerm :: Term -> T.Text -> MRM t Term mrErrorTerm a str = @@ -1094,7 +1090,8 @@ mrGetFunAssump nm = lookupFunAssump nm <$> mrRefnset withFunAssump :: FunName -> [Term] -> Term -> MRM t a -> MRM t a withFunAssump fname args rhs m = do k <- mkCompFunReturn <$> mrFunOutType fname args - mrDebugPPPrefixSep 1 "withFunAssump" (FunBind fname args k) "|=" rhs + mrDebugPPPrefixSep 1 "withFunAssump" (FunBind fname args Unlifted k) + "|=" rhs ctx <- mrUVars rs <- mrRefnset let assump = FunAssump ctx fname args (RewriteFunAssump rhs) Nothing diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index 65a233abee..20f2d2254f 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -1,5 +1,6 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE ViewPatterns #-} -- This is to stop GHC 8.8.4's pattern match checker exceeding its limit when @@ -19,32 +20,47 @@ Portability : non-portable (language extensions) This module implements a monadic-recursive solver, for proving that one monadic term refines another. The algorithm works on the "monadic normal form" of -computations, which uses the following laws to simplify binds in computations, -where either is the sum elimination function defined in the SAW core prelude: - -returnM x >>= k = k x -errorM str >>= k = errorM -(m >>= k1) >>= k2 = m >>= \x -> k1 x >>= k2 -(existsM f) >>= k = existsM (\x -> f x >>= k) -(forallM f) >>= k = forallM (\x -> f x >>= k) -(assumingM b m) >>= k = assumingM b (m >>= k) -(assertingM b m) >>= k = assertingM b (m >>= k) -(orM m1 m2) >>= k = orM (m1 >>= k) (m2 >>= k) -(if b then m1 else m2) >>= k = if b then m1 >>= k else m2 >>1 k -(either f1 f2 e) >>= k = either (\x -> f1 x >= k) (\x -> f2 x >= k) e -(letrecM funs body) >>= k = letrecM funs (\F1 ... Fn -> body F1 ... Fn >>= k) - -The resulting computations of one of the following forms: - -returnM e | errorM str | existsM f | forallM f | orM m1 m2 | -if b then m1 else m2 | either f1 f2 e | F e1 ... en | F e1 ... en >>= k | -letrecM lrts B (\F1 ... Fn -> (f1, ..., fn)) (\F1 ... Fn -> m) - -The form F e1 ... en refers to a recursively-defined function or a function -variable that has been locally bound by a letrecM. Either way, monadic +computations, which uses the following laws to simplify binds and calls to +@liftStackS@ in computations, where @either@ is the sum elimination function +defined in the SAW core prelude: + +> retS x >>= k = k x +> errorS str >>= k = errorM +> (m >>= k1) >>= k2 = m >>= \x -> k1 x >>= k2 +> (existsS f) >>= k = existsM (\x -> f x >>= k) +> (forallS f) >>= k = forallM (\x -> f x >>= k) +> (assumingS b m) >>= k = assumingM b (m >>= k) +> (assertingS b m) >>= k = assertingM b (m >>= k) +> (orS m1 m2) >>= k = orM (m1 >>= k) (m2 >>= k) +> (if b then m1 else m2) >>= k = if b then m1 >>= k else m2 >>= k +> (either f1 f2 e) >>= k = either (\x -> f1 x >>= k) (\x -> f2 x >>= k) e +> (multiFixS funs body) >>= k = multiFixS funs (\F1 ... Fn -> body F1 ... Fn >>= k) +> +> liftStackS (retS x) = retS x +> liftStackS (errorS str) = errorS str +> liftStackS (m >>= k) = liftStackS m >>= \x -> liftStackS (k x) +> liftStackS (existsS f) = existsM (\x -> liftStackS (f x)) +> liftStackS (forallS f) = forallM (\x -> liftStackS (f x)) +> liftStackS (assumingS b m) = assumingM b (liftStackS m) +> liftStackS (assertingS b m) = assertingM b (liftStackS m) +> liftStackS (orS m1 m2) = orM (liftStackS m1) (liftStackS m2) +> liftStackS (if b then m1 else m2) = if b then liftStackS m1 else liftStackS m2 +> liftStackS (either f1 f2 e) = either (\x -> liftStackS f1 x) (\x -> liftStackS f2 x) e +> liftStackS (multiFixS funs body) = multiFixS funs (\F1 ... Fn -> liftStackS (body F1 ... Fn)) + +The resulting computations are in one of the following forms: + +> returnM e | errorM str | existsM f | forallM f | assumingS b m | +> assertingS b m | orM m1 m2 | if b then m1 else m2 | either f1 f2 e | +> F e1 ... en | liftStackS (F e1 ... en) | +> F e1 ... en >>= k | liftStackS (F e1 ... en) >>= k | +> multiFixS (\F1 ... Fn -> (f1, ..., fn)) (\F1 ... Fn -> m) + +The form @F e1 ... en@ refers to a recursively-defined function or a function +variable that has been locally bound by a @multiFixS@. Either way, monadic normalization does not attempt to normalize these functions. -The algorithm maintains a context of three sorts of variables: letrec-bound +The algorithm maintains a context of three sorts of variables: @multiFixS@-bound variables, existential variables, and universal variables. Universal variables are represented as free SAW core variables, while the other two forms of variable are represented as SAW core 'ExtCns's terms, which are essentially @@ -52,73 +68,78 @@ axioms that have been generated internally. These 'ExtCns's are Skolemized, meaning that they take in as arguments all universal variables that were in scope when they were created. The context also maintains a partial substitution for the existential variables, as they become instantiated with values, and it -additionally remembers the bodies / unfoldings of the letrec-bound variables. - -The goal of the solver at any point is of the form C |- m1 |= m2, meaning that -we are trying to prove m1 refines m2 in context C. This proceed by cases: - -C |- returnM e1 |= returnM e2: prove C |- e1 = e2 - -C |- errorM str1 |= errorM str2: vacuously true - -C |- if b then m1' else m1'' |= m2: prove C,b=true |- m1' |= m2 and -C,b=false |- m1'' |= m2, skipping either case where C,b=X is unsatisfiable; - -C |- m1 |= if b then m2' else m2'': similar to the above - -C |- either T U (SpecM V) f1 f2 e |= m: prove C,x:T,e=inl x |- f1 x |= m and -C,y:U,e=inl y |- f2 y |= m, again skippping any case with unsatisfiable context; - -C |- m |= either T U (SpecM V) f1 f2 e: similar to previous - -C |- m |= forallM f: make a new universal variable x and recurse - -C |- existsM f |= m: make a new universal variable x and recurse (existential -elimination uses universal variables and vice-versa) - -C |- m |= existsM f: make a new existential variable x and recurse - -C |- forall f |= m: make a new existential variable x and recurse - -C |- m |= orM m1 m2: try to prove C |- m |= m1, and if that fails, backtrack and -prove C |- m |= m2 - -C |- orM m1 m2 |= m: prove both C |- m1 |= m and C |- m2 |= m - -C |- letrec (\F1 ... Fn -> (f1, ..., fn)) (\F1 ... Fn -> body) |= m: create -letrec-bound variables F1 through Fn in the context bound to their unfoldings f1 -through fn, respectively, and recurse on body |= m - -C |- m |= letrec (\F1 ... Fn -> (f1, ..., fn)) (\F1 ... Fn -> body): similar to -previous case - -C |- F e1 ... en >>= k |= F e1' ... en' >>= k': prove C |- ei = ei' for each i -and then prove k x |= k' x for new universal variable x - -C |- F e1 ... en >>= k |= F' e1' ... em' >>= k': - -* If we have an assumption that forall x1 ... xj, F a1 ... an |= F' a1' .. am', - prove ei = ai and ei' = ai' and then that C |- k x |= k' x for fresh uvar x - -* If we have an assumption that forall x1, ..., xn, F e1'' ... en'' |= m' for - some ei'' and m', match the ei'' against the ei by instantiating the xj with - fresh evars, and if this succeeds then recursively prove C |- m' >>= k |= RHS - -(We don't do this one right now) -* If we have an assumption that forall x1', ..., xn', m |= F e1'' ... en'' for - some ei'' and m', match the ei'' against the ei by instantiating the xj with - fresh evars, and if this succeeds then recursively prove C |- LHS |= m' >>= k' - -* If either side is a definition whose unfolding does not contain letrecM, fixM, - or any related operations, unfold it - -* If F and F' have the same return type, add an assumption forall uvars in scope - that F e1 ... en |= F' e1' ... em' and unfold both sides, recursively proving - that F_body e1 ... en |= F_body' e1' ... em'. Then also prove k x |= k' x for - fresh uvar x. - -* Otherwise we don't know to "split" one of the sides into a bind whose - components relate to the two components on the other side, so just fail +additionally remembers the bodies / unfoldings of the @multiFixS@-bound variables. + +The goal of the solver at any point is of the form @C |- m1 |= m2@, meaning that +we are trying to prove @m1@ refines @m2@ in context @C@. This proceeds by cases: + +> C |- retS e1 |= retS e2: prove C |- e1 = e2 +> +> C |- errorS str1 |= errorS str2: vacuously true +> +> C |- if b then m1' else m1'' |= m2: prove C,b=true |- m1' |= m2 and +> C,b=false |- m1'' |= m2, skipping either case where C,b=X is unsatisfiable; +> +> C |- m1 |= if b then m2' else m2'': similar to the above +> +> C |- either T U (SpecM V) f1 f2 e |= m: prove C,x:T,e=inl x |- f1 x |= m and +> C,y:U,e=inl y |- f2 y |= m, again skippping any case with unsatisfiable context; +> +> C |- m |= either T U (SpecM V) f1 f2 e: similar to previous +> +> C |- m |= forallS f: make a new universal variable x and recurse +> +> C |- existsS f |= m: make a new universal variable x and recurse (existential +> elimination uses universal variables and vice-versa) +> +> C |- m |= existsS f: make a new existential variable x and recurse +> +> C |- forallS f |= m: make a new existential variable x and recurse +> +> C |- m |= orS m1 m2: try to prove C |- m |= m1, and if that fails, backtrack and +> prove C |- m |= m2 +> +> C |- orS m1 m2 |= m: prove both C |- m1 |= m and C |- m2 |= m +> +> C |- multiFixS (\F1 ... Fn -> (f1, ..., fn)) (\F1 ... Fn -> body) |= m: create +> multiFixS-bound variables F1 through Fn in the context bound to their unfoldings +> f1 through fn, respectively, and recurse on body |= m +> +> C |- m |= multiFixS (\F1 ... Fn -> (f1, ..., fn)) (\F1 ... Fn -> body): similar to +> previous case +> +> C |- F e1 ... en >>= k |= F e1' ... en' >>= k': prove C |- ei = ei' for each i +> and then prove k x |= k' x for new universal variable x +> +> C |- F e1 ... en >>= k |= F' e1' ... em' >>= k': +> +> * If we have an assumption that forall x1 ... xj, F a1 ... an |= F' a1' .. am', +> prove ei = ai and ei' = ai' and then that C |- k x |= k' x for fresh uvar x +> +> * If we have an assumption that forall x1, ..., xn, F e1'' ... en'' |= m' for +> some ei'' and m', match the ei'' against the ei by instantiating the xj with +> fresh evars, and if this succeeds then recursively prove C |- m' >>= k |= RHS +> +> (We don't do this one right now) +> * If we have an assumption that forall x1', ..., xn', m |= F e1'' ... en'' for +> some ei'' and m', match the ei'' against the ei by instantiating the xj with +> fresh evars, and if this succeeds then recursively prove C |- LHS |= m' >>= k' +> +> * If either side is a definition whose unfolding does not contain multiFixS, or +> any related operations, unfold it +> +> * If F and F' have the same return type, add an assumption forall uvars in scope +> that F e1 ... en |= F' e1' ... em' and unfold both sides, recursively proving +> that F_body e1 ... en |= F_body' e1' ... em'. Then also prove k x |= k' x for +> fresh uvar x. +> +> * Otherwise we don't know to "split" one of the sides into a bind whose +> components relate to the two components on the other side, so just fail + +Note that if either side of the final case is wrapped in a @liftStackS@, the +behavior is identical, just with a @liftStackS@ wrapped around the appropriate +unfolded function body or bodies. The only exception is the second to final case, +which also requires the both functions either be lifted or unlifted. -} module SAWScript.Prover.MRSolver.Solver where @@ -128,7 +149,7 @@ import Data.Either import Numeric.Natural (Natural) import Data.List (find, findIndices) import Data.Foldable (foldlM) --- import Data.Bits (shiftL) +import Data.Bits (shiftL) import Control.Monad.Reader import Control.Monad.Except import qualified Data.Map as Map @@ -229,6 +250,7 @@ mrFreshCallVars ev stack frame defs_tm = do -- First, make fresh function constants for all the recursive functions, -- noting that each constant must abstract out the current uvar context + -- (see mrFreshVar) new_stack <- liftSC2 scGlobalApply "Prelude.pushFunStack" [frame, stack] lrts <- liftSC1 scWhnf frame >>= \case (asList1 -> Just lrts) -> return lrts @@ -238,14 +260,14 @@ mrFreshCallVars ev stack frame defs_tm = fun_vars <- mapM (mrFreshVar "F") fun_tps -- Next, match on the tuple of recursive function definitions and convert - -- each definition to a function body, by lambda-abstracting all the current - -- uvars and then replacing all recursive calls in each function body with - -- our new variable terms (which are applied to the current uvars; see - -- mrVarTerm) + -- each definition to a function body, by replacing all recursive calls in + -- each function body with our new variable terms (which are applied to the + -- current uvars; see mrVarTerm) and then lambda-abstracting all the + -- current uvars fun_tms <- mapM mrVarTerm fun_vars defs_tm' <- liftSC1 scWhnf defs_tm bodies <- case asNestedPairs defs_tm' of - Just defs -> mapM (lambdaUVarsM >=> mrReplaceCallsWithTerms fun_tms) defs + Just defs -> mapM (mrReplaceCallsWithTerms fun_tms >=> lambdaUVarsM) defs Nothing -> throwMRFailure (MalformedDefs defs_tm) -- Remember the body associated with each fresh function constant @@ -280,6 +302,8 @@ normComp (CompTerm t) = normBind norm (CompFunTerm (SpecMParams e stack) f) (isGlobalDef "Prelude.errorS" -> Just (), [_, _, _, str]) -> return (ErrorS str) + (isGlobalDef "Prelude.liftStackS" -> Just (), [ev, stk, _, t']) -> + normCompTerm t' >>= liftStackNormComp (SpecMParams ev stk) (isGlobalDef "Prelude.ite" -> Just (), [_, cond, then_tm, else_tm]) -> return $ Ite cond (CompTerm then_tm) (CompTerm else_tm) (isGlobalDef "Prelude.either" -> Just (), @@ -326,7 +350,8 @@ normComp (CompTerm t) = -- that it must be applied to all of the uvars as well as the args let var = CallSName (fun_vars !! (fromIntegral i)) all_args <- (++ args) <$> getAllUVarTerms - FunBind var all_args <$> mkCompFunReturn <$> mrFunOutType var all_args + FunBind var all_args Unlifted <$> mkCompFunReturn <$> + mrFunOutType var all_args (isGlobalDef "Prelude.multiArgFixS" -> Just (), _ev:_stack:_lrt:body:args) -> do @@ -347,7 +372,8 @@ normComp (CompTerm t) = -- well as the args let var = CallSName fun_var all_args <- (++ args) <$> getAllUVarTerms - FunBind var all_args <$> mkCompFunReturn <$> mrFunOutType var all_args + FunBind var all_args Unlifted <$> mkCompFunReturn <$> + mrFunOutType var all_args -- Convert `vecMapM (bvToNat ...)` into `bvVecMapInvarM`, with the -- invariant being the current set of assumptions @@ -362,59 +388,57 @@ normComp (CompTerm t) = -- Convert `atM (bvToNat ...) ... (bvToNat ...)` into the unfolding of -- `bvVecAtM` - (asGlobalDef -> Just "CryptolM.atM", [(asBvToNat -> Just (_w1, _n)), _a, _xs, - (asBvToNat -> Just (_w2, _i))]) -> - error "FIXME HERE NOW: need SpecM version of atM" - {- + (asGlobalDef -> Just "CryptolM.atM", [ev, stack, + (asBvToNat -> Just (w1, n)), a, xs, + (asBvToNat -> Just (w2, i))]) -> do body <- mrGlobalDefBody "CryptolM.bvVecAtM" ws_are_eq <- mrConvertible w1 w2 if ws_are_eq then - mrApplyAll body [w1, n, a, xs, i] >>= normCompTerm - else throwMRFailure (MalformedComp t) -} + mrApplyAll body [ev, stack, w1, n, a, xs, i] >>= normCompTerm + else throwMRFailure (MalformedComp t) -- Convert `atM n ... xs (bvToNat ...)` for a constant `n` into the -- unfolding of `bvVecAtM` after converting `n` to a bitvector constant -- and applying `genBVVecFromVec` to `xs` - (asGlobalDef -> Just "CryptolM.atM", [_n_tm@(asNat -> Just _n), _a, _xs, + (asGlobalDef -> Just "CryptolM.atM", [ev, stack, + n_tm@(asNat -> Just n), a, xs, (asBvToNat -> - Just (_w_tm@(asNat -> Just _w), - _i))]) -> - error "FIXME HERE NOW: need SpecM version of atM" - {- + Just (w_tm@(asNat -> Just w), + i))]) -> do body <- mrGlobalDefBody "CryptolM.bvVecAtM" if n < 1 `shiftL` fromIntegral w then do n' <- liftSC2 scBvLit w (toInteger n) xs' <- mrGenBVVecFromVec n_tm a xs "normComp (atM)" w_tm n' - mrApplyAll body [w_tm, n', a, xs', i] >>= normCompTerm - else throwMRFailure (MalformedComp t) -} + mrApplyAll body [ev, stack, w_tm, n', a, xs', i] >>= normCompTerm + else throwMRFailure (MalformedComp t) -- Convert `updateM (bvToNat ...) ... (bvToNat ...)` into the unfolding of -- `bvVecUpdateM` - (asGlobalDef -> Just "CryptolM.updateM", [(asBvToNat -> Just (_w1, _n)), _a, _xs, - (asBvToNat -> Just (_w2, _i)), _x]) -> - error "FIXME HERE NOW: need SpecM version of updateM" - {- + (asGlobalDef -> Just "CryptolM.updateM", [ev, stack, + (asBvToNat -> Just (w1, n)), a, xs, + (asBvToNat -> Just (w2, i)), x]) -> do body <- mrGlobalDefBody "CryptolM.bvVecUpdateM" ws_are_eq <- mrConvertible w1 w2 if ws_are_eq then - mrApplyAll body [w1, n, a, xs, i, x] >>= normCompTerm - else throwMRFailure (MalformedComp t) -} + mrApplyAll body [ev, stack, w1, n, a, xs, i, x] >>= normCompTerm + else throwMRFailure (MalformedComp t) -- Convert `updateM n ... xs (bvToNat ...)` for a constant `n` into the -- unfolding of `bvVecUpdateM` after converting `n` to a bitvector constant -- and applying `genBVVecFromVec` to `xs` - (asGlobalDef -> Just "CryptolM.updateM", - [_n_tm@(asNat -> Just _n), _a, _xs, (asBvToNat -> - Just (_w_tm@(asNat -> Just _w), _i)), _x]) -> - error "FIXME HERE NOW: need SpecM version of updateM" - {- + (asGlobalDef -> Just "CryptolM.updateM", [ev, stack, + n_tm@(asNat -> Just n), a, xs, + (asBvToNat -> + Just (w_tm@(asNat -> Just w), + i)), x]) -> do body <- mrGlobalDefBody "CryptolM.fromBVVecUpdateM" if n < 1 `shiftL` fromIntegral w then do n' <- liftSC2 scBvLit w (toInteger n) xs' <- mrGenBVVecFromVec n_tm a xs "normComp (updateM)" w_tm n' err_tm <- mrErrorTerm a "normComp (updateM)" - mrApplyAll body [w_tm, n', a, xs', i, x, err_tm, n_tm] >>= normCompTerm - else throwMRFailure (MalformedComp t) -} + mrApplyAll body [ev, stack, w_tm, n', a, xs', i, x, err_tm, n_tm] + >>= normCompTerm + else throwMRFailure (MalformedComp t) -- Always unfold: sawLet, multiArgFixM, invariantHint, Num_rec (f@(asGlobalDef -> Just ident), args) @@ -444,11 +468,11 @@ normComp (CompTerm t) = -- FIXME: substitute for evars if they have been instantiated ((asExtCns -> Just ec), args) -> do fun_name <- extCnsToFunName ec - FunBind fun_name args <$> mkCompFunReturn <$> + FunBind fun_name args Unlifted <$> mkCompFunReturn <$> mrFunOutType fun_name args ((asGlobalFunName -> Just f), args) -> - FunBind f args <$> mkCompFunReturn <$> mrFunOutType f args + FunBind f args Unlifted <$> mkCompFunReturn <$> mrFunOutType f args _ -> throwMRFailure (MalformedComp t) @@ -471,7 +495,7 @@ normBind (AssumeBoolBind cond f) k = return $ AssumeBoolBind cond (compFunComp f k) normBind (ExistsBind tp f) k = return $ ExistsBind tp (compFunComp f k) normBind (ForallBind tp f) k = return $ ForallBind tp (compFunComp f k) -normBind (FunBind f args k1) k2 +normBind (FunBind f args isLifted k1) k2 -- Turn `bvVecMapInvarM ... >>= k` into `bvVecMapInvarBindM ... k` {- | GlobalName (globalDefString -> "CryptolM.bvVecMapInvarM") [] <- f @@ -488,18 +512,80 @@ normBind (FunBind f args k1) k2 do cont' <- compFunToTerm (compFunComp (compFunComp (CompFunTerm cont) k1) k2) c <- compFunReturnType k2 return $ FunBind f (args_pre ++ [cont']) (CompFunReturn (Type c)) - | otherwise -} = return $ FunBind f args (compFunComp k1 k2) + | otherwise -} = return $ FunBind f args isLifted (compFunComp k1 k2) + +-- | Bind a computation in whnf with a function, normalize, and then call +-- 'liftStackNormComp' if the first argument is 'Lifted'. If the first argument +-- is 'Unlifted', this function is the same as 'normBind'. +normBindLiftStack :: IsLifted -> NormComp -> CompFun -> MRM t NormComp +normBindLiftStack Unlifted t f = normBind t f +normBindLiftStack Lifted t f = + liftStackNormComp (compFunSpecMParams f) t >>= \t' -> normBind t' f + +-- | Bind a 'Term' for a computation with with a function, normalize, and then +-- call 'liftStackNormComp' if the first argument is 'Lifted'. See: +-- 'normBindLiftStack'. +normBindTermLiftStack :: IsLifted -> Term -> CompFun -> MRM t NormComp +normBindTermLiftStack isLifted t f = + normCompTerm t >>= \m -> normBindLiftStack isLifted m f + + +-- | Apply @liftStackS@ to a computation in whnf, and normalize +liftStackNormComp :: SpecMParams Term -> NormComp -> MRM t NormComp +liftStackNormComp _ (RetS t) = return (RetS t) +liftStackNormComp _ (ErrorS msg) = return (ErrorS msg) +liftStackNormComp params (Ite cond comp1 comp2) = + Ite cond <$> liftStackComp params comp1 <*> liftStackComp params comp2 +liftStackNormComp params (Eithers elims t) = + Eithers <$> mapM (\(tp,f) -> (tp,) <$> liftStackCompFun params f) elims + <*> return t +liftStackNormComp params (MaybeElim tp m f t) = + MaybeElim tp <$> liftStackComp params m + <*> liftStackCompFun params f <*> return t +liftStackNormComp params (OrS comp1 comp2) = + OrS <$> liftStackComp params comp1 <*> liftStackComp params comp2 +liftStackNormComp params (AssertBoolBind cond f) = + AssertBoolBind cond <$> liftStackCompFun params f +liftStackNormComp params (AssumeBoolBind cond f) = + AssumeBoolBind cond <$> liftStackCompFun params f +liftStackNormComp params (ExistsBind tp f) = + ExistsBind tp <$> liftStackCompFun params f +liftStackNormComp params (ForallBind tp f) = + ForallBind tp <$> liftStackCompFun params f +liftStackNormComp params (FunBind f args _ k) = + FunBind f args Lifted <$> liftStackCompFun params k + +-- | Apply @liftStackS@ to a computation +liftStackComp :: SpecMParams Term -> Comp -> MRM t Comp +liftStackComp (SpecMParams ev stk) (CompTerm t) = mrTypeOf t >>= \case + (asSpecM -> Just (_, tp)) -> + CompTerm <$> liftSC2 scGlobalApply "Prelude.liftStackS" [ev, stk, tp, t] + _ -> error "liftStackComp: type not of the form: SpecM a" +liftStackComp _ (CompReturn t) = return $ CompReturn t +liftStackComp params (CompBind c f) = + CompBind <$> liftStackComp params c <*> liftStackCompFun params f + +-- | Apply @liftStackS@ to the bodies of a composition of functions +liftStackCompFun :: SpecMParams Term -> CompFun -> MRM t CompFun +liftStackCompFun params@(SpecMParams ev stk) (CompFunTerm _ f) = mrTypeOf f >>= \case + (asPi -> Just (_, _, asSpecM -> Just (_, tp))) -> + let nm = maybe "ret_val" id (asLambdaName f) in + CompFunTerm params <$> + mrLambdaLift1 (nm, tp) (ev, stk, tp, f) (\arg (ev', stk', tp', f') -> + do app <- mrApplyAll f' [arg] + liftSC2 scGlobalApply "Prelude.liftStackS" [ev', stk', tp', app]) + _ -> error "liftStackCompFun: type not of the form: a -> SpecM b" +liftStackCompFun params (CompFunReturn _ tp) = return $ CompFunReturn params tp +liftStackCompFun params (CompFunComp f g) = + CompFunComp <$> liftStackCompFun params f <*> liftStackCompFun params g --- | Bind a 'Term' for a computation with a function and normalize -normBindTerm :: Term -> CompFun -> MRM t NormComp -normBindTerm t f = normCompTerm t >>= \m -> normBind m f {- -- | Get the return type of a 'CompFun' compFunReturnType :: CompFun -> MRM t Term compFunReturnType (CompFunTerm _ t) = mrTypeOf t compFunReturnType (CompFunComp _ g) = compFunReturnType g -compFunReturnType (CompFunReturn (Type t)) = return t +compFunReturnType (CompFunReturn _ _) = error "FIXME" -} -- | Apply a computation function to a term argument to get a computation @@ -560,7 +646,7 @@ applyNormCompFun f arg = applyCompFun f arg >>= normComp -- | Convert a 'FunAssumpRHS' to a 'NormComp' mrFunAssumpRHSAsNormComp :: FunAssumpRHS -> MRM t NormComp mrFunAssumpRHSAsNormComp (OpaqueFunAssump f args) = - FunBind f args <$> mkCompFunReturn <$> mrFunOutType f args + FunBind f args Unlifted <$> mkCompFunReturn <$> mrFunOutType f args mrFunAssumpRHSAsNormComp (RewriteFunAssump rhs) = normCompTerm rhs @@ -810,10 +896,10 @@ mrRefines' (RetS e) (ErrorS _) = throwMRFailure (ReturnNotError e) mrRefines' (ErrorS _) (RetS e) = throwMRFailure (ReturnNotError e) -- maybe elimination on equality types -mrRefines' (MaybeElim (Type (asEq -> Just (tp,e1,e2))) m1 f1 _) m2 = +mrRefines' (MaybeElim (Type cond_tp@(asEq -> Just (tp,e1,e2))) m1 f1 _) m2 = do cond <- mrEq' tp e1 e2 not_cond <- liftSC1 scNot cond - cond_pf <- liftSC1 scEqTrue cond >>= mrDummyProof + cond_pf <- mrDummyProof cond_tp m1' <- applyNormCompFun f1 cond_pf cond_holds <- mrProvable cond not_cond_holds <- mrProvable not_cond @@ -822,10 +908,10 @@ mrRefines' (MaybeElim (Type (asEq -> Just (tp,e1,e2))) m1 f1 _) m2 = (_, True) -> mrRefines m1 m2 _ -> withAssumption cond (mrRefines m1' m2) >> withAssumption not_cond (mrRefines m1 m2) -mrRefines' m1 (MaybeElim (Type (asEq -> Just (tp,e1,e2))) m2 f2 _) = +mrRefines' m1 (MaybeElim (Type cond_tp@(asEq -> Just (tp,e1,e2))) m2 f2 _) = do cond <- mrEq' tp e1 e2 not_cond <- liftSC1 scNot cond - cond_pf <- liftSC1 scEqTrue cond >>= mrDummyProof + cond_pf <- mrDummyProof cond_tp m2' <- applyNormCompFun f2 cond_pf cond_holds <- mrProvable cond not_cond_holds <- mrProvable not_cond @@ -836,10 +922,10 @@ mrRefines' m1 (MaybeElim (Type (asEq -> Just (tp,e1,e2))) m2 f2 _) = withAssumption not_cond (mrRefines m1 m2) -- maybe elimination on isFinite types -mrRefines' (MaybeElim (Type (asIsFinite -> Just n1)) m1 f1 _) m2 = +mrRefines' (MaybeElim (Type fin_tp@(asIsFinite -> Just n1)) m1 f1 _) m2 = do n1_norm <- mrNormOpenTerm n1 maybe_assump <- mrGetDataTypeAssump n1_norm - fin_pf <- mrIsFinite n1_norm >>= mrDummyProof + fin_pf <- mrDummyProof fin_tp case (maybe_assump, asNum n1_norm) of (_, Just (Left _)) -> applyNormCompFun f1 fin_pf >>= flip mrRefines m2 (_, Just (Right _)) -> mrRefines m1 m2 @@ -851,10 +937,10 @@ mrRefines' (MaybeElim (Type (asIsFinite -> Just n1)) m1 f1 _) m2 = (withUVarLift "n" (Type nat_tp) (n1_norm, f1, m2) $ \ n (n1', f1', m2') -> withDataTypeAssump n1' (IsNum n) (applyNormCompFun f1' n >>= flip mrRefines m2')) -mrRefines' m1 (MaybeElim (Type (asIsFinite -> Just n2)) m2 f2 _) = +mrRefines' m1 (MaybeElim (Type fin_tp@(asIsFinite -> Just n2)) m2 f2 _) = do n2_norm <- mrNormOpenTerm n2 maybe_assump <- mrGetDataTypeAssump n2_norm - fin_pf <- mrIsFinite n2_norm >>= mrDummyProof + fin_pf <- mrDummyProof fin_tp case (maybe_assump, asNum n2_norm) of (_, Just (Left _)) -> applyNormCompFun f2 fin_pf >>= mrRefines m1 (_, Just (Right _)) -> mrRefines m1 m2 @@ -956,9 +1042,9 @@ mrRefines' (OrS m1 m1') m2 = -- FIXME: the following cases don't work unless we either allow evars to be set -- to NormComps or we can turn NormComps back into terms -mrRefines' m1@(FunBind (EVarFunName _) _ _) m2 = +mrRefines' m1@(FunBind (EVarFunName _) _ _ _) m2 = throwMRFailure (CompsDoNotRefine m1 m2) -mrRefines' m1 m2@(FunBind (EVarFunName _) _ _) = +mrRefines' m1 m2@(FunBind (EVarFunName _) _ _ _) = throwMRFailure (CompsDoNotRefine m1 m2) {- mrRefines' (FunBind (EVarFunName evar) args (CompFunReturn _)) m2 = @@ -969,13 +1055,15 @@ mrRefines' (FunBind (EVarFunName evar) args (CompFunReturn _)) m2 = Nothing -> mrTrySetAppliedEVar evar args m2 -} -mrRefines' (FunBind (CallSName f) args1 k1) (FunBind (CallSName f') args2 k2) - | f == f' && length args1 == length args2 = +mrRefines' (FunBind (CallSName f) args1 isLifted k1) + (FunBind (CallSName f') args2 isLifted' k2) + | f == f' && isLifted == isLifted' && length args1 == length args2 = zipWithM_ mrAssertProveEq args1 args2 >> mrFunOutType (CallSName f) args1 >>= \(_, tp) -> mrRefinesFun tp k1 tp k2 -mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = +mrRefines' m1@(FunBind f1 args1 isLifted1 k1) + m2@(FunBind f2 args2 isLifted2 k2) = mrFunOutType f1 args1 >>= \(_, tp1) -> mrFunOutType f2 args2 >>= \(_, tp2) -> findInjConvs tp1 Nothing tp2 Nothing >>= \mb_convs -> @@ -1014,7 +1102,7 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = -- unfolds and is not recursive in itself, unfold f2 and recurse (_, Just fa@(FunAssump _ _ _ (OpaqueFunAssump _ _) _)) | Just (f2_body, False) <- maybe_f2_body -> - normBindTerm f2_body k2 >>= \m2' -> + normBindTermLiftStack isLifted2 f2_body k2 >>= \m2' -> recordUsedFunAssump fa >> mrRefines m1 m2' -- If we have a rewrite FunAssump, or we have an opaque FunAssump that @@ -1036,25 +1124,27 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = evars <- mrFreshEVars ctx (args1'', rhs'') <- substTermLike 0 evars (args1', rhs') zipWithM_ mrAssertProveEq args1'' args1 - m1' <- normBind rhs'' k1 + m1' <- normBindLiftStack isLifted1 rhs'' k1 recordUsedFunAssump fa >> mrRefines m1' m2 -- If f1 unfolds and is not recursive in itself, unfold it and recurse _ | Just (f1_body, False) <- maybe_f1_body -> - normBindTerm f1_body k1 >>= \m1' -> mrRefines m1' m2 + normBindTermLiftStack isLifted1 f1_body k1 >>= \m1' -> mrRefines m1' m2 -- If f2 unfolds and is not recursive in itself, unfold it and recurse _ | Just (f2_body, False) <- maybe_f2_body -> - normBindTerm f2_body k2 >>= \m2' -> mrRefines m1 m2' + normBindTermLiftStack isLifted2 f2_body k2 >>= \m2' -> mrRefines m1 m2' -- If we don't have a co-inducitve hypothesis for f1 and f2, don't have an - -- assumption that f1 refines some specification, and both f1 and f2 are - -- recursive and have return types which are heterogeneously related, then - -- try to coinductively prove that f1 args1 |= f2 args2 under the assumption - -- that f1 args1 |= f2 args2, and then try to prove that k1 |= k2 + -- assumption that f1 refines some specification, both are either lifted or + -- unlifted, and both f1 and f2 are recursive and have return types which are + -- heterogeneously related, then try to coinductively prove that + -- f1 args1 |= f2 args2 under the assumption that f1 args1 |= f2 args2, and + -- then try to prove that k1 |= k2 _ | Just _ <- mb_convs , Just _ <- maybe_f1_body - , Just _ <- maybe_f2_body -> + , Just _ <- maybe_f2_body + , isLifted1 == isLifted2 -> mrRefinesCoInd f1 args1 f2 args2 >> mrRefinesFun tp1 k1 tp2 k2 -- If we cannot line up f1 and f2, then making progress here would require us @@ -1063,10 +1153,12 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = -- continuation on the other side, but we don't know how to do that, so give -- up _ -> - mrDebugPPPrefixSep 1 "mrRefines: bind types not equal:" tp1 "/=" tp2 >> - throwMRFailure (CompsDoNotRefine m1 m2) + do if isLifted1 /= isLifted2 + then debugPrint 1 "mrRefines: isLifted cases do not match" + else mrDebugPPPrefixSep 1 "mrRefines: bind types not equal:" tp1 "/=" tp2 + throwMRFailure (CompsDoNotRefine m1 m2) -mrRefines' m1@(FunBind f1 args1 k1) m2 = +mrRefines' m1@(FunBind f1 args1 isLifted1 k1) m2 = mrGetFunAssump f1 >>= \case -- If we have an assumption that f1 args' refines some rhs, then prove that @@ -1076,7 +1168,7 @@ mrRefines' m1@(FunBind f1 args1 k1) m2 = evars <- mrFreshEVars ctx (args1'', rhs'') <- substTermLike 0 evars (args1', rhs') zipWithM_ mrAssertProveEq args1'' args1 - m1' <- normBind rhs'' k1 + m1' <- normBindLiftStack isLifted1 rhs'' k1 recordUsedFunAssump fa >> mrRefines m1' m2 -- Otherwise, see if we can unfold f1 @@ -1085,19 +1177,19 @@ mrRefines' m1@(FunBind f1 args1 k1) m2 = -- If f1 unfolds and is not recursive in itself, unfold it and recurse Just (f1_body, False) -> - normBindTerm f1_body k1 >>= \m1' -> mrRefines m1' m2 + normBindTermLiftStack isLifted1 f1_body k1 >>= \m1' -> mrRefines m1' m2 -- Otherwise we would have to somehow split m2 into some computation of the -- form m2' >>= k2 where f1 args1 |= m2' and k1 |= k2, but we don't know how -- to do this splitting, so give up _ -> mrRefines'' m1 m2 -mrRefines' m1 m2@(FunBind f2 args2 k2) = +mrRefines' m1 m2@(FunBind f2 args2 isLifted2 k2) = mrFunBodyRecInfo f2 args2 >>= \case -- If f2 unfolds and is not recursive in itself, unfold it and recurse Just (f2_body, False) -> - normBindTerm f2_body k2 >>= \m2' -> mrRefines m1 m2' + normBindTermLiftStack isLifted2 f2_body k2 >>= \m2' -> mrRefines m1 m2' -- If f2 unfolds but is recursive, and k2 is the trivial continuation, meaning -- m2 is just f2 args2, use the law of coinduction to prove m1 |= f2 args2 by diff --git a/src/SAWScript/Prover/MRSolver/Term.hs b/src/SAWScript/Prover/MRSolver/Term.hs index f6d533d4f1..3b6d2399b3 100644 --- a/src/SAWScript/Prover/MRSolver/Term.hs +++ b/src/SAWScript/Prover/MRSolver/Term.hs @@ -7,6 +7,7 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE EmptyCase #-} @@ -171,6 +172,10 @@ mrVarCtxFromOuterToInner = mrVarCtxFromInnerToOuter . reverse specMParamsArgs :: SpecMParams Term -> [Term] specMParamsArgs (SpecMParams ev stack) = [ev, stack] +-- | A datatype indicating whether an application of a 'FunName' is wrapped in +-- a call to @liftStackS@ - used in the 'FunBind' constructor of 'NormComp' +data IsLifted = Lifted | Unlifted deriving (Generic, Eq, Show) + -- | A Haskell representation of a @SpecM@ in "monadic normal form" data NormComp = RetS Term -- ^ A term @retS _ _ a x@ @@ -183,8 +188,9 @@ data NormComp | AssumeBoolBind Term CompFun -- ^ the bind of an @assumeBoolS@ computation | ExistsBind Type CompFun -- ^ the bind of an @existsS@ computation | ForallBind Type CompFun -- ^ the bind of a @forallS@ computation - | FunBind FunName [Term] CompFun - -- ^ Bind a monadic function with @N@ arguments in an @a -> SpecM b@ term + | FunBind FunName [Term] IsLifted CompFun + -- ^ Bind a monadic function with @N@ arguments, possibly wrapped in a call + -- to @liftStackS@, in an @a -> SpecM b@ term deriving (Generic, Show) -- | An eliminator for an @Eithers@ type is a pair of the type of the disjunct @@ -240,6 +246,8 @@ data Comp = CompTerm Term | CompBind Comp CompFun | CompReturn Term asSpecM :: Term -> Maybe (SpecMParams Term, Term) asSpecM (asApplyAll -> (isGlobalDef "Prelude.SpecM" -> Just (), [ev, stack, tp])) = return (SpecMParams { specMEvType = ev, specMStack = stack }, tp) +asSpecM (asApplyAll -> (isGlobalDef "Prelude.CompM" -> Just (), _)) = + error "CompM found instead of SpecM" asSpecM _ = fail "not a SpecM type!" -- | Test if a type normalizes to a monadic function type of 0 or more arguments @@ -418,6 +426,7 @@ instance TermLike Natural where deriving anyclass instance TermLike Type deriving instance TermLike (SpecMParams Term) +deriving instance TermLike IsLifted deriving instance TermLike NormComp deriving instance TermLike CompFun deriving instance TermLike Comp @@ -532,7 +541,8 @@ instance PrettyInCtx Comp where prettyInCtx (CompBind c f) = prettyAppList [prettyInCtx c, return ">>=", prettyInCtx f] prettyInCtx (CompReturn t) = - prettyAppList [ return "returnM", return "_", parens <$> prettyInCtx t] + prettyAppList [return "retS", return "_", return "_", + parens <$> prettyInCtx t] instance PrettyInCtx CompFun where prettyInCtx (CompFunTerm _ t) = prettyInCtx t @@ -576,10 +586,21 @@ instance PrettyInCtx NormComp where prettyInCtx (ForallBind tp k) = prettyAppList [return "forallS", return "_", return "_", prettyInCtx tp, return ">>=", parens <$> prettyInCtx k] - prettyInCtx (FunBind f args (CompFunReturn _ _)) = - prettyTermApp (funNameTerm f) args - prettyInCtx (FunBind f [] k) = - prettyAppList [prettyInCtx f, return ">>=", prettyInCtx k] - prettyInCtx (FunBind f args k) = - prettyAppList [parens <$> prettyTermApp (funNameTerm f) args, - return ">>=", prettyInCtx k] + prettyInCtx (FunBind f args isLifted (CompFunReturn _ _)) = + snd $ prettyInCtxFunBindH f args isLifted + prettyInCtx (FunBind f args isLifted k) + | (g, m) <- prettyInCtxFunBindH f args isLifted = + prettyAppList [g <$> m, return ">>=", prettyInCtx k] + +-- | A helper function for the 'FunBind' case of 'prettyInCtx'. Returns the +-- string you would get if the associated 'CompFun' is 'CompFunReturn', as well +-- as a 'SawDoc' function (which is either 'id' or 'parens') to apply in the +-- case where the associated 'CompFun' is something else. +prettyInCtxFunBindH :: FunName -> [Term] -> IsLifted -> + (SawDoc -> SawDoc, PPInCtxM SawDoc) +prettyInCtxFunBindH f [] Unlifted = (id, prettyInCtx f) +prettyInCtxFunBindH f args Unlifted = (parens,) $ + prettyTermApp (funNameTerm f) args +prettyInCtxFunBindH f args Lifted = (parens,) $ + prettyAppList [return "liftStackS", return "_", return "_", return "_", + parens <$> prettyTermApp (funNameTerm f) args] \ No newline at end of file