diff --git a/cryptol-saw-core/saw/CryptolM.sawcore b/cryptol-saw-core/saw/CryptolM.sawcore index 01eed7bd27..59462730db 100644 --- a/cryptol-saw-core/saw/CryptolM.sawcore +++ b/cryptol-saw-core/saw/CryptolM.sawcore @@ -172,6 +172,17 @@ bvVecUpdateM n len a xs i x = (updBVVec n len a xs i x)) (bvultWithProof n i len); +fromBVVecUpdateM : (n : Nat) -> (len : Vec n Bool) -> (a : isort 0) -> + BVVec n len a -> Vec n Bool -> a -> + a -> (m : Nat) -> CompM (Vec m a); +fromBVVecUpdateM n len a xs i x def m = + maybe (is_bvult n i len) (CompM (Vec m a)) + (errorM (Vec m a) "bvVecUpdateM: invalid sequence index") + (\ (_:is_bvult n i len) -> returnM (Vec m a) + (genFromBVVec n len a + (updBVVec n len a xs i x) def m)) + (bvultWithProof n i len); + updateM : (n : Nat) -> (a : isort 0) -> Vec n a -> Nat -> a -> CompM (Vec n a); updateM n a xs i x = maybe (IsLtNat i n) (CompM (Vec n a)) @@ -340,7 +351,7 @@ ecShiftRM : (m : Num) -> (ix a : sort 0) -> PIntegral ix -> PZero a -> ecShiftRM = Num_rec (\ (m:Num) -> (ix a : sort 0) -> PIntegral ix -> PZero a -> mseq m a -> ix -> mseq m a) - (\ (m:Nat) -> ecShiftL (TCNum m)) + (\ (m:Nat) -> ecShiftR (TCNum m)) (\ (ix a : sort 0) (pix:PIntegral ix) (pa:PZero a) -> ecShiftR TCInf ix (CompM a) pix (PZeroCompM a pa)); diff --git a/heapster-saw/examples/sha512.cry b/heapster-saw/examples/sha512.cry index da3704db33..c118a7c12b 100644 --- a/heapster-saw/examples/sha512.cry +++ b/heapster-saw/examples/sha512.cry @@ -73,14 +73,13 @@ round_00_15_spec i a b c d e f g h T1 = round_16_80_spec : [w] -> [w] -> [w] -> [w] -> [w] -> [w] -> [w] -> [w] -> [w] -> [w] -> - [16][w] -> - [w] -> [w] -> [w] -> + [16][w] -> [w] -> ([w], [w], [w], [w], [w], [w], [w], [w], [16][w], [w], [w], [w]) -round_16_80_spec i j a b c d e f g h X s0 s1 T1 = +round_16_80_spec i j a b c d e f g h X T1 = (a', b', c', d', e', f', g', h', X', s0', s1', T1'') where s0' = sigma_0 (X @ ((j + 1) && 15)) - s1' = sigma_1 (X @ ((j + 4) && 15)) - T1' = X @ (j && 15) + s0' + s1' + X @ ((j + 9) && 15) + s1' = sigma_1 (X @ ((j + 14) && 15)) + T1' = (X @ (j && 15)) + s0' + s1' + (X @ ((j + 9) && 15)) X' = update X (j && 15) T1' (a', b', c', d', e', f', g', h', T1'') = round_00_15_spec (i + j) a b c d e f g h T1' diff --git a/heapster-saw/examples/sha512_mr_solver.saw b/heapster-saw/examples/sha512_mr_solver.saw index 69169abb2e..372a3f0731 100644 --- a/heapster-saw/examples/sha512_mr_solver.saw +++ b/heapster-saw/examples/sha512_mr_solver.saw @@ -7,7 +7,10 @@ heapster_define_perm env "int64" " " "llvmptr 64" "exists x:bv 64.eq(llvmword(x) heapster_define_perm env "int32" " " "llvmptr 32" "exists x:bv 32.eq(llvmword(x))"; heapster_define_perm env "int8" " " "llvmptr 8" "exists x:bv 8.eq(llvmword(x))"; -heapster_define_perm env "int64_ptr" " " "llvmptr 64" "ptr((W,0) |-> int64<>)"; +// FIXME: We always have rw=W, but without the rw arguments below Heapster +// doesn't realize the perm is not copyable (it needs to unfold named perms). +heapster_define_perm env "int64_ptr" "rw:rwmodality" "llvmptr 64" "ptr((rw,0) |-> int64<>)"; +heapster_define_perm env "true_ptr" "rw:rwmodality" "llvmptr 64" "ptr((rw,0) |-> true)"; heapster_assume_fun env "CRYPTO_load_u64_be" "(). arg0:ptr((R,0) |-> int64<>) -o \ @@ -16,23 +19,23 @@ heapster_assume_fun env "CRYPTO_load_u64_be" heapster_typecheck_fun env "round_00_15" "(). arg0:int64<>, \ - \ arg1:int64_ptr<>, arg2:int64_ptr<>, arg3:int64_ptr<>, arg4:int64_ptr<>, \ - \ arg5:int64_ptr<>, arg6:int64_ptr<>, arg7:int64_ptr<>, arg8:int64_ptr<>, \ - \ arg9:int64_ptr<> -o \ - \ arg1:int64_ptr<>, arg2:int64_ptr<>, arg3:int64_ptr<>, arg4:int64_ptr<>, \ - \ arg5:int64_ptr<>, arg6:int64_ptr<>, arg7:int64_ptr<>, arg8:int64_ptr<>, \ - \ arg9:int64_ptr<>, ret:true"; + \ arg1:int64_ptr, arg2:int64_ptr, arg3:int64_ptr, arg4:int64_ptr, \ + \ arg5:int64_ptr, arg6:int64_ptr, arg7:int64_ptr, arg8:int64_ptr, \ + \ arg9:int64_ptr -o \ + \ arg1:int64_ptr, arg2:int64_ptr, arg3:int64_ptr, arg4:int64_ptr, \ + \ arg5:int64_ptr, arg6:int64_ptr, arg7:int64_ptr, arg8:int64_ptr, \ + \ arg9:int64_ptr, ret:true"; heapster_typecheck_fun env "round_16_80" "(). arg0:int64<>, arg1:int64<>, \ - \ arg2:int64_ptr<>, arg3:int64_ptr<>, arg4:int64_ptr<>, arg5:int64_ptr<>, \ - \ arg6:int64_ptr<>, arg7:int64_ptr<>, arg8:int64_ptr<>, arg9:int64_ptr<>, \ + \ arg2:int64_ptr, arg3:int64_ptr, arg4:int64_ptr, arg5:int64_ptr, \ + \ arg6:int64_ptr, arg7:int64_ptr, arg8:int64_ptr, arg9:int64_ptr, \ \ arg10:array(W,0,<16,*8,fieldsh(int64<>)), \ - \ arg11:ptr((W,0) |-> true), arg12:ptr((W,0) |-> true), arg13:int64_ptr<> -o \ - \ arg2:int64_ptr<>, arg3:int64_ptr<>, arg4:int64_ptr<>, arg5:int64_ptr<>, \ - \ arg6:int64_ptr<>, arg7:int64_ptr<>, arg8:int64_ptr<>, arg9:int64_ptr<>, \ + \ arg11:true_ptr, arg12:true_ptr, arg13:int64_ptr -o \ + \ arg2:int64_ptr, arg3:int64_ptr, arg4:int64_ptr, arg5:int64_ptr, \ + \ arg6:int64_ptr, arg7:int64_ptr, arg8:int64_ptr, arg9:int64_ptr, \ \ arg10:array(W,0,<16,*8,fieldsh(int64<>)), \ - \ arg11:int64_ptr<>, arg12:int64_ptr<>, arg13:int64_ptr<>, ret:true"; + \ arg11:int64_ptr, arg12:int64_ptr, arg13:int64_ptr, ret:true"; heapster_typecheck_fun env "return_X" "(). arg0:array(W,0,<16,*8,fieldsh(int64<>)) -o \ @@ -40,13 +43,13 @@ heapster_typecheck_fun env "return_X" heapster_set_translation_checks env false; heapster_typecheck_fun env "processBlock" - "(). arg0:int64_ptr<>, arg1:int64_ptr<>, arg2:int64_ptr<>, \ - \ arg3:int64_ptr<>, arg4:int64_ptr<>, arg5:int64_ptr<>, \ - \ arg6:int64_ptr<>, arg7:int64_ptr<>, \ + "(). arg0:int64_ptr, arg1:int64_ptr, arg2:int64_ptr, \ + \ arg3:int64_ptr, arg4:int64_ptr, arg5:int64_ptr, \ + \ arg6:int64_ptr, arg7:int64_ptr, \ \ arg8:array(R,0,<16,*8,fieldsh(int64<>)) -o \ - \ arg0:int64_ptr<>, arg1:int64_ptr<>, arg2:int64_ptr<>, \ - \ arg3:int64_ptr<>, arg4:int64_ptr<>, arg5:int64_ptr<>, \ - \ arg6:int64_ptr<>, arg7:int64_ptr<>, \ + \ arg0:int64_ptr, arg1:int64_ptr, arg2:int64_ptr, \ + \ arg3:int64_ptr, arg4:int64_ptr, arg5:int64_ptr, \ + \ arg6:int64_ptr, arg7:int64_ptr, \ \ arg8:array(R,0,<16,*8,fieldsh(int64<>)), ret:true"; // FIXME: This translation contains errors @@ -103,6 +106,4 @@ monadify_term {{ Maj }}; monadify_term {{ round_00_15_spec }}; run_test "round_00_15 |= round_00_15_spec" (mr_solver round_00_15 {{ round_00_15_spec }}) true; - -// FIXME: Need to add heterogenous equality on output types for this to work -// run_test "round_16_80 |= round_16_80_spec" (mr_solver_debug 0 round_16_80 {{ round_16_80_spec }}) true; +run_test "round_16_80 |= round_16_80_spec" (mr_solver round_16_80 {{ round_16_80_spec }}) true; diff --git a/saw-core/src/Verifier/SAW/Recognizer.hs b/saw-core/src/Verifier/SAW/Recognizer.hs index ad951c573e..e7be1a8b61 100644 --- a/saw-core/src/Verifier/SAW/Recognizer.hs +++ b/saw-core/src/Verifier/SAW/Recognizer.hs @@ -60,6 +60,7 @@ module Verifier.SAW.Recognizer -- * Prelude recognizers. , asBool , asBoolType + , asNatType , asIntegerType , asIntModType , asBitvectorType @@ -357,6 +358,11 @@ asBool _ = Nothing asBoolType :: Recognizer Term () asBoolType = isGlobalDef "Prelude.Bool" +asNatType :: Recognizer Term () +asNatType (asDataType -> Just (o, [])) + | primName o == preludeNatIdent = return () +asNatType _ = Nothing + asIntegerType :: Recognizer Term () asIntegerType = isGlobalDef "Prelude.Integer" diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index 0fe9915bec..bb5a5b9148 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -4,6 +4,7 @@ {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DerivingStrategies #-} @@ -22,8 +23,10 @@ monadic combinators for operating on terms. module SAWScript.Prover.MRSolver.Monad where -import Data.List (find, findIndex) +import Data.List (find, findIndex, foldl') import qualified Data.Text as T +import Numeric.Natural (Natural) +import Data.Bits (testBit) import System.IO (hPutStrLn, stderr) import Control.Monad.Reader import Control.Monad.State @@ -62,7 +65,7 @@ data FailCtx -- | That's MR. Failure to you data MRFailure - = TermsNotEq Term Term + = TermsNotRel Bool Term Term | TypesNotEq Type Type | CompsDoNotRefine NormComp NormComp | ReturnNotError Term @@ -84,6 +87,9 @@ data MRFailure | MRFailureDisj MRFailure MRFailure deriving Show +pattern TermsNotEq :: Term -> Term -> MRFailure +pattern TermsNotEq t1 t2 = TermsNotRel False t1 t2 + -- | Pretty-print an object prefixed with a 'String' that describes it ppWithPrefix :: PrettyInCtx a => String -> a -> PPInCtxM SawDoc ppWithPrefix str a = (pretty str <>) <$> nest 2 <$> (line <>) <$> prettyInCtx a @@ -109,8 +115,10 @@ instance PrettyInCtx FailCtx where prettyInCtx t] instance PrettyInCtx MRFailure where - prettyInCtx (TermsNotEq t1 t2) = + prettyInCtx (TermsNotRel False t1 t2) = ppWithPrefixSep "Could not prove terms equal:" t1 "and" t2 + prettyInCtx (TermsNotRel True t1 t2) = + ppWithPrefixSep "Could not prove terms heterogeneously related:" t1 "and" t2 prettyInCtx (TypesNotEq tp1 tp2) = ppWithPrefixSep "Types not equal:" tp1 "and" tp2 prettyInCtx (CompsDoNotRefine m1 m2) = @@ -245,27 +253,6 @@ instance PrettyInCtx DataTypeAssump where prettyInCtx (IsNum x) = prettyInCtx x >>= ppWithPrefix "TCNum" prettyInCtx IsInf = return "TCInf" --- | Recognize a term as a @Left@ or @Right@ -asEither :: Recognizer Term (Either Term Term) -asEither (asCtor -> Just (c, [_, _, x])) - | primName c == "Prelude.Left" = return $ Left x - | primName c == "Prelude.Right" = return $ Right x -asEither _ = Nothing - --- | Recognize a term as a @TCNum n@ or @TCInf@ -asNum :: Recognizer Term (Either Term ()) -asNum (asCtor -> Just (c, [n])) - | primName c == "Cryptol.TCNum" = return $ Left n -asNum (asCtor -> Just (c, [])) - | primName c == "Cryptol.TCInf" = return $ Right () -asNum _ = Nothing - --- | Recognize a term as being of the form @isFinite n@ -asIsFinite :: Recognizer Term Term -asIsFinite (asApp -> Just (isGlobalDef "CryptolM.isFinite" -> Just (), n)) = - Just n -asIsFinite _ = Nothing - -- | Create a term representing the type @IsFinite n@ mrIsFinite :: Term -> MRM Term mrIsFinite n = liftSC2 scGlobalApply "CryptolM.isFinite" [n] @@ -481,6 +468,22 @@ funNameType (GlobalName gd projs) = mrApplyAll :: Term -> [Term] -> MRM Term mrApplyAll f args = liftSC2 scApplyAllBeta f args +-- | Like 'scBvNat', but if given a bitvector literal it is converted to a +-- natural number literal +mrBvToNat :: Term -> Term -> MRM Term +mrBvToNat _ (asArrayValue -> Just (asBoolType -> Just _, + mapM asBool -> Just bits)) = + liftSC1 scNat $ foldl' (\n bit -> if bit then 2*n+1 else 2*n) 0 bits +mrBvToNat n len = liftSC2 scBvNat n len + +-- | Like 'scBvConst', but returns a bitvector literal +mrBvConst :: Natural -> Integer -> MRM Term +mrBvConst n x = + do bool_tp <- liftSC0 scBoolType + bits <- mapM (liftSC1 scBool . testBit x) + [(fromIntegral n - 1), (fromIntegral n - 2) .. 0] + liftSC2 scVector bool_tp bits + -- | Get the current context of uvars as a list of variable names and their -- types as SAW core 'Term's, with the least recently bound uvar first, i.e., in -- the order as seen "from the outside" diff --git a/src/SAWScript/Prover/MRSolver/SMT.hs b/src/SAWScript/Prover/MRSolver/SMT.hs index 3ece60ca01..70bf769020 100644 --- a/src/SAWScript/Prover/MRSolver/SMT.hs +++ b/src/SAWScript/Prover/MRSolver/SMT.hs @@ -50,15 +50,6 @@ import SAWScript.Prover.MRSolver.Monad -- * Various SMT-specific Functions on Terms ---------------------------------------------------------------------- --- | Test if a 'Term' is a 'BVVec' type -asBVVecType :: Recognizer Term (Term, Term, Term) -asBVVecType (asApplyAll -> - (isGlobalDef "Prelude.Vec" -> Just _, - [(asApplyAll -> - (isGlobalDef "Prelude.bvToNat" -> Just _, [n, len])), a])) = - Just (n, len, a) -asBVVecType _ = Nothing - -- | Apply @genBVVec@ to arguments @n@, @len@, and @a@, along with a function of -- type @Vec n Bool -> a@ genBVVecTerm :: SharedContext -> Term -> Term -> Term -> Term -> IO Term @@ -205,32 +196,48 @@ readBackValueNoConfig err_str sc tv v = -- | Implementations of primitives for normalizing Mr Solver terms smtNormPrims :: SharedContext -> Map Ident TmPrim smtNormPrims sc = Map.fromList - [ + [ -- Don't unfold @genBVVec@ when normalizing ("Prelude.genBVVec", Prim (do tp <- scTypeOfGlobal sc "Prelude.genBVVec" VExtra <$> VExtraTerm (VTyTerm (mkSort 1) tp) <$> scGlobalDef sc "Prelude.genBVVec") ), + -- Normalize applications of @genBVVecFromVec@ to a @genFromBVVec@ term or + -- a vector literal into the body of the @genFromBVVec@ term or @genBVVec@ + -- of an sequence of @ite@s defined by the literal, respectively ("Prelude.genBVVecFromVec", natFun $ \_m -> tvalFun $ \a -> primFromBVVecOrLit sc a $ \eith -> PrimFun $ \_def -> natFun $ \n -> primBVTermFun sc $ \len -> Prim (do n' <- scNat sc n - a' <- readBackTValueNoConfig "smtNormPrims (genBVVecFromVec)" - sc a + a' <- readBackTValueNoConfig "smtNormPrims (genBVVecFromVec)" sc a tp <- scGlobalApply sc "Prelude.BVVec" [n', len, a'] VExtra <$> VExtraTerm (VTyTerm (mkSort 0) tp) <$> bvVecFromBVVecOrLit sc n n' len a' eith) ), + -- Don't normalize applications of @genFromBVVec@ ("Prelude.genFromBVVec", - Prim (do tp <- scTypeOfGlobal sc "Prelude.genFromBVVec" - VExtra <$> VExtraTerm (VTyTerm (mkSort 1) tp) <$> - scGlobalDef sc "Prelude.genFromBVVec") + natFun $ \n -> PrimStrict $ \len -> tvalFun $ \a -> PrimStrict $ \v -> + PrimStrict $ \def -> natFun $ \m -> + Prim (do n' <- scNat sc n + let len_tp = VVecType n VBoolType + len' <- readBackValueNoConfig "smtNormPrims (genFromBVVec)" sc len_tp len + a' <- readBackTValueNoConfig "smtNormPrims (genFromBVVec)" sc a + bvToNat_len <- scGlobalApply sc "Prelude.bvToNat" [n', len'] + v_tp <- VTyTerm (mkSort 0) <$> scVecType sc bvToNat_len a' + v' <- readBackValueNoConfig "smtNormPrims (genFromBVVec)" sc v_tp v + def' <- readBackValueNoConfig "smtNormPrims (genFromBVVec)" sc a def + m' <- scNat sc m + tm <- scGlobalApply sc "Prelude.genFromBVVec" [n', len', a', v', def', m'] + return $ VExtra $ VExtraTerm (VVecType m a) tm) ), + -- Normalize applications of @atBVVec@ to a @genBVVec@ term into an + -- application of the body of the @genBVVec@ term to the index ("Prelude.atBVVec", PrimFun $ \_n -> PrimFun $ \_len -> tvalFun $ \a -> primGenBVVec sc $ \f -> primBVTermFun sc $ \ix -> PrimFun $ \_pf -> - Prim (VExtra <$> VExtraTerm a <$> scApply sc f ix) + Prim (VExtra <$> VExtraTerm a <$> scApplyBeta sc f ix) ), + -- Don't normalize applications of @CompM@ ("Prelude.CompM", PrimFilterFun "CompM" (\case TValue tv -> return tv @@ -305,7 +312,7 @@ mrProvable (asBool -> Just b) = return b mrProvable bool_tm = do assumps <- mrAssumptions prop <- liftSC2 scImplies assumps bool_tm >>= liftSC1 scEqTrue - prop_inst <- instantiateUVarsM instUVar prop + prop_inst <- mrSubstEVars prop >>= instantiateUVarsM instUVar mrNormTerm prop_inst >>= mrProvableRaw where -- | Given a UVar name and type, generate a 'Term' to be passed to -- SMT, with special cases for BVVec and pair types @@ -343,8 +350,7 @@ mrEq t1 t2 = mrTypeOf t1 >>= \tp -> mrEq' tp t1 t2 -- are equal, where the first 'Term' gives their type (which we assume is the -- same for both). This is like 'scEq' except that it works on open terms. mrEq' :: Term -> Term -> Term -> MRM Term -mrEq' (asDataType -> Just (pn, [])) t1 t2 - | primName pn == "Prelude.Nat" = liftSC2 scEqualNat t1 t2 +mrEq' (asNatType -> Just _) t1 t2 = liftSC2 scEqualNat t1 t2 mrEq' (asBoolType -> Just _) t1 t2 = liftSC2 scBoolEq t1 t2 mrEq' (asIntegerType -> Just _) t1 t2 = liftSC2 scIntEq t1 t2 mrEq' (asVectorType -> Just (n, asBoolType -> Just ())) t1 t2 = @@ -394,100 +400,173 @@ mrProveEqSimple eqf t1 t2 = TermInCtx [] <$> eqf t1' t2' -- | Prove that two terms are equal, instantiating evars if necessary, --- returning true on success +-- returning true on success - the same as @mrProveRel False@ mrProveEq :: Term -> Term -> MRM Bool -mrProveEq t1 t2 = - do mrDebugPPPrefixSep 1 "mrProveEq" t1 "==" t2 - tp <- mrTypeOf t1 >>= mrSubstEVars - varmap <- mrVars - cond_in_ctx <- mrProveEqH varmap tp t1 t2 - res <- withTermInCtx cond_in_ctx mrProvable - debugPrint 1 $ "mrProveEq: " ++ if res then "Success" else "Failure" - return res +mrProveEq = mrProveRel False -- | Prove that two terms are equal, instantiating evars if necessary, or --- throwing an error if this is not possible +-- throwing an error if this is not possible - the same as +-- @mrAssertProveRel False@ mrAssertProveEq :: Term -> Term -> MRM () -mrAssertProveEq t1 t2 = - do success <- mrProveEq t1 t2 - if success then return () else - throwMRFailure (TermsNotEq t1 t2) - --- | The main workhorse for 'mrProveEq'. Build a Boolean term expressing that --- the third and fourth arguments, whose type is given by the second. -mrProveEqH :: Map MRVar MRVarInfo -> Term -> Term -> Term -> MRM TermInCtx +mrAssertProveEq = mrAssertProveRel False + +-- | Prove that two terms are related, heterogeneously iff the first argument +-- is true, instantiating evars if necessary, returning true on success +mrProveRel :: Bool -> Term -> Term -> MRM Bool +mrProveRel het t1 t2 = + do let nm = if het then "mrProveRel" else "mrProveEq" + mrDebugPPPrefixSep 1 nm t1 (if het then "~=" else "==") t2 + tp1 <- mrTypeOf t1 >>= mrSubstEVars + tp2 <- mrTypeOf t2 >>= mrSubstEVars + cond_in_ctx <- mrProveRelH het tp1 tp2 t1 t2 + res <- withTermInCtx cond_in_ctx mrProvable + debugPrint 1 $ nm ++ ": " ++ if res then "Success" else "Failure" + return res -{- -mrProveEqH _ _ t1 t2 - | trace ("mrProveEqH:\n" ++ showTerm t1 ++ "\n==\n" ++ showTerm t2) False = undefined --} +-- | Prove that two terms are related, heterogeneously iff the first argument, +-- is true, instantiating evars if necessary, or throwing an error if this is +-- not possible +mrAssertProveRel :: Bool -> Term -> Term -> MRM () +mrAssertProveRel het t1 t2 = + do success <- mrProveRel het t1 t2 + if success then return () else + throwMRFailure (TermsNotRel het t1 t2) + +-- | The main workhorse for 'mrProveEq' and 'mrProveRel'. Build a Boolean term +-- expressing that the fourth and fifth arguments are related, heterogeneously +-- iff the first argument is true, whose types are given by the second and +-- third arguments, respectively +mrProveRelH :: Bool -> Term -> Term -> Term -> Term -> MRM TermInCtx +mrProveRelH het tp1 tp2 t1 t2 = + do varmap <- mrVars + tp1' <- liftSC1 scWhnf tp1 + tp2' <- liftSC1 scWhnf tp2 + mrProveRelH' varmap het tp1' tp2' t1 t2 + +-- | The body of 'mrProveRelH' +-- NOTE: Don't call this function recursively, call 'mrProveRelH' +mrProveRelH' :: Map MRVar MRVarInfo -> Bool -> + Term -> Term -> Term -> Term -> MRM TermInCtx -- If t1 is an instantiated evar, substitute and recurse -mrProveEqH var_map tp (asEVarApp var_map -> Just (_, args, Just f)) t2 = - mrApplyAll f args >>= \t1' -> mrProveEqH var_map tp t1' t2 - --- If t1 is an uninstantiated evar, instantiate it with t2 -mrProveEqH var_map _tp (asEVarApp var_map -> Just (evar, args, Nothing)) t2 = - do t2' <- mrSubstEVars t2 +mrProveRelH' var_map het tp1 tp2 (asEVarApp var_map -> Just (_, args, Just f)) t2 = + mrApplyAll f args >>= \t1' -> mrProveRelH het tp1 tp2 t1' t2 + +-- If t1 is an uninstantiated evar, ensure the types are equal and instantiate +-- it with t2 +mrProveRelH' var_map _ tp1 tp2 (asEVarApp var_map -> Just (evar, args, Nothing)) t2 = + do tps_are_eq <- mrConvertible tp1 tp2 + if tps_are_eq then return () else + throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) + t2' <- mrSubstEVars t2 success <- mrTrySetAppliedEVar evar args t2' + when success $ + mrDebugPPPrefixSep 2 "mrProveRelH setting evar" evar "to" t2 TermInCtx [] <$> liftSC1 scBool success -- If t2 is an instantiated evar, substitute and recurse -mrProveEqH var_map tp t1 (asEVarApp var_map -> Just (_, args, Just f)) = - mrApplyAll f args >>= \t2' -> mrProveEqH var_map tp t1 t2' - --- If t2 is an uninstantiated evar, instantiate it with t1 -mrProveEqH var_map _tp t1 (asEVarApp var_map -> Just (evar, args, Nothing)) = - do t1' <- mrSubstEVars t1 +mrProveRelH' var_map het tp1 tp2 t1 (asEVarApp var_map -> Just (_, args, Just f)) = + mrApplyAll f args >>= \t2' -> mrProveRelH het tp1 tp2 t1 t2' + +-- If t2 is an uninstantiated evar, ensure the types are equal and instantiate +-- it with t1 +mrProveRelH' var_map _ tp1 tp2 t1 (asEVarApp var_map -> Just (evar, args, Nothing)) = + do tps_are_eq <- mrConvertible tp1 tp2 + if tps_are_eq then return () else + throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) + t1' <- mrSubstEVars t1 success <- mrTrySetAppliedEVar evar args t1' + when success $ + mrDebugPPPrefixSep 2 "mrProveRelH setting evar" evar "to" t1 TermInCtx [] <$> liftSC1 scBool success -- For unit types, always return true -mrProveEqH _ (asTupleType -> Just []) _ _ = +mrProveRelH' _ _ (asTupleType -> Just []) (asTupleType -> Just []) _ _ = TermInCtx [] <$> liftSC1 scBool True --- For the nat, bitvector, Boolean, and integer types, call mrProveEqSimple -mrProveEqH _ (asDataType -> Just (pn, [])) t1 t2 - | primName pn == "Prelude.Nat" = - mrProveEqSimple (liftSC2 scEqualNat) t1 t2 -mrProveEqH _ (asVectorType -> Just (n, asBoolType -> Just ())) t1 t2 = - -- FIXME: make a better solver for bitvector equalities - mrProveEqSimple (liftSC3 scBvEq n) t1 t2 -mrProveEqH _ (asBoolType -> Just _) t1 t2 = +-- For nat, bitvector, Boolean, and integer types, call mrProveEqSimple +mrProveRelH' _ _ (asNatType -> Just _) (asNatType -> Just _) t1 t2 = + mrProveEqSimple (liftSC2 scEqualNat) t1 t2 +mrProveRelH' _ _ tp1@(asVectorType -> Just (n1, asBoolType -> Just ())) + tp2@(asVectorType -> Just (n2, asBoolType -> Just ())) t1 t2 = + do ns_are_eq <- mrConvertible n1 n2 + if ns_are_eq then return () else + throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) + mrProveEqSimple (liftSC3 scBvEq n1) t1 t2 +mrProveRelH' _ _ (asBoolType -> Just _) (asBoolType -> Just _) t1 t2 = mrProveEqSimple (liftSC2 scBoolEq) t1 t2 -mrProveEqH _ (asIntegerType -> Just _) t1 t2 = +mrProveRelH' _ _ (asIntegerType -> Just _) (asIntegerType -> Just _) t1 t2 = mrProveEqSimple (liftSC2 scIntEq) t1 t2 --- For pair types, prove both the left and right projections are equal -mrProveEqH var_map (asPairType -> Just (tpL, tpR)) t1 t2 = +-- For pair types, prove both the left and right projections are related +mrProveRelH' _ het (asPairType -> Just (tpL1, tpR1)) + (asPairType -> Just (tpL2, tpR2)) t1 t2 = do t1L <- liftSC1 scPairLeft t1 t2L <- liftSC1 scPairLeft t2 t1R <- liftSC1 scPairRight t1 t2R <- liftSC1 scPairRight t2 - condL <- mrProveEqH var_map tpL t1L t2L - condR <- mrProveEqH var_map tpR t1R t2R + condL <- mrProveRelH het tpL1 tpL2 t1L t2L + condR <- mrProveRelH het tpR1 tpR2 t1R t2R andTermInCtx condL condR --- For non-bitvector vector types, prove all projections are equal by --- quantifying over a universal index variable and proving equality at that --- index -mrProveEqH _ (asBVVecType -> Just (n, len, tp)) t1 t2 = +-- For BVVec types, prove all projections are related by quantifying over an +-- index variable and proving the projections at that index are related +mrProveRelH' _ het tp1@(asBVVecType -> Just (n1, len1, tpA1)) + tp2@(asBVVecType -> Just (n2, len2, tpA2)) t1 t2 = + mrConvertible n1 n2 >>= \ns_are_eq -> + mrConvertible len1 len2 >>= \lens_are_eq -> + (if ns_are_eq && lens_are_eq then return () else + throwMRFailure (TypesNotEq (Type tp1) (Type tp2))) >> liftSC0 scBoolType >>= \bool_tp -> - liftSC2 scVecType n bool_tp >>= \ix_tp -> - withUVarLift "eq_ix" (Type ix_tp) (n,(len,(tp,(t1,t2)))) $ - \ix' (n',(len',(tp',(t1',t2')))) -> - liftSC2 scGlobalApply "Prelude.is_bvult" [n', ix', len'] >>= \pf_tp -> - withUVarLift "eq_pf" (Type pf_tp) (n',(len',(tp',(ix',(t1',t2'))))) $ - \pf'' (n'',(len'',(tp'',(ix'',(t1'',t2''))))) -> - do t1_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n'', len'', tp'', + liftSC2 scVecType n1 bool_tp >>= \ix_tp -> + withUVarLift "eq_ix" (Type ix_tp) (n1,(len1,(tpA1,(tpA2,(t1,t2))))) $ + \ix' (n1',(len1',(tpA1',(tpA2',(t1',t2'))))) -> + liftSC2 scGlobalApply "Prelude.is_bvult" [n1', ix', len1'] >>= \pf_tp -> + withUVarLift "eq_pf" (Type pf_tp) (n1',(len1',(tpA1',(tpA2',(ix',(t1',t2')))))) $ + \pf'' (n1'',(len1'',(tpA1'',(tpA2'',(ix'',(t1'',t2'')))))) -> + do t1_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n1'', len1'', tpA1'', t1'', ix'', pf''] - t2_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n'', len'', tp'', + t2_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n1'', len1'', tpA2'', t2'', ix'', pf''] - var_map <- mrVars extTermInCtx [("eq_ix",ix_tp),("eq_pf",pf_tp)] <$> - mrProveEqH var_map tp'' t1_prj t2_prj + mrProveRelH het tpA1'' tpA2'' t1_prj t2_prj + +-- If our relation is heterogeneous and we have a BVVec on one side and a +-- non-BVVec vector on the other, wrap the non-BVVec vector term in +-- genBVVecFromVec and recurse +mrProveRelH' _ True tp1@(asBVVecType -> Just (n, len, _)) + tp2@(asNonBVVecVectorType -> Just (m, tpA2)) t1 t2 = + do m' <- mrBvToNat n len + ms_are_eq <- mrConvertible m' m + if ms_are_eq then return () else + throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) + len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] + tp2' <- liftSC2 scVecType len' tpA2 + err_str_tm <- liftSC1 scString "FIXME: mrProveRelH error" + err_tm <- liftSC2 scGlobalApply "Prelude.error" [tpA2, err_str_tm] + t2' <- liftSC2 scGlobalApply "Prelude.genBVVecFromVec" + [m, tpA2, t2, err_tm, n, len] + -- mrDebugPPPrefixSep 2 "mrProveRelH on BVVec/Vec: " t1 "and" t2' + mrProveRelH True tp1 tp2' t1 t2' +mrProveRelH' _ True tp1@(asNonBVVecVectorType -> Just (m, tpA1)) + tp2@(asBVVecType -> Just (n, len, _)) t1 t2 = + do m' <- mrBvToNat n len + ms_are_eq <- mrConvertible m' m + if ms_are_eq then return () else + throwMRFailure (TypesNotEq (Type tp1) (Type tp2)) + len' <- liftSC2 scGlobalApply "Prelude.bvToNat" [n, len] + tp1' <- liftSC2 scVecType len' tpA1 + err_str_tm <- liftSC1 scString "FIXME: mrProveRelH error" + err_tm <- liftSC2 scGlobalApply "Prelude.error" [tpA1, err_str_tm] + t1' <- liftSC2 scGlobalApply "Prelude.genBVVecFromVec" + [m, tpA1, t1, err_tm, n, len] + -- mrDebugPPPrefixSep 2 "mrProveRelH on Vec/BVVec: " t1' "and" t2 + mrProveRelH True tp1' tp2 t1' t2 -- As a fallback, for types we can't handle, just check convertibility -mrProveEqH _ _ t1 t2 = +mrProveRelH' _ _ tp1 tp2 t1 t2 = do success <- mrConvertible t1 t2 + if success then return () else + mrDebugPPPrefixSep 2 "mrProveRelH could not match types: " tp1 "and" tp2 >> + mrDebugPPPrefixSep 2 "and could not prove convertible: " t1 "and" t2 TermInCtx [] <$> liftSC1 scBool success diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index ce83c2a5d7..74b65cae1b 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -125,7 +125,7 @@ module SAWScript.Prover.MRSolver.Solver where import Data.Maybe import Data.Either -import Data.List (findIndices, intercalate, foldl') +import Data.List (findIndices, intercalate) import Data.Bits (shiftL) import Control.Monad.Except import qualified Data.Map as Map @@ -148,19 +148,6 @@ import SAWScript.Prover.MRSolver.SMT -- * Normalizing and Matching on Terms ---------------------------------------------------------------------- --- | Like 'asVectorType', but returns 'Nothing' if 'asBVVecType' returns 'Just' -asNonBVVecVectorType :: Recognizer Term (Term, Term) -asNonBVVecVectorType (asBVVecType -> Just _) = Nothing -asNonBVVecVectorType t = asVectorType t - --- | Like 'scBvNat', but if given a bitvector literal it is converted to a --- natural number literal -mrBvToNat :: Term -> Term -> MRM Term -mrBvToNat _ (asArrayValue -> Just (asBoolType -> Just _, - mapM asBool -> Just bits)) = - liftSC1 scNat $ foldl' (\n bit -> if bit then 2*n+1 else 2*n) 0 bits -mrBvToNat n len = liftSC2 scBvNat n len - -- | Pattern-match on a @LetRecTypes@ list in normal form and return a list of -- the types it specifies, each in normal form and with uvars abstracted out asLRTList :: Term -> MRM [Term] @@ -309,7 +296,7 @@ normComp (CompTerm t) = i)]) -> do body <- mrGlobalDefBody "CryptolM.bvVecAtM" if n < 1 `shiftL` fromIntegral w then do - n' <- liftSC2 scBvConst w (toInteger n) + n' <- mrBvConst w (toInteger n) err_str <- liftSC1 scString "FIXME: normComp (atM) error" err_tm <- liftSC2 scGlobalApply "Prelude.error" [a, err_str] xs' <- liftSC2 scGlobalApply "Prelude.genBVVecFromVec" @@ -334,14 +321,14 @@ normComp (CompTerm t) = asBvToNat -> Just (w_tm@(asNat -> Just w), i), x]) -> - do body <- mrGlobalDefBody "CryptolM.bvVecUpdateM" + do body <- mrGlobalDefBody "CryptolM.fromBVVecUpdateM" if n < 1 `shiftL` fromIntegral w then do - n' <- liftSC2 scBvConst w (toInteger n) + n' <- mrBvConst w (toInteger n) err_str <- liftSC1 scString "FIXME: normComp (updateM) error" err_tm <- liftSC2 scGlobalApply "Prelude.error" [a, err_str] xs' <- liftSC2 scGlobalApply "Prelude.genBVVecFromVec" [n_tm, a, xs, err_tm, w_tm, n'] - mrApplyAll body [w_tm, n', a, xs', i, x] >>= normCompTerm + mrApplyAll body [w_tm, n', a, xs', i, x, err_tm, n_tm] >>= normCompTerm else throwMRFailure (MalformedComp t) -- Always unfold: sawLet, multiArgFixM, invariantHint, Num_rec @@ -636,7 +623,7 @@ mrRefines t1 t2 = -- | The main implementation of 'mrRefines' mrRefines' :: NormComp -> NormComp -> MRM () -mrRefines' (ReturnM e1) (ReturnM e2) = mrAssertProveEq e1 e2 +mrRefines' (ReturnM e1) (ReturnM e2) = mrAssertProveRel True e1 e2 mrRefines' (ErrorM _) (ErrorM _) = return () mrRefines' (ReturnM e) (ErrorM _) = throwMRFailure (ReturnNotError e) mrRefines' (ErrorM _) (ReturnM e) = throwMRFailure (ReturnNotError e) diff --git a/src/SAWScript/Prover/MRSolver/Term.hs b/src/SAWScript/Prover/MRSolver/Term.hs index a498c61717..10f958f67b 100644 --- a/src/SAWScript/Prover/MRSolver/Term.hs +++ b/src/SAWScript/Prover/MRSolver/Term.hs @@ -177,12 +177,52 @@ isCompFunType sc t = scWhnf sc t >>= \case (asPiList -> (_, asCompM -> Just _)) -> return True _ -> return False + +---------------------------------------------------------------------- +-- * Useful 'Recognizer's for 'Term's +---------------------------------------------------------------------- + -- | Recognize a 'Term' as an application of `bvToNat` asBvToNat :: Recognizer Term (Term, Term) asBvToNat (asApplyAll -> ((isGlobalDef "Prelude.bvToNat" -> Just ()), [n, x])) = Just (n, x) asBvToNat _ = Nothing +-- | Recognize a term as a @Left@ or @Right@ +asEither :: Recognizer Term (Either Term Term) +asEither (asCtor -> Just (c, [_, _, x])) + | primName c == "Prelude.Left" = return $ Left x + | primName c == "Prelude.Right" = return $ Right x +asEither _ = Nothing + +-- | Recognize a term as a @TCNum n@ or @TCInf@ +asNum :: Recognizer Term (Either Term ()) +asNum (asCtor -> Just (c, [n])) + | primName c == "Cryptol.TCNum" = return $ Left n +asNum (asCtor -> Just (c, [])) + | primName c == "Cryptol.TCInf" = return $ Right () +asNum _ = Nothing + +-- | Recognize a term as being of the form @isFinite n@ +asIsFinite :: Recognizer Term Term +asIsFinite (asApp -> Just (isGlobalDef "CryptolM.isFinite" -> Just (), n)) = + Just n +asIsFinite _ = Nothing + +-- | Test if a 'Term' is a 'BVVec' type +asBVVecType :: Recognizer Term (Term, Term, Term) +asBVVecType (asApplyAll -> + (isGlobalDef "Prelude.Vec" -> Just _, + [(asApplyAll -> + (isGlobalDef "Prelude.bvToNat" -> Just _, [n, len])), a])) = + Just (n, len, a) +asBVVecType _ = Nothing + +-- | Like 'asVectorType', but returns 'Nothing' if 'asBVVecType' returns 'Just' +asNonBVVecVectorType :: Recognizer Term (Term, Term) +asNonBVVecVectorType (asBVVecType -> Just _) = Nothing +asNonBVVecVectorType t = asVectorType t + ---------------------------------------------------------------------- -- * Mr Solver Environments @@ -393,6 +433,14 @@ instance PrettyInCtx MRVar where instance PrettyInCtx [Term] where prettyInCtx xs = list <$> mapM prettyInCtx xs +instance PrettyInCtx a => PrettyInCtx (Maybe a) where + prettyInCtx (Just x) = (<+>) "Just" <$> prettyInCtx x + prettyInCtx Nothing = return "Nothing" + +instance (PrettyInCtx a, PrettyInCtx b) => PrettyInCtx (a,b) where + prettyInCtx (x, y) = (\x' y' -> parens (x' <> "," <> y')) <$> prettyInCtx x + <*> prettyInCtx y + instance PrettyInCtx TermProj where prettyInCtx TermProjLeft = return (pretty '.' <> "1") prettyInCtx TermProjRight = return (pretty '.' <> "2")