diff --git a/heapster-saw/examples/arrays_mr_solver.saw b/heapster-saw/examples/arrays_mr_solver.saw new file mode 100644 index 0000000000..0326bac2ec --- /dev/null +++ b/heapster-saw/examples/arrays_mr_solver.saw @@ -0,0 +1,3 @@ +include "arrays.saw"; +contains0 <- parse_core_mod "arrays" "contains0"; +mr_solver_debug 1 contains0 contains0; diff --git a/saw-core/src/Verifier/SAW/OpenTerm.hs b/saw-core/src/Verifier/SAW/OpenTerm.hs index d271153071..57c1fd7ad0 100644 --- a/saw-core/src/Verifier/SAW/OpenTerm.hs +++ b/saw-core/src/Verifier/SAW/OpenTerm.hs @@ -27,13 +27,14 @@ module Verifier.SAW.OpenTerm ( unitOpenTerm, unitTypeOpenTerm, stringLitOpenTerm, stringTypeOpenTerm, trueOpenTerm, falseOpenTerm, boolOpenTerm, boolTypeOpenTerm, - arrayValueOpenTerm, bvLitOpenTerm, bvTypeOpenTerm, + arrayValueOpenTerm, vectorTypeOpenTerm, bvLitOpenTerm, bvTypeOpenTerm, pairOpenTerm, pairTypeOpenTerm, pairLeftOpenTerm, pairRightOpenTerm, tupleOpenTerm, tupleTypeOpenTerm, projTupleOpenTerm, tupleOpenTerm', tupleTypeOpenTerm', recordOpenTerm, recordTypeOpenTerm, projRecordOpenTerm, ctorOpenTerm, dataTypeOpenTerm, globalOpenTerm, extCnsOpenTerm, - applyOpenTerm, applyOpenTermMulti, applyPiOpenTerm, piArgOpenTerm, + applyOpenTerm, applyOpenTermMulti, applyGlobalOpenTerm, + applyPiOpenTerm, piArgOpenTerm, lambdaOpenTerm, lambdaOpenTermMulti, piOpenTerm, piOpenTermMulti, arrowOpenTerm, letOpenTerm, sawLetOpenTerm, -- * Monadic operations for building terms with binders @@ -179,6 +180,10 @@ bvLitOpenTerm :: [Bool] -> OpenTerm bvLitOpenTerm bits = arrayValueOpenTerm boolTypeOpenTerm $ map boolOpenTerm bits +-- | Create a SAW core term for a vector type +vectorTypeOpenTerm :: OpenTerm -> OpenTerm -> OpenTerm +vectorTypeOpenTerm n a = applyGlobalOpenTerm "Prelude.Vec" [n,a] + -- | Create a SAW core term for the type of a bitvector bvTypeOpenTerm :: Integral a => a -> OpenTerm bvTypeOpenTerm n = @@ -287,6 +292,10 @@ applyOpenTerm (OpenTerm f) (OpenTerm arg) = applyOpenTermMulti :: OpenTerm -> [OpenTerm] -> OpenTerm applyOpenTermMulti = foldl applyOpenTerm +-- | Apply a named global to 0 or more arguments +applyGlobalOpenTerm :: Ident -> [OpenTerm] -> OpenTerm +applyGlobalOpenTerm ident = applyOpenTermMulti (globalOpenTerm ident) + -- | Compute the output type of applying a function of a given type to an -- argument. That is, given @tp@ and @arg@, compute the type of applying any @f@ -- of type @tp@ to @arg@. diff --git a/src/SAWScript/Builtins.hs b/src/SAWScript/Builtins.hs index f67330163c..22e4565920 100644 --- a/src/SAWScript/Builtins.hs +++ b/src/SAWScript/Builtins.hs @@ -1385,8 +1385,8 @@ tailPrim :: [a] -> TopLevel [a] tailPrim [] = fail "tail: empty list" tailPrim (_ : xs) = return xs -parseCore :: String -> TopLevel Term -parseCore input = +parseCoreMod :: String -> String -> TopLevel Term +parseCoreMod mnm_str input = do sc <- getSharedContext let base = "" path = "" @@ -1397,18 +1397,29 @@ parseCore input = do let msg = show err printOutLnTop Opts.Error msg fail msg - let mnm = Just $ mkModuleName ["Cryptol"] - err_or_t <- io $ runTCM (typeInferComplete uterm) sc mnm [] + let mnm = + mkModuleName $ Text.splitOn (Text.pack ".") $ Text.pack mnm_str + _ <- io $ scFindModule sc mnm -- Check that mnm exists + err_or_t <- io $ runTCM (typeInferComplete uterm) sc (Just mnm) [] case err_or_t of Left err -> fail (show err) Right (TC.TypedTerm x _) -> return x +parseCore :: String -> TopLevel Term +parseCore = parseCoreMod "Cryptol" + parse_core :: String -> TopLevel TypedTerm parse_core input = do t <- parseCore input sc <- getSharedContext io $ mkTypedTerm sc t +parse_core_mod :: String -> String -> TopLevel TypedTerm +parse_core_mod mnm input = do + t <- parseCoreMod mnm input + sc <- getSharedContext + io $ mkTypedTerm sc t + prove_core :: ProofScript () -> String -> TopLevel Theorem prove_core script input = do sc <- getSharedContext @@ -1540,16 +1551,18 @@ monadifyTypedTerm sc t = -- | Ensure that a 'TypedTerm' has been monadified ensureMonadicTerm :: SharedContext -> TypedTerm -> TopLevel TypedTerm -ensureMonadicTerm _ t - | TypedTermOther tp <- ttType t - , Prover.isCompFunType tp = return t +ensureMonadicTerm sc t + | TypedTermOther tp <- ttType t = + io (Prover.isCompFunType sc tp) >>= \case + True -> return t + False -> monadifyTypedTerm sc t ensureMonadicTerm sc t = monadifyTypedTerm sc t mrSolver :: SharedContext -> Int -> TypedTerm -> TypedTerm -> TopLevel Bool mrSolver sc dlvl t1 t2 = do m1 <- ttTerm <$> ensureMonadicTerm sc t1 m2 <- ttTerm <$> ensureMonadicTerm sc t2 - res <- liftIO $ Prover.askMRSolver sc dlvl SBV.z3 Nothing m1 m2 + res <- liftIO $ Prover.askMRSolver sc dlvl Nothing m1 m2 case res of Just err -> io (putStrLn $ Prover.showMRFailure err) >> return False Nothing -> return True diff --git a/src/SAWScript/Interpreter.hs b/src/SAWScript/Interpreter.hs index 6a5aba8d6f..6d5dd09139 100644 --- a/src/SAWScript/Interpreter.hs +++ b/src/SAWScript/Interpreter.hs @@ -2225,6 +2225,13 @@ primitives = Map.fromList [ "Parse a Term from a String in SAWCore syntax." ] + , prim "parse_core_mod" "String -> String -> Term" + (funVal2 parse_core_mod) + Current + [ "Parse a Term from the second supplied String in SAWCore syntax," + , "relative to the module specified by the first String" + ] + , prim "prove_core" "ProofScript () -> String -> TopLevel Theorem" (pureVal prove_core) Current diff --git a/src/SAWScript/Prover/MRSolver.hs b/src/SAWScript/Prover/MRSolver.hs index d29bfa0617..f828ee69a1 100644 --- a/src/SAWScript/Prover/MRSolver.hs +++ b/src/SAWScript/Prover/MRSolver.hs @@ -118,13 +118,11 @@ C |- F e1 ... en >>= k |= F' e1' ... em' >>= k': -} module SAWScript.Prover.MRSolver - (askMRSolver, MRFailure(..), showMRFailure, isCompFunType - , SBV.SMTConfig - , SBV.z3, SBV.cvc4, SBV.yices, SBV.mathSAT, SBV.boolector - ) where + (askMRSolver, MRFailure(..), showMRFailure, isCompFunType) where import Data.List (find, findIndex) import qualified Data.Text as T +import qualified Data.Vector as V import Data.IORef import System.IO (hPutStrLn, stderr) import Control.Monad.Reader @@ -135,6 +133,7 @@ import Control.Monad.Trans.Maybe import qualified Data.IntMap as IntMap import Data.Map (Map) import qualified Data.Map as Map +import qualified Data.Set as Set import Prettyprinter @@ -144,10 +143,19 @@ import Verifier.SAW.Term.Pretty import Verifier.SAW.SCTypeCheck import Verifier.SAW.SharedTerm import Verifier.SAW.Recognizer +import Verifier.SAW.OpenTerm import Verifier.SAW.Cryptol.Monadify -import SAWScript.Proof (termToProp) -import qualified SAWScript.Prover.SBV as SBV +import qualified Verifier.SAW.Prim as Prim +import Verifier.SAW.Simulator.TermModel +import Verifier.SAW.Simulator.Prims +import Verifier.SAW.Simulator.MonadLazy + +import SAWScript.Proof (termToProp, propToTerm, prettyProp) +import What4.Solver +import SAWScript.Prover.What4 + +-- import Debug.Trace ---------------------------------------------------------------------- @@ -193,18 +201,64 @@ newtype MRVar = MRVar { unMRVar :: ExtCns Term } deriving (Eq, Show, Ord) mrVarType :: MRVar -> Term mrVarType = ecType . unMRVar +-- | A tuple or record projection of a 'Term' +data TermProj = TermProjLeft | TermProjRight | TermProjRecord FieldName + deriving (Eq, Ord, Show) + +-- | Apply a 'TermProj' to perform a projection on a 'Term' +doTermProj :: Term -> TermProj -> MRM Term +doTermProj t TermProjLeft = liftSC1 scPairLeft t +doTermProj t TermProjRight = liftSC1 scPairRight t +doTermProj t (TermProjRecord fld) = liftSC2 scRecordSelect t fld + +-- | Apply a 'TermProj' to a type to get the output type of the projection, +-- assuming that the type is already normalized +doTypeProj :: Term -> TermProj -> MRM Term +doTypeProj (asPairType -> Just (tp1, _)) TermProjLeft = return tp1 +doTypeProj (asPairType -> Just (_, tp2)) TermProjRight = return tp2 +doTypeProj (asRecordType -> Just tp_map) (TermProjRecord fld) + | Just tp <- Map.lookup fld tp_map + = return tp +doTypeProj _ _ = + -- FIXME: better error message? This is an error and not an MRFailure because + -- we should only be projecting types for terms that we have already seen... + error "doTypeProj" + +-- | Recognize a 'Term' as 0 or more projections +asProjAll :: Term -> (Term, [TermProj]) +asProjAll (asRecordSelector -> Just ((asProjAll -> (t, projs)), fld)) = + (t, TermProjRecord fld:projs) +asProjAll (asPairSelector -> Just ((asProjAll -> (t, projs)), isRight)) + | isRight = (t, TermProjRight:projs) + | not isRight = (t, TermProjLeft:projs) +asProjAll t = (t, []) + -- | Names of functions to be used in computations, which are either names bound -- by letrec to for recursive calls to fixed-points, existential variables, or --- global named constants +-- (possibly projections of) of global named constants data FunName - = LetRecName MRVar | EVarFunName MRVar | GlobalName GlobalDef + = LetRecName MRVar | EVarFunName MRVar | GlobalName GlobalDef [TermProj] deriving (Eq, Ord, Show) --- | Get the type of a 'FunName' -funNameType :: FunName -> Term -funNameType (LetRecName var) = mrVarType var -funNameType (EVarFunName var) = mrVarType var -funNameType (GlobalName gd) = globalDefType gd +-- | Get and normalize the type of a 'FunName' +funNameType :: FunName -> MRM Term +funNameType (LetRecName var) = liftSC1 scWhnf $ mrVarType var +funNameType (EVarFunName var) = liftSC1 scWhnf $ mrVarType var +funNameType (GlobalName gd projs) = + liftSC1 scWhnf (globalDefType gd) >>= \gd_tp -> + foldM doTypeProj gd_tp projs + +-- | Recognize a 'Term' as (possibly a projection of) a global name +asTypedGlobalProj :: Recognizer Term (GlobalDef, [TermProj]) +asTypedGlobalProj (asProjAll -> ((asTypedGlobalDef -> Just glob), projs)) = + Just (glob, projs) +asTypedGlobalProj _ = Nothing + +-- | Recognize a 'Term' as (possibly a projection of) a global name +asGlobalFunName :: Recognizer Term FunName +asGlobalFunName (asTypedGlobalProj -> Just (glob, projs)) = + Just $ GlobalName glob projs +asGlobalFunName _ = Nothing -- | A term specifically known to be of type @sort i@ for some @i@ newtype Type = Type Term deriving Show @@ -215,6 +269,7 @@ data NormComp | ErrorM Term -- ^ A term @errorM a str@ | Ite Term Comp Comp -- ^ If-then-else computation | Either CompFun CompFun Term -- ^ A sum elimination + | MaybeElim Type Comp CompFun Term -- ^ A maybe elimination | OrM Comp Comp -- ^ an @orM@ computation | ExistsM Type CompFun -- ^ an @existsM@ computation | ForallM Type CompFun -- ^ a @forallM@ computation @@ -287,10 +342,16 @@ instance PrettyInCtx Type where instance PrettyInCtx MRVar where prettyInCtx (MRVar ec) = return $ ppName $ ecName ec +instance PrettyInCtx TermProj where + prettyInCtx TermProjLeft = return (pretty '.' <> "1") + prettyInCtx TermProjRight = return (pretty '.' <> "2") + prettyInCtx (TermProjRecord fld) = return (pretty '.' <> pretty fld) + instance PrettyInCtx FunName where prettyInCtx (LetRecName var) = prettyInCtx var prettyInCtx (EVarFunName var) = prettyInCtx var - prettyInCtx (GlobalName i) = return $ viaShow i + prettyInCtx (GlobalName g projs) = + foldM (\pp proj -> (pp <>) <$> prettyInCtx proj) (viaShow g) projs instance PrettyInCtx Comp where prettyInCtx (CompTerm t) = prettyInCtx t @@ -311,11 +372,16 @@ instance PrettyInCtx NormComp where prettyInCtx (ErrorM str) = prettyAppList [return "errorM", return "_", parens <$> prettyInCtx str] prettyInCtx (Ite cond t1 t2) = - prettyAppList [return "ite", return "_", prettyInCtx cond, + prettyAppList [return "ite", return "_", parens <$> prettyInCtx cond, parens <$> prettyInCtx t1, parens <$> prettyInCtx t2] prettyInCtx (Either f g eith) = prettyAppList [return "either", return "_", return "_", return "_", - prettyInCtx f, prettyInCtx g, prettyInCtx eith] + parens <$> prettyInCtx f, parens <$> prettyInCtx g, + parens <$> prettyInCtx eith] + prettyInCtx (MaybeElim tp m f mayb) = + prettyAppList [return "maybe", parens <$> prettyInCtx tp, + return (parens "CompM _"), parens <$> prettyInCtx m, + parens <$> prettyInCtx f, parens <$> prettyInCtx mayb] prettyInCtx (OrM t1 t2) = prettyAppList [return "orM", return "_", parens <$> prettyInCtx t1, parens <$> prettyInCtx t2] @@ -367,7 +433,11 @@ instance TermLike NormComp where Ite <$> liftTermLike n i cond <*> liftTermLike n i t1 <*> liftTermLike n i t2 liftTermLike n i (Either f g eith) = Either <$> liftTermLike n i f <*> liftTermLike n i g <*> liftTermLike n i eith - liftTermLike n i (OrM t1 t2) = OrM <$> liftTermLike n i t1 <*> liftTermLike n i t2 + liftTermLike n i (MaybeElim tp m f mayb) = + MaybeElim <$> liftTermLike n i tp <*> liftTermLike n i m + <*> liftTermLike n i f <*> liftTermLike n i mayb + liftTermLike n i (OrM t1 t2) = + OrM <$> liftTermLike n i t1 <*> liftTermLike n i t2 liftTermLike n i (ExistsM tp f) = ExistsM <$> liftTermLike n i tp <*> liftTermLike n i f liftTermLike n i (ForallM tp f) = @@ -383,6 +453,9 @@ instance TermLike NormComp where substTermLike n s (Either f g eith) = Either <$> substTermLike n s f <*> substTermLike n s g <*> substTermLike n s eith + substTermLike n s (MaybeElim tp m f mayb) = + MaybeElim <$> substTermLike n s tp <*> substTermLike n s m + <*> substTermLike n s f <*> substTermLike n s mayb substTermLike n s (OrM t1 t2) = OrM <$> substTermLike n s t1 <*> substTermLike n s t2 substTermLike n s (ExistsM tp f) = @@ -557,8 +630,6 @@ data FunAssump = FunAssump { data MRState = MRState { -- | Global shared context for building terms, etc. mrSC :: SharedContext, - -- | Global SMT configuration for the duration of the MR. Solver call - mrSMTConfig :: SBV.SMTConfig, -- | SMT timeout for SMT calls made by Mr. Solver mrSMTTimeout :: Maybe Integer, -- | The context of universal variables, which are free SAW core variables, in @@ -569,7 +640,8 @@ data MRState = MRState { mrVars :: MRVarMap, -- | The current assumptions of function refinement mrFunAssumps :: Map FunName FunAssump, - -- | The current assumptions, which are conjoined into a single Boolean term + -- | The current assumptions, which are conjoined into a single Boolean term; + -- note that these have the current UVars free mrAssumptions :: Term, -- | The debug level, which controls debug printing mrDebugLevel :: Int @@ -577,11 +649,11 @@ data MRState = MRState { -- | Build a default, empty state from SMT configuration parameters and a set of -- function refinement assumptions -mkMRState :: SharedContext -> Map FunName FunAssump -> SBV.SMTConfig -> +mkMRState :: SharedContext -> Map FunName FunAssump -> Maybe Integer -> Int -> IO MRState -mkMRState sc fun_assumps smt_config timeout dlvl = +mkMRState sc fun_assumps timeout dlvl = scBool sc True >>= \true_tm -> - return $ MRState { mrSC = sc, mrSMTConfig = smt_config, + return $ MRState { mrSC = sc, mrSMTTimeout = timeout, mrUVars = [], mrVars = Map.empty, mrFunAssumps = fun_assumps, mrAssumptions = true_tm, mrDebugLevel = dlvl } @@ -626,11 +698,9 @@ catchErrorEither m = catchError (Right <$> m) (return . Left) -- FIXME: replace these individual lifting functions with a more general -- typeclass like LiftTCM -{- -- | Lift a nullary SharedTerm computation into 'MRM' liftSC0 :: (SharedContext -> IO a) -> MRM a liftSC0 f = (mrSC <$> get) >>= \sc -> liftIO (f sc) --} -- | Lift a unary SharedTerm computation into 'MRM' liftSC1 :: (SharedContext -> a -> IO b) -> a -> MRM b @@ -649,6 +719,11 @@ liftSC4 :: (SharedContext -> a -> b -> c -> d -> IO e) -> a -> b -> c -> d -> MRM e liftSC4 f a b c d = (mrSC <$> get) >>= \sc -> liftIO (f sc a b c d) +-- | Lift a quinary SharedTerm computation into 'MRM' +liftSC5 :: (SharedContext -> a -> b -> c -> d -> e -> IO f) -> + a -> b -> c -> d -> e -> MRM f +liftSC5 f a b c d e = (mrSC <$> get) >>= \sc -> liftIO (f sc a b c d e) + -- | Apply a 'Term' to a list of arguments and beta-reduce in Mr. Monad mrApplyAll :: Term -> [Term] -> MRM Term mrApplyAll f args = liftSC2 scApplyAll f args >>= liftSC1 betaNormalize @@ -671,13 +746,18 @@ mrConvertible = liftSC4 scConvertibleEval scTypeCheckWHNF True -- compute the type @CompM [args/vars]a@ of @f@ applied to @args@. Return the -- type @[args/vars]a@ that @CompM@ is applied to. mrFunOutType :: FunName -> [Term] -> MRM Term -mrFunOutType ((asPiList . funNameType) -> (vars, asCompM -> Just tp)) args - | length vars == length args = - substTermLike 0 args tp -mrFunOutType _ _ = - -- NOTE: this is an error because we should only ever call mrFunOutType with a - -- well-formed application at a CompM type - error "mrFunOutType" +mrFunOutType fname args = + funNameType fname >>= \case + (asPiList -> (vars, asCompM -> Just tp)) + | length vars == length args -> substTermLike 0 args tp + ftype@(asPiList -> (vars, _)) -> + do pp_ftype <- mrPPInCtx ftype + pp_fname <- mrPPInCtx fname + debugPrint 0 "mrFunOutType: function applied to the wrong number of args" + debugPrint 0 ("Expected: " ++ show (length vars) ++ + ", found: " ++ show (length args)) + debugPretty 0 ("For function: " <> pp_fname <> " with type: " <> pp_ftype) + error"mrFunOutType" -- | Turn a 'LocalName' into one not in a list, adding a suffix if necessary uniquifyName :: LocalName -> [LocalName] -> LocalName @@ -689,14 +769,18 @@ uniquifyName nm nms = Nothing -> error "uniquifyName" -- | Run a MR Solver computation in a context extended with a universal --- variable, which is passed as a 'Term' to the sub-computation +-- variable, which is passed as a 'Term' to the sub-computation. Note that any +-- assumptions made in the sub-computation will be lost when it completes. withUVar :: LocalName -> Type -> (Term -> MRM a) -> MRM a withUVar nm tp m = do st <- get let nm' = uniquifyName nm (map fst $ mrUVars st) - put (st { mrUVars = (nm',tp) : mrUVars st }) + assumps' <- liftTerm 0 1 $ mrAssumptions st + put (st { mrUVars = (nm',tp) : mrUVars st, + mrAssumptions = assumps' }) ret <- mapFailure (MRFailureLocalVar nm') (liftSC1 scLocalVar 0 >>= m) - modify (\st' -> st' { mrUVars = mrUVars st }) + modify (\st' -> st' { mrUVars = mrUVars st, + mrAssumptions = mrAssumptions st }) return ret -- | Run a MR Solver computation in a context extended with a universal variable @@ -717,8 +801,10 @@ withUVars = helper [] where helper :: [Term] -> [(LocalName,Term)] -> ([Term] -> MRM a) -> MRM a helper vars [] m = m $ reverse vars helper vars ((nm,tp):ctx) m = + -- FIXME: I think substituting here is wrong, but works on closed terms, so + -- it's fine to use at the top level at least... substTerm 0 vars tp >>= \tp' -> - withUVar nm (Type tp') $ \var -> helper (var:vars) ctx m + withUVarLift nm (Type tp') vars $ \var vars' -> helper (var:vars') ctx m -- | Build 'Term's for all the uvars currently in scope, ordered from least to -- most recently bound @@ -737,6 +823,23 @@ lambdaUVarsM t = mrUVarCtx >>= \ctx -> liftSC2 scLambdaList ctx t piUVarsM :: Term -> MRM Term piUVarsM t = mrUVarCtx >>= \ctx -> liftSC2 scPiList ctx t +-- | Instantiate all uvars in a term using the supplied function +instantiateUVarsM :: TermLike a => (LocalName -> Term -> MRM Term) -> a -> MRM a +instantiateUVarsM f a = + do ctx <- mrUVarCtx + -- Remember: the uvar context is outermost to innermost, so we bind + -- variables from left to right, substituting earlier ones into the types + -- of later ones, but all substitutions are in reverse order, since + -- substTerm and friends like innermost bindings first + let helper :: [Term] -> [(LocalName,Term)] -> MRM [Term] + helper tms [] = return tms + helper tms ((nm,tp):vars) = + do tp' <- substTerm 0 tms tp + tm <- f nm tp' + helper (tm:tms) vars + ecs <- helper [] ctx + substTermLike 0 ecs a + -- | Convert an 'MRVar' to a 'Term', applying it to all the uvars in scope mrVarTerm :: MRVar -> MRM Term mrVarTerm (MRVar ec) = @@ -755,7 +858,7 @@ extCnsToFunName ec = let var = MRVar ec in mrVarInfo var >>= \case Just (FunVarInfo _) -> return $ LetRecName var Nothing | Just glob <- asTypedGlobalDef (Unshared $ FTermF $ ExtCns ec) -> - return $ GlobalName glob + return $ GlobalName glob [] _ -> error "extCnsToFunName: unreachable" -- | Get the body of a function @f@ if it has one @@ -764,7 +867,10 @@ mrFunNameBody (LetRecName var) = mrVarInfo var >>= \case Just (FunVarInfo body) -> return $ Just body _ -> error "mrFunBody: unknown letrec var" -mrFunNameBody (GlobalName glob) = return $ globalDefBody glob +mrFunNameBody (GlobalName glob projs) + | Just body <- globalDefBody glob + = Just <$> foldM doTermProj body projs +mrFunNameBody (GlobalName _ _) = return Nothing mrFunNameBody (EVarFunName _) = return Nothing -- | Get the body of a function @f@ applied to some arguments, if possible @@ -793,9 +899,9 @@ mrCallsFun f = memoFixTermFun $ \recurse t -> case t of _ | f == g -> return True Just body -> recurse body Nothing -> return False - (asTypedGlobalDef -> Just gdef) -> + (asTypedGlobalProj -> Just (gdef, projs)) -> case globalDefBody gdef of - _ | f == GlobalName gdef -> return True + _ | f == GlobalName gdef projs -> return True Just body -> recurse body Nothing -> return False (unwrapTermF -> tf) -> @@ -844,8 +950,11 @@ mrFreshEVars = helper [] where mrSetEVarClosed :: MRVar -> Term -> MRM () mrSetEVarClosed var val = do val_tp <- mrTypeOf val + -- NOTE: need to instantiate any evars in the type of var, to ensure the + -- following subtyping check will succeed + var_tp <- mrSubstEVars $ mrVarType var -- FIXME: catch subtyping errors and report them as being evar failures - liftSC3 scCheckSubtype Nothing (TypedTerm val val_tp) (mrVarType var) + liftSC3 scCheckSubtype Nothing (TypedTerm val val_tp) var_tp modify $ \st -> st { mrVars = Map.alter @@ -979,13 +1088,13 @@ debugPretty i pp = debugPrint i $ renderSawDoc defaultPPOpts pp -- | Pretty-print an object in the current context if the current debug level is -- at least the supplied 'Int' -_debugPrettyInCtx :: PrettyInCtx a => Int -> a -> MRM () -_debugPrettyInCtx i a = +debugPrettyInCtx :: PrettyInCtx a => Int -> a -> MRM () +debugPrettyInCtx i a = (mrUVars <$> get) >>= \ctx -> debugPrint i (showInCtx (map fst ctx) a) -- | Pretty-print an object relative to the current context -_mrPPInCtx :: PrettyInCtx a => a -> MRM SawDoc -_mrPPInCtx a = +mrPPInCtx :: PrettyInCtx a => a -> MRM SawDoc +mrPPInCtx a = runReader (prettyInCtx a) <$> map fst <$> mrUVars <$> get -- | Pretty-print the result of 'ppWithPrefixSep' relative to the current uvar @@ -1003,18 +1112,128 @@ mrDebugPPPrefixSep i pre a1 sp a2 = -- * Calling Out to SMT ---------------------------------------------------------------------- +-- | Test if a 'Term' is a 'BVVec' type +asBVVecType :: Recognizer Term (Term, Term, Term) +asBVVecType (asApplyAll -> + (isGlobalDef "Prelude.Vec" -> Just _, + [(asApplyAll -> + (isGlobalDef "Prelude.bvToNat" -> Just _, [n, len])), a])) = + Just (n, len, a) +asBVVecType _ = Nothing + +-- | Apply @genBVVec@ to arguments @n@, @len@, and @a@, along with a function of +-- type @Vec n Bool -> a@ +genBVVecTerm :: SharedContext -> Term -> Term -> Term -> Term -> IO Term +genBVVecTerm sc n_tm len_tm a_tm f_tm = + let n = closedOpenTerm n_tm + len = closedOpenTerm len_tm + a = closedOpenTerm a_tm + f = closedOpenTerm f_tm in + completeOpenTerm sc $ + applyOpenTermMulti (globalOpenTerm "Prelude.genBVVec") + [n, len, a, + lambdaOpenTerm "i" (vectorTypeOpenTerm n boolTypeOpenTerm) $ \i -> + lambdaOpenTerm "_" (applyGlobalOpenTerm "Prelude.is_bvult" [n, i, len]) $ \_ -> + applyOpenTerm f i] + +-- | Match a term of the form @genBVVec n len a (\ i _ -> e)@, i.e., where @e@ +-- does not have the proof variable (the underscore) free +asGenBVVecTerm :: Recognizer Term (Term, Term, Term, Term) +asGenBVVecTerm (asApplyAll -> + (isGlobalDef "Prelude.genBVVec" -> Just _, + [n, len, a, + (asLambdaList -> ([_,_], e))])) + | not $ inBitSet 0 $ looseVars e + = Just (n, len, a, e) +asGenBVVecTerm _ = Nothing + +type TmPrim = Prim TermModel + +-- | Convert a Boolean value to a 'Term'; like 'readBackValue' but that function +-- requires a 'SimulatorConfig' which we cannot easily generate here... +boolValToTerm :: SharedContext -> Value TermModel -> IO Term +boolValToTerm _ (VBool (Left tm)) = return tm +boolValToTerm sc (VBool (Right b)) = scBool sc b +boolValToTerm _ (VExtra (VExtraTerm _tp tm)) = return tm +boolValToTerm _ v = error ("boolValToTerm: unexpected value: " ++ show v) + +-- | An implementation of a primitive function that expects a @genBVVec@ term +primGenBVVec :: SharedContext -> (Term -> TmPrim) -> TmPrim +primGenBVVec sc f = + PrimFilterFun "genBVVecPrim" + (\case + VExtra (VExtraTerm _ (asGenBVVecTerm -> Just (n, _, _, e))) -> + -- Generate the function \i -> [i/1,error/0]e + lift $ + do i_tp <- scBoolType sc >>= scVecType sc n + let err_tm = error "primGenBVVec: unexpected variable occurrence" + i_tm <- scLocalVar sc 0 + body <- instantiateVarList sc 0 [err_tm,i_tm] e + scLambda sc "i" i_tp body + _ -> mzero) + f + +-- | An implementation of a primitive function that expects a bitvector term +primBVTermFun :: SharedContext -> (Term -> TmPrim) -> TmPrim +primBVTermFun sc = + PrimFilterFun "primBVTermFun" $ + \case + VExtra (VExtraTerm _ w_tm) -> return w_tm + VWord (Left (_,w_tm)) -> return w_tm + VWord (Right bv) -> + lift $ scBvConst sc (fromIntegral (Prim.width bv)) (Prim.unsigned bv) + VVector vs -> + lift $ + do tms <- traverse (boolValToTerm sc <=< force) (V.toList vs) + tp <- scBoolType sc + scVectorReduced sc tp tms + v -> lift (putStrLn ("primBVTermFun: unhandled value: " ++ show v)) >> mzero + +-- | Implementations of primitives for normalizing SMT terms +smtNormPrims :: SharedContext -> Map Ident TmPrim +smtNormPrims sc = Map.fromList + [ + ("Prelude.genBVVec", + Prim (do tp <- scTypeOfGlobal sc "Prelude.genBVVec" + VExtra <$> VExtraTerm (VTyTerm (mkSort 1) tp) <$> + scGlobalDef sc "Prelude.genBVVec")), + + ("Prelude.atBVVec", + PrimFun $ \_n -> PrimFun $ \_len -> tvalFun $ \a -> + primGenBVVec sc $ \f -> primBVTermFun sc $ \ix -> PrimFun $ \_pf -> + Prim (VExtra <$> VExtraTerm a <$> scApply sc f ix) + ) + ] + +-- | Normalize a 'Term' before building an SMT query for it +normSMTProp :: Term -> MRM Term +normSMTProp t = + debugPrint 2 "Normalizing term:" >> + debugPrettyInCtx 2 t >> + liftSC0 return >>= \sc -> + liftSC0 scGetModuleMap >>= \modmap -> + liftSC5 normalizeSharedTerm modmap (smtNormPrims sc) Map.empty Set.empty t + -- | Test if a closed Boolean term is "provable", i.e., its negation is -- unsatisfiable, using an SMT solver. By "closed" we mean that it contains no -- uvars or 'MRVar's. +-- +-- FIXME: use the timeout! mrProvableRaw :: Term -> MRM Bool mrProvableRaw prop_term = - do smt_conf <- mrSMTConfig <$> get - timeout <- mrSMTTimeout <$> get + do sc <- mrSC <$> get prop <- liftSC1 termToProp prop_term - (smt_res, _) <- liftSC4 SBV.proveUnintSBVIO smt_conf mempty timeout prop + unints <- Set.map ecVarIndex <$> getAllExtSet <$> liftSC1 propToTerm prop + debugPrint 2 ("Calling SMT solver with proposition: " ++ + prettyProp defaultPPOpts prop) + sym <- liftIO $ setupWhat4_sym True + (smt_res, _) <- + liftIO $ proveWhat4_solver z3Adapter sym unints sc prop (return ()) case smt_res of - Just _ -> return False - Nothing -> return True + Just _ -> + debugPrint 2 "SMT solver response: not provable" >> return False + Nothing -> + debugPrint 2 "SMT solver response: provable" >> return True -- | Test if a Boolean term over the current uvars is provable given the current -- assumptions @@ -1022,8 +1241,21 @@ mrProvable :: Term -> MRM Bool mrProvable bool_tm = do assumps <- mrAssumptions <$> get prop <- liftSC2 scImplies assumps bool_tm >>= liftSC1 scEqTrue - forall_prop <- piUVarsM prop - mrProvableRaw forall_prop + prop_inst <- flip instantiateUVarsM prop $ \nm tp -> + liftSC1 scWhnf tp >>= \case + (asBVVecType -> Just (n, len, a)) -> + -- For variables of type BVVec, create a Vec n Bool -> a function as an + -- ExtCns and apply genBVVec to it + do + ec_tp <- + liftSC1 completeOpenTerm $ + arrowOpenTerm "_" (applyOpenTermMulti (globalOpenTerm "Prelude.Vec") + [closedOpenTerm n, boolTypeOpenTerm]) + (closedOpenTerm a) + ec <- liftSC2 scFreshEC nm ec_tp >>= liftSC1 scExtCns + liftSC4 genBVVecTerm n len a ec + tp' -> liftSC2 scFreshEC nm tp' >>= liftSC1 scExtCns + normSMTProp prop_inst >>= mrProvableRaw -- | Build a Boolean 'Term' stating that two 'Term's are equal. This is like -- 'scEq' except that it works on open terms. @@ -1042,73 +1274,137 @@ mrEq' (asVectorType -> Just (n, asBoolType -> Just ())) t1 t2 = liftSC3 scBvEq n t1 t2 mrEq' _ _ _ = error "mrEq': unsupported type" +-- | A 'Term' in an extended context of universal variables, which are listed +-- "outside in", meaning the highest deBruijn index comes first +data TermInCtx = TermInCtx [(LocalName,Term)] Term + +-- | Conjoin two 'TermInCtx's, assuming they both have Boolean type +andTermInCtx :: TermInCtx -> TermInCtx -> MRM TermInCtx +andTermInCtx (TermInCtx ctx1 t1) (TermInCtx ctx2 t2) = + do + -- Insert the variables in ctx2 into the context of t1 starting at index 0, + -- by lifting its variables starting at 0 by length ctx2 + t1' <- liftTermLike 0 (length ctx2) t1 + -- Insert the variables in ctx1 into the context of t1 starting at index + -- length ctx2, by lifting its variables starting at length ctx2 by length + -- ctx1 + t2' <- liftTermLike (length ctx2) (length ctx1) t2 + TermInCtx (ctx1++ctx2) <$> liftSC2 scAnd t1' t2' + +-- | Extend the context of a 'TermInCtx' with additional universal variables +-- bound "outside" the 'TermInCtx' +extTermInCtx :: [(LocalName,Term)] -> TermInCtx -> TermInCtx +extTermInCtx ctx (TermInCtx ctx' t) = TermInCtx (ctx++ctx') t + +-- | Run an 'MRM' computation in the context of a 'TermInCtx', passing in the +-- 'Term' +withTermInCtx :: TermInCtx -> (Term -> MRM a) -> MRM a +withTermInCtx (TermInCtx [] tm) f = f tm +withTermInCtx (TermInCtx ((nm,tp):ctx) tm) f = + withUVar nm (Type tp) $ const $ withTermInCtx (TermInCtx ctx tm) f + -- | A "simple" strategy for proving equality between two terms, which we assume --- are of the same type. This strategy first checks if either side is an --- uninstantiated evar, in which case it set that evar to the other side. If --- not, it builds an equality proposition by applying the supplied function to --- both sides, and passes this proposition to an SMT solver. -mrProveEqSimple :: (Term -> Term -> MRM Term) -> MRVarMap -> Term -> Term -> - MRM () +-- are of the same type, which builds an equality proposition by applying the +-- supplied function to both sides and passes this proposition to an SMT solver. +mrProveEqSimple :: (Term -> Term -> MRM Term) -> Term -> Term -> + MRM TermInCtx +-- NOTE: The use of mrSubstEVars instead of mrSubstEVarsStrict means that we +-- allow evars in the terms we send to the SMT solver, but we treat them as +-- uvars. +mrProveEqSimple eqf t1 t2 = + do t1' <- mrSubstEVars t1 + t2' <- mrSubstEVars t2 + TermInCtx [] <$> eqf t1' t2' + + +-- | Prove that two terms are equal, instantiating evars if necessary, or +-- throwing an error if this is not possible +mrProveEq :: Term -> Term -> MRM () +mrProveEq t1 t2 = + do mrDebugPPPrefixSep 1 "mrProveEq" t1 "==" t2 + tp <- mrTypeOf t1 + varmap <- mrVars <$> get + cond_in_ctx <- mrProveEqH varmap tp t1 t2 + success <- withTermInCtx cond_in_ctx mrProvable + if success then return () else + throwError (TermsNotEq t1 t2) + +-- | The main workhorse for 'prProveEq'. Build a Boolean term expressing that +-- the third and fourth arguments, whose type is given by the second. This is +-- done in a continuation monad so that the output term can be in a context with +-- additional universal variables. +mrProveEqH :: Map MRVar MRVarInfo -> Term -> Term -> Term -> MRM TermInCtx + +{- +mrProveEqH _ _ t1 t2 + | trace ("mrProveEqH:\n" ++ showTerm t1 ++ "\n==\n" ++ showTerm t2) False = undefined +-} -- If t1 is an instantiated evar, substitute and recurse -mrProveEqSimple eqf var_map (asEVarApp var_map -> Just (_, args, Just f)) t2 = - mrApplyAll f args >>= \t1' -> mrProveEqSimple eqf var_map t1' t2 +mrProveEqH var_map tp (asEVarApp var_map -> Just (_, args, Just f)) t2 = + mrApplyAll f args >>= \t1' -> mrProveEqH var_map tp t1' t2 -- If t1 is an uninstantiated evar, instantiate it with t2 -mrProveEqSimple _ var_map t1@(asEVarApp var_map -> - Just (evar, args, Nothing)) t2 = +mrProveEqH var_map _tp (asEVarApp var_map -> Just (evar, args, Nothing)) t2 = do t2' <- mrSubstEVars t2 success <- mrTrySetAppliedEVar evar args t2' - if success then return () else throwError (TermsNotEq t1 t2) + TermInCtx [] <$> liftSC1 scBool success -- If t2 is an instantiated evar, substitute and recurse -mrProveEqSimple eqf var_map t1 (asEVarApp var_map -> Just (_, args, Just f)) = - mrApplyAll f args >>= \t2' -> mrProveEqSimple eqf var_map t1 t2' +mrProveEqH var_map tp t1 (asEVarApp var_map -> Just (_, args, Just f)) = + mrApplyAll f args >>= \t2' -> mrProveEqH var_map tp t1 t2' -- If t2 is an uninstantiated evar, instantiate it with t1 -mrProveEqSimple _ var_map t1 t2@(asEVarApp var_map -> - Just (evar, args, Nothing)) = +mrProveEqH var_map _tp t1 (asEVarApp var_map -> Just (evar, args, Nothing)) = do t1' <- mrSubstEVars t1 success <- mrTrySetAppliedEVar evar args t1' - if success then return () else throwError (TermsNotEq t1 t2) - --- Otherwise, try to prove both sides are equal. The use of mrSubstEVars instead --- of mrSubstEVarsStrict means that we allow evars in the terms we send to the --- SMT solver, but we treat them as uvars. -mrProveEqSimple eqf _ t1 t2 = - do t1' <- mrSubstEVars t1 - t2' <- mrSubstEVars t2 - prop <- eqf t1' t2' - success <- mrProvable prop - if success then return () else - throwError (TermsNotEq t1 t2) - - --- | Prove that two terms are equal, instantiating evars if necessary, or --- throwing an error if this is not possible -mrProveEq :: Term -> Term -> MRM () -mrProveEq t1_top t2_top = - (do mrDebugPPPrefixSep 1 "mrProveEq" t1_top "==" t2_top - tp <- mrTypeOf t1_top - varmap <- mrVars <$> get - proveEq varmap tp t1_top t2_top) - where - proveEq :: Map MRVar MRVarInfo -> Term -> Term -> Term -> MRM () - proveEq var_map (asDataType -> Just (pn, [])) t1 t2 - | primName pn == "Prelude.Nat" = - mrProveEqSimple (liftSC2 scEqualNat) var_map t1 t2 - proveEq var_map (asVectorType -> Just (n, asBoolType -> Just ())) t1 t2 = - -- FIXME: make a better solver for bitvector equalities - mrProveEqSimple (liftSC3 scBvEq n) var_map t1 t2 - proveEq var_map (asBoolType -> Just _) t1 t2 = - mrProveEqSimple (liftSC2 scBoolEq) var_map t1 t2 - proveEq var_map (asIntegerType -> Just _) t1 t2 = - mrProveEqSimple (liftSC2 scIntEq) var_map t1 t2 - proveEq _ _ t1 t2 = - -- As a fallback, for types we can't handle, just check convertibility - mrConvertible t1 t2 >>= \case - True -> return () - False -> throwError (TermsNotEq t1 t2) + TermInCtx [] <$> liftSC1 scBool success + +-- For the nat, bitvector, Boolean, and integer types, call mrProveEqSimple +mrProveEqH _ (asDataType -> Just (pn, [])) t1 t2 + | primName pn == "Prelude.Nat" = + mrProveEqSimple (liftSC2 scEqualNat) t1 t2 +mrProveEqH _ (asVectorType -> Just (n, asBoolType -> Just ())) t1 t2 = + -- FIXME: make a better solver for bitvector equalities + mrProveEqSimple (liftSC3 scBvEq n) t1 t2 +mrProveEqH _ (asBoolType -> Just _) t1 t2 = + mrProveEqSimple (liftSC2 scBoolEq) t1 t2 +mrProveEqH _ (asIntegerType -> Just _) t1 t2 = + mrProveEqSimple (liftSC2 scIntEq) t1 t2 + +-- For pair types, prove both the left and right projections are equal +mrProveEqH var_map (asPairType -> Just (tpL, tpR)) t1 t2 = + do t1L <- liftSC1 scPairLeft t1 + t2L <- liftSC1 scPairLeft t2 + t1R <- liftSC1 scPairRight t1 + t2R <- liftSC1 scPairRight t2 + condL <- mrProveEqH var_map tpL t1L t2L + condR <- mrProveEqH var_map tpR t1R t2R + andTermInCtx condL condR + +-- For non-bitvector vector types, prove all projections are equal by +-- quantifying over a universal index variable and proving equality at that +-- index +mrProveEqH _ (asBVVecType -> Just (n, len, tp)) t1 t2 = + liftSC0 scBoolType >>= \bool_tp -> + liftSC2 scVecType n bool_tp >>= \ix_tp -> + withUVarLift "eq_ix" (Type ix_tp) (n,(len,(tp,(t1,t2)))) $ + \ix' (n',(len',(tp',(t1',t2')))) -> + liftSC2 scGlobalApply "Prelude.is_bvult" [n', ix', len'] >>= \pf_tp -> + withUVarLift "eq_pf" (Type pf_tp) (n',(len',(tp',(ix',(t1',t2'))))) $ + \pf'' (n'',(len'',(tp'',(ix'',(t1'',t2''))))) -> + do t1_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n'', len'', tp'', + t1'', ix'', pf''] + t2_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n'', len'', tp'', + t2'', ix'', pf''] + var_map <- mrVars <$> get + extTermInCtx [("eq_ix",ix_tp),("eq_pf",pf_tp)] <$> + mrProveEqH var_map tp'' t1_prj t2_prj + +-- As a fallback, for types we can't handle, just check convertibility +mrProveEqH _ _ t1 t2 = + do success <- mrConvertible t1 t2 + TermInCtx [] <$> liftSC1 scBool success ---------------------------------------------------------------------- @@ -1121,10 +1417,11 @@ asCompM (asApp -> Just (isGlobalDef "Prelude.CompM" -> Just (), tp)) = return tp asCompM _ = fail "not a CompM type!" --- | Test if a type is a monadic function type of 0 or more arguments -isCompFunType :: Term -> Bool -isCompFunType (asPiList -> (_, asCompM -> Just _)) = True -isCompFunType _ = False +-- | Test if a type normalizes to a monadic function type of 0 or more arguments +isCompFunType :: SharedContext -> Term -> IO Bool +isCompFunType sc t = scWhnf sc t >>= \case + (asPiList -> (_, asCompM -> Just _)) -> return True + _ -> return False -- | Pattern-match on a @LetRecTypes@ list in normal form and return a list of -- the types it specifies, each in normal form and with uvars abstracted out @@ -1144,6 +1441,48 @@ asNestedPairs (asPairValue -> Just (x, asNestedPairs -> Just xs)) = Just (x:xs) asNestedPairs (asFTermF -> Just UnitValue) = Just [] asNestedPairs _ = Nothing +-- | Syntactically project then @i@th element of the body of a lambda. That is, +-- assuming the input 'Term' has the form +-- +-- > \ (x1:T1) ... (xn:Tn) -> (e1, (e2, ... (en, ()))) +-- +-- return the bindings @x1:T1,...,xn:Tn@ and @ei@ +synProjFunBody :: Int -> Term -> Maybe ([(LocalName, Term)], Term) +synProjFunBody i (asLambdaList -> (vars, asTupleValue -> Just es)) = + -- NOTE: we are doing 1-based indexing instead of 0-based, thus the -1 + Just $ (vars, es !! (i-1)) +synProjFunBody _ _ = Nothing + +-- | Bind fresh function variables for a @letRecM@ or @multiFixM@ with the given +-- @LetRecTypes@ and definitions for the function bodies as a lambda +mrFreshLetRecVars :: Term -> Term -> MRM [Term] +mrFreshLetRecVars lrts defs_f = + do + -- First, make fresh function constants for all the bound functions, using + -- the names bound by defs_f and just "F" if those run out + let fun_var_names = + map fst (fst $ asLambdaList defs_f) ++ repeat "F" + fun_tps <- asLRTList lrts + funs <- zipWithM mrFreshVar fun_var_names fun_tps + fun_tms <- mapM mrVarTerm funs + + -- Next, apply the definition function defs_f to our function vars, yielding + -- the definitions of the individual letrec-bound functions in terms of the + -- new function constants + defs_tm <- mrApplyAll defs_f fun_tms + defs <- case asNestedPairs defs_tm of + Just defs -> return defs + Nothing -> throwError (MalformedDefsFun defs_f) + + -- Remember the body associated with each fresh function constant + zipWithM_ (\f body -> + lambdaUVarsM body >>= \cl_body -> + mrSetVarInfo f (FunVarInfo cl_body)) funs defs + + -- Finally, return the terms for the fresh function variables + return fun_tms + + -- | Normalize a 'Term' of monadic type to monadic normal form normCompTerm :: Term -> MRM NormComp normCompTerm = normComp . CompTerm @@ -1170,6 +1509,8 @@ normComp (CompTerm t) = return $ Ite cond (CompTerm then_tm) (CompTerm else_tm) (isGlobalDef "Prelude.either" -> Just (), [_, _, _, f, g, eith]) -> return $ Either (CompFunTerm f) (CompFunTerm g) eith + (isGlobalDef "Prelude.maybe" -> Just (), [tp, _, m, f, mayb]) -> + return $ MaybeElim (Type tp) (CompTerm m) (CompFunTerm f) mayb (isGlobalDef "Prelude.orM" -> Just (), [_, m1, m2]) -> return $ OrM (CompTerm m1) (CompTerm m2) (isGlobalDef "Prelude.existsM" -> Just (), [tp, _, body_tm]) -> @@ -1178,28 +1519,9 @@ normComp (CompTerm t) = return $ ForallM (Type tp) (CompFunTerm body_tm) (isGlobalDef "Prelude.letRecM" -> Just (), [lrts, _, defs_f, body_f]) -> do - -- First, make fresh function constants for all the bound functions, - -- using the names bound by body_f and just "F" if those run out - let fun_var_names = - map fst (fst $ asLambdaList body_f) ++ repeat "F" - fun_tps <- asLRTList lrts - funs <- zipWithM mrFreshVar fun_var_names fun_tps - fun_tms <- mapM mrVarTerm funs - - -- Next, apply the definition function defs_f to our function vars, - -- yielding the definitions of the individual letrec-bound functions in - -- terms of the new function constants - defs_tm <- mrApplyAll defs_f fun_tms - defs <- case asNestedPairs defs_tm of - Just defs -> return defs - Nothing -> throwError (MalformedDefsFun defs_f) - - -- Remember the body associated with each fresh function constant - zipWithM_ (\f body -> - lambdaUVarsM body >>= \cl_body -> - mrSetVarInfo f (FunVarInfo cl_body)) funs defs - - -- Finally, apply the body function to our function vars and recursively + -- Bind fresh function vars for the letrec-bound functions + fun_tms <- mrFreshLetRecVars lrts defs_f + -- Apply the body function to our function vars and recursively -- normalize the resulting computation body_tm <- mrApplyAll body_f fun_tms normComp (CompTerm body_tm) @@ -1214,14 +1536,31 @@ normComp (CompTerm t) = mrApplyAll body args >>= normCompTerm -} + -- Recognize (multiFixM lrts (\ f1 ... fn -> (body1, ..., bodyn))).i args + (asTupleSelector -> + Just (asApplyAll -> (isGlobalDef "Prelude.multiFixM" -> Just (), + [lrts, defs_f]), + i), args) + -- Extract out the function \f1 ... fn -> bodyi + | Just (vars, body_i) <- synProjFunBody i defs_f -> + do + -- Bind fresh function variables for the functions f1 ... fn + fun_tms <- mrFreshLetRecVars lrts defs_f + -- Re-abstract the body + body_f <- liftSC2 scLambdaList vars body_i + -- Apply body_f to f1 ... fn and the top-level arguments + body_tm <- mrApplyAll body_f (fun_tms ++ args) + normComp (CompTerm body_tm) + + -- For an ExtCns, we have to check what sort of variable it is -- FIXME: substitute for evars if they have been instantiated ((asExtCns -> Just ec), args) -> do fun_name <- extCnsToFunName ec return $ FunBind fun_name args CompFunReturn - ((asTypedGlobalDef -> Just gdef), args) -> - return $ FunBind (GlobalName gdef) args CompFunReturn + ((asGlobalFunName -> Just f), args) -> + return $ FunBind f args CompFunReturn _ -> throwError (MalformedComp t) @@ -1234,6 +1573,8 @@ normBind (Ite cond comp1 comp2) k = return $ Ite cond (CompBind comp1 k) (CompBind comp2 k) normBind (Either f g t) k = return $ Either (compFunComp f k) (compFunComp g k) t +normBind (MaybeElim tp m f t) k = + return $ MaybeElim tp (CompBind m k) (compFunComp f k) t normBind (OrM comp1 comp2) k = return $ OrM (CompBind comp1 k) (CompBind comp2 k) normBind (ExistsM tp f) k = return $ ExistsM tp (compFunComp f k) @@ -1320,6 +1661,26 @@ mrRefines' (ReturnM e1) (ReturnM e2) = mrProveEq e1 e2 mrRefines' (ErrorM _) (ErrorM _) = return () mrRefines' (ReturnM e) (ErrorM _) = throwError (ReturnNotError e) mrRefines' (ErrorM _) (ReturnM e) = throwError (ReturnNotError e) +mrRefines' (MaybeElim (Type (asEq -> Just (tp,e1,e2))) m1 f1 _) m2 = + do cond <- mrEq' tp e1 e2 + not_cond <- liftSC1 scNot cond + cond_pf <- + liftSC1 scEqTrue cond >>= piUVarsM >>= mrFreshVar "pf" >>= mrVarTerm + m1' <- applyNormCompFun f1 cond_pf + cond_holds <- mrProvable cond + if cond_holds then mrRefines m1' m2 else + withAssumption cond (mrRefines m1' m2) >> + withAssumption not_cond (mrRefines m1 m2) +mrRefines' m1 (MaybeElim (Type (asEq -> Just (tp,e1,e2))) m2 f2 _) = + do cond <- mrEq' tp e1 e2 + not_cond <- liftSC1 scNot cond + cond_pf <- + liftSC1 scEqTrue cond >>= piUVarsM >>= mrFreshVar "pf" >>= mrVarTerm + m2' <- applyNormCompFun f2 cond_pf + cond_holds <- mrProvable cond + if cond_holds then mrRefines m1 m2' else + withAssumption cond (mrRefines m1 m2') >> + withAssumption not_cond (mrRefines m1 m2) mrRefines' (Ite cond1 m1 m1') m2_all@(Ite cond2 m2 m2') = liftSC1 scNot cond1 >>= \not_cond1 -> (mrEq cond1 cond2 >>= mrProvable) >>= \case @@ -1526,14 +1887,13 @@ mrRefinesFun _ _ = error "mrRefinesFun: unreachable!" askMRSolver :: SharedContext -> Int {- ^ The debug level -} -> - SBV.SMTConfig {- ^ SBV configuration -} -> Maybe Integer {- ^ Timeout in milliseconds for each SMT call -} -> Term -> Term -> IO (Maybe MRFailure) -askMRSolver sc dlvl smt_conf timeout t1 t2 = - do tp1 <- scTypeOf sc t1 - tp2 <- scTypeOf sc t2 - init_st <- mkMRState sc Map.empty smt_conf timeout dlvl +askMRSolver sc dlvl timeout t1 t2 = + do tp1 <- scTypeOf sc t1 >>= scWhnf sc + tp2 <- scTypeOf sc t2 >>= scWhnf sc + init_st <- mkMRState sc Map.empty timeout dlvl case asPiList tp1 of (uvar_ctx, asCompM -> Just _) -> fmap (either Just (const Nothing)) $ runMRM init_st $