diff --git a/src/SAWScript/Prover/MRSolver/Monad.hs b/src/SAWScript/Prover/MRSolver/Monad.hs index b89e3ae1c1..fd4ce458c2 100644 --- a/src/SAWScript/Prover/MRSolver/Monad.hs +++ b/src/SAWScript/Prover/MRSolver/Monad.hs @@ -1,6 +1,7 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE ViewPatterns #-} @@ -65,6 +66,8 @@ data MRFailure | MalformedDefsFun Term | MalformedComp Term | NotCompFunType Term + | CoIndHypMismatchWidened FunName FunName CoIndHyp + | CoIndHypMismatchFailure (NormComp, NormComp) (NormComp, NormComp) -- | A local variable binding | MRFailureLocalVar LocalName MRFailure -- | Information about the context of the failure @@ -121,6 +124,13 @@ instance PrettyInCtx MRFailure where ppWithPrefix "Could not handle computation:" t prettyInCtx (NotCompFunType tp) = ppWithPrefix "Not a computation or computational function type:" tp + prettyInCtx (CoIndHypMismatchWidened nm1 nm2 _) = + ppWithPrefixSep "[Internal] Trying to widen co-inductive hypothesis on:" nm1 "," nm2 + prettyInCtx (CoIndHypMismatchFailure (tm1, tm2) (tm1', tm2')) = + do pp <- ppWithPrefixSep "" tm1 "|=" tm2 + pp' <- ppWithPrefixSep "" tm1' "|=" tm2' + return $ "Could not match co-inductive hypothesis:" <> pp' <> line <> + "with goal:" <> pp prettyInCtx (MRFailureLocalVar x err) = local (x:) $ prettyInCtx err prettyInCtx (MRFailureCtx ctx err) = @@ -163,6 +173,27 @@ asEVarApp var_map (asExtCnsApp -> Just (ec, args)) Just (MRVar ec, args, maybe_inst) asEVarApp _ _ = Nothing +-- | A co-inductive hypothesis of the form: +-- +-- > forall x1, ..., xn. F y1 ... ym |= G z1 ... zl +-- +-- for some universal context @x1:T1, ..., xn:Tn@ and some lists of argument +-- expressions @y1, ..., ym@ and @z1, ..., zl@ over the universal context. +data CoIndHyp = CoIndHyp { + -- | The uvars that were in scope when this assmption was created, in order + -- from outermost to innermost; that is, the uvars as "seen from outside their + -- scope", which is the reverse of the order of 'mrUVars', below + coIndHypCtx :: [(LocalName,Term)], + -- | The LHS argument expressions @y1, ..., ym@ over the 'coIndHypCtx' uvars + coIndHypLHS :: [Term], + -- | The RHS argument expressions @y1, ..., ym@ over the 'coIndHypCtx' uvars + coIndHypRHS :: [Term] +} deriving Show + +-- | A map from pairs of function names to co-inductive hypotheses over those +-- names +type CoIndHyps = Map (FunName, FunName) CoIndHyp + -- | An assumption that a named function refines some specificaiton. This has -- the form -- @@ -179,44 +210,48 @@ data FunAssump = FunAssump { -- | The argument expressions @e1, ..., en@ over the 'fassumpCtx' uvars fassumpArgs :: [Term], -- | The right-hand side upper bound @m@ over the 'fassumpCtx' uvars - fassumpRHS :: NormComp } + fassumpRHS :: NormComp +} --- | State maintained by MR. Solver -data MRState = MRState { +-- | A map from function names to function refinement assumptions over that +-- name +type FunAssumps = Map FunName FunAssump + +-- | Parameters and locals for MR. Solver +data MRInfo = MRInfo { -- | Global shared context for building terms, etc. - mrSC :: SharedContext, + mriSC :: SharedContext, -- | SMT timeout for SMT calls made by Mr. Solver - mrSMTTimeout :: Maybe Integer, - -- | The context of universal variables, which are free SAW core variables, in - -- order from innermost to outermost, i.e., where element @0@ corresponds to - -- deBruijn index @0@ - mrUVars :: [(LocalName,Type)], - -- | The existential and letrec-bound variables - mrVars :: MRVarMap, - -- | The current assumptions of function refinement - mrFunAssumps :: Map FunName FunAssump, + mriSMTTimeout :: Maybe Integer, + -- | The current context of universal variables, which are free SAW core + -- variables, in order from innermost to outermost, i.e., where element @0@ + -- corresponds to deBruijn index @0@ + mriUVars :: [(LocalName,Type)], + -- | The set of function refinements to be assumed by to Mr. Solver + mriFunAssumps :: FunAssumps, + -- | The current set of co-inductive hypotheses + mriCoIndHyps :: CoIndHyps, -- | The current assumptions, which are conjoined into a single Boolean term; -- note that these have the current UVars free - mrAssumptions :: Term, + mriAssumptions :: Term, -- | The debug level, which controls debug printing - mrDebugLevel :: Int + mriDebugLevel :: Int } --- | Build a default, empty state from SMT configuration parameters and a set of --- function refinement assumptions -mkMRState :: SharedContext -> Map FunName FunAssump -> - Maybe Integer -> Int -> IO MRState -mkMRState sc fun_assumps timeout dlvl = - scBool sc True >>= \true_tm -> - return $ MRState { mrSC = sc, - mrSMTTimeout = timeout, mrUVars = [], mrVars = Map.empty, - mrFunAssumps = fun_assumps, mrAssumptions = true_tm, - mrDebugLevel = dlvl } - --- | Mr. Monad, the monad used by MR. Solver, which is the state-exception monad -newtype MRM a = MRM { unMRM :: StateT MRState (ExceptT MRFailure IO) a } +-- | State maintained by MR. Solver +data MRState = MRState { + -- | The existential and letrec-bound variables + mrsVars :: MRVarMap +} + +-- | Mr. Monad, the monad used by MR. Solver, which has 'MRInfo' as as a +-- shared environment, 'MRState' as state, and 'MRFailure' as an exception +-- type, all over an 'IO' monad +newtype MRM a = MRM { unMRM :: ReaderT MRInfo (StateT MRState + (ExceptT MRFailure IO)) a } deriving (Functor, Applicative, Monad, MonadIO, - MonadState MRState, MonadError MRFailure) + MonadReader MRInfo, MonadState MRState, + MonadError MRFailure) instance MonadTerm MRM where mkTermF = liftSC1 scTermF @@ -224,9 +259,49 @@ instance MonadTerm MRM where whnfTerm = liftSC1 scWhnf substTerm = liftSC3 instantiateVarList +-- | Get the current value of 'mriSC' +mrSC :: MRM SharedContext +mrSC = mriSC <$> ask + +-- | Get the current value of 'mrSMTTimeout' +mrSMTTimeout :: MRM (Maybe Integer) +mrSMTTimeout = mriSMTTimeout <$> ask + +-- | Get the current value of 'mrUVars' +mrUVars :: MRM [(LocalName,Type)] +mrUVars = mriUVars <$> ask + +-- | Get the current value of 'mrFunAssumps' +mrFunAssumps :: MRM FunAssumps +mrFunAssumps = mriFunAssumps <$> ask + +-- | Get the current value of 'mrCoIndHyps' +mrCoIndHyps :: MRM CoIndHyps +mrCoIndHyps = mriCoIndHyps <$> ask + +-- | Get the current value of 'mrAssumptions' +mrAssumptions :: MRM Term +mrAssumptions = mriAssumptions <$> ask + +-- | Get the current value of 'mrDebugLevel' +mrDebugLevel :: MRM Int +mrDebugLevel = mriDebugLevel <$> ask + +-- | Get the current value of 'mrVars' +mrVars :: MRM MRVarMap +mrVars = mrsVars <$> get + -- | Run an 'MRM' computation and return a result or an error -runMRM :: MRState -> MRM a -> IO (Either MRFailure a) -runMRM init_st m = runExceptT $ flip evalStateT init_st $ unMRM m +runMRM :: SharedContext -> Maybe Integer -> Int -> FunAssumps -> + MRM a -> IO (Either MRFailure a) +runMRM sc timeout debug assumps m = + do true_tm <- scBool sc True + let init_info = MRInfo { mriSC = sc, mriSMTTimeout = timeout, + mriDebugLevel = debug, mriFunAssumps = assumps, + mriUVars = [], mriCoIndHyps = Map.empty, + mriAssumptions = true_tm } + let init_st = MRState { mrsVars = Map.empty } + runExceptT $ flip evalStateT init_st $ flip runReaderT init_info $ unMRM m -- | Apply a function to any failure thrown by an 'MRM' computation mapFailure :: (MRFailure -> MRFailure) -> MRM a -> MRM a @@ -255,29 +330,29 @@ catchErrorEither m = catchError (Right <$> m) (return . Left) -- | Lift a nullary SharedTerm computation into 'MRM' liftSC0 :: (SharedContext -> IO a) -> MRM a -liftSC0 f = (mrSC <$> get) >>= \sc -> liftIO (f sc) +liftSC0 f = mrSC >>= \sc -> liftIO (f sc) -- | Lift a unary SharedTerm computation into 'MRM' liftSC1 :: (SharedContext -> a -> IO b) -> a -> MRM b -liftSC1 f a = (mrSC <$> get) >>= \sc -> liftIO (f sc a) +liftSC1 f a = mrSC >>= \sc -> liftIO (f sc a) -- | Lift a binary SharedTerm computation into 'MRM' liftSC2 :: (SharedContext -> a -> b -> IO c) -> a -> b -> MRM c -liftSC2 f a b = (mrSC <$> get) >>= \sc -> liftIO (f sc a b) +liftSC2 f a b = mrSC >>= \sc -> liftIO (f sc a b) -- | Lift a ternary SharedTerm computation into 'MRM' liftSC3 :: (SharedContext -> a -> b -> c -> IO d) -> a -> b -> c -> MRM d -liftSC3 f a b c = (mrSC <$> get) >>= \sc -> liftIO (f sc a b c) +liftSC3 f a b c = mrSC >>= \sc -> liftIO (f sc a b c) -- | Lift a quaternary SharedTerm computation into 'MRM' liftSC4 :: (SharedContext -> a -> b -> c -> d -> IO e) -> a -> b -> c -> d -> MRM e -liftSC4 f a b c d = (mrSC <$> get) >>= \sc -> liftIO (f sc a b c d) +liftSC4 f a b c d = mrSC >>= \sc -> liftIO (f sc a b c d) -- | Lift a quinary SharedTerm computation into 'MRM' liftSC5 :: (SharedContext -> a -> b -> c -> d -> e -> IO f) -> a -> b -> c -> d -> e -> MRM f -liftSC5 f a b c d e = (mrSC <$> get) >>= \sc -> liftIO (f sc a b c d e) +liftSC5 f a b c d e = mrSC >>= \sc -> liftIO (f sc a b c d e) ---------------------------------------------------------------------- @@ -319,7 +394,7 @@ mrApplyAll f args = liftSC2 scApplyAll f args >>= liftSC1 betaNormalize -- types as SAW core 'Term's, with the least recently bound uvar first, i.e., in -- the order as seen "from the outside" mrUVarCtx :: MRM [(LocalName,Term)] -mrUVarCtx = reverse <$> map (\(nm,Type tp) -> (nm,tp)) <$> mrUVars <$> get +mrUVarCtx = reverse <$> map (\(nm,Type tp) -> (nm,tp)) <$> mrUVars -- | Get the type of a 'Term' in the current uvar context mrTypeOf :: Term -> MRM Term @@ -360,15 +435,11 @@ uniquifyName nm nms = -- assumptions made in the sub-computation will be lost when it completes. withUVar :: LocalName -> Type -> (Term -> MRM a) -> MRM a withUVar nm tp m = - do st <- get - let nm' = uniquifyName nm (map fst $ mrUVars st) - assumps' <- liftTerm 0 1 $ mrAssumptions st - put (st { mrUVars = (nm',tp) : mrUVars st, - mrAssumptions = assumps' }) - ret <- mapFailure (MRFailureLocalVar nm') (liftSC1 scLocalVar 0 >>= m) - modify (\st' -> st' { mrUVars = mrUVars st, - mrAssumptions = mrAssumptions st }) - return ret + do nm' <- uniquifyName nm <$> map fst <$> mrUVars + assumps' <- mrAssumptions >>= liftTerm 0 1 + local (\info -> info { mriUVars = (nm',tp) : mriUVars info, + mriAssumptions = assumps' }) $ + mapFailure (MRFailureLocalVar nm') (liftSC1 scLocalVar 0 >>= m) -- | Run a MR Solver computation in a context extended with a universal variable -- and pass it the lifting (in the sense of 'incVars') of an MR Solver term @@ -397,7 +468,7 @@ withUVars = helper [] where -- most recently bound getAllUVarTerms :: MRM [Term] getAllUVarTerms = - (length <$> mrUVars <$> get) >>= \len -> + (length <$> mrUVars) >>= \len -> mapM (liftSC1 scLocalVar) [len-1, len-2 .. 0] -- | Lambda-abstract all the current uvars out of a 'Term', with the least @@ -436,7 +507,7 @@ mrVarTerm (MRVar ec) = -- | Get the 'VarInfo' associated with a 'MRVar' mrVarInfo :: MRVar -> MRM (Maybe MRVarInfo) -mrVarInfo var = Map.lookup var <$> mrVars <$> get +mrVarInfo var = Map.lookup var <$> mrVars -- | Convert an 'ExtCns' to a 'FunName' extCnsToFunName :: ExtCns Term -> MRM FunName @@ -503,11 +574,11 @@ mrFreshVar nm tp = MRVar <$> liftSC2 scFreshEC nm tp mrSetVarInfo :: MRVar -> MRVarInfo -> MRM () mrSetVarInfo var info = modify $ \st -> - st { mrVars = + st { mrsVars = Map.alter (\case Just _ -> error "mrSetVarInfo" Nothing -> Just info) - var (mrVars st) } + var (mrsVars st) } -- | Make a fresh existential variable of the given type, abstracting out all -- the current uvars and returning the new evar applied to all current uvars @@ -543,14 +614,14 @@ mrSetEVarClosed var val = -- FIXME: catch subtyping errors and report them as being evar failures liftSC3 scCheckSubtype Nothing (TypedTerm val val_tp) var_tp modify $ \st -> - st { mrVars = + st { mrsVars = Map.alter (\case Just (EVarInfo Nothing) -> Just $ EVarInfo (Just val) Just (EVarInfo (Just _)) -> error "Setting existential variable: variable already set!" _ -> error "Setting existential variable: not an evar!") - var (mrVars st) } + var (mrsVars st) } -- | Try to set the value of the application @X e1 .. en@ of evar @X@ to an @@ -596,7 +667,7 @@ mrTrySetAppliedEVar evar args t = -- | Replace all evars in a 'Term' with their instantiations when they have one mrSubstEVars :: Term -> MRM Term mrSubstEVars = memoFixTermFun $ \recurse t -> - do var_map <- mrVars <$> get + do var_map <- mrVars case t of -- If t is an instantiated evar, recurse on its instantiation (asEVarApp var_map -> Just (_, args, Just t')) -> @@ -609,7 +680,7 @@ mrSubstEVars = memoFixTermFun $ \recurse t -> mrSubstEVarsStrict :: Term -> MRM (Maybe Term) mrSubstEVarsStrict top_t = runMaybeT $ flip memoFixTermFun top_t $ \recurse t -> - do var_map <- mrVars <$> get + do var_map <- lift mrVars case t of -- If t is an instantiated evar, recurse on its instantiation (asEVarApp var_map -> Just (_, args, Just t')) -> @@ -624,9 +695,49 @@ mrSubstEVarsStrict top_t = _mrSubstEVarsStrict :: Term -> MRM (Maybe Term) _mrSubstEVarsStrict = mrSubstEVarsStrict +-- | Get the 'CoIndHyp' for a pair of 'FunName's, if there is one +mrGetCoIndHyp :: FunName -> FunName -> MRM (Maybe CoIndHyp) +mrGetCoIndHyp nm1 nm2 = Map.lookup (nm1, nm2) <$> mrCoIndHyps + +-- | Run a compuation under the additional co-inductive assumption that +-- @forall x1, ..., xn. F y1 ... ym |= G z1 ... zl@, where @F@ and @G@ are +-- the given 'FunName's, @y1, ..., ym@ and @z1, ..., zl@ are the given +-- argument lists, and @x1, ..., xn@ is the current context of uvars. If +-- while running the given computation a 'CoIndHypMismatchWidened' error is +-- reached with the given names, the state is restored and the computation is +-- re-run with the widened hypothesis. This is done recursively, meaning this +-- function will only return once no 'CoIndHypMismatchWidened' errors are +-- raised with the given names. +withCoIndHyp :: FunName -> [Term] -> FunName -> [Term] -> MRM a -> MRM a +withCoIndHyp nm1 args1 nm2 args2 m = + do ctx <- mrUVarCtx + withCoIndHyp' (nm1, nm2) (CoIndHyp ctx args1 args2) m + +-- | The main loop of 'withCoIndHyp' +withCoIndHyp' :: (FunName, FunName) -> CoIndHyp -> MRM a -> MRM a +withCoIndHyp' (nm1, nm2) hyp@(CoIndHyp _ args1 args2) m = + do mrDebugPPPrefixSep 1 "withCoIndHyp" (FunBind nm1 args1 CompFunReturn) + "|=" (FunBind nm2 args2 CompFunReturn) + st <- get + hyps' <- Map.insert (nm1, nm2) hyp <$> mrCoIndHyps + (local (\info -> info { mriCoIndHyps = hyps' }) m) `catchError` \case + CoIndHypMismatchWidened nm1' nm2' hyp' | nm1 == nm1' && nm2 == nm2' + -> -- FIXME: Could restoring the state here cause any problems? + put st >> withCoIndHyp' (nm1, nm2) hyp' m + e -> throwError e + +-- | Generate fresh evars for the context of a 'CoIndHyp' and +-- substitute them into its arguments and right-hand side +instantiateCoIndHyp :: CoIndHyp -> MRM ([Term], [Term]) +instantiateCoIndHyp (CoIndHyp {..}) = + do evars <- mrFreshEVars coIndHypCtx + lhs <- substTermLike 0 evars coIndHypLHS + rhs <- substTermLike 0 evars coIndHypRHS + return (lhs, rhs) + -- | Look up the 'FunAssump' for a 'FunName', if there is one mrGetFunAssump :: FunName -> MRM (Maybe FunAssump) -mrGetFunAssump nm = Map.lookup nm <$> mrFunAssumps <$> get +mrGetFunAssump nm = Map.lookup nm <$> mrFunAssumps -- | Run a computation under the additional assumption that a named function -- applied to a list of arguments refines a given right-hand side, all of which @@ -636,12 +747,9 @@ withFunAssump fname args rhs m = do mrDebugPPPrefixSep 1 "withFunAssump" (FunBind fname args CompFunReturn) "|=" rhs ctx <- mrUVarCtx - assumps <- mrFunAssumps <$> get + assumps <- mrFunAssumps let assumps' = Map.insert fname (FunAssump ctx args rhs) assumps - modify (\s -> s { mrFunAssumps = assumps' }) - ret <- m - modify (\s -> s { mrFunAssumps = assumps }) - return ret + local (\info -> info { mriFunAssumps = assumps' }) m -- | Generate fresh evars for the context of a 'FunAssump' and substitute them -- into its arguments and right-hand side @@ -656,17 +764,14 @@ instantiateFunAssump fassump = -- executing a sub-computation withAssumption :: Term -> MRM a -> MRM a withAssumption phi m = - do assumps <- mrAssumptions <$> get + do assumps <- mrAssumptions assumps' <- liftSC2 scAnd phi assumps - modify (\s -> s { mrAssumptions = assumps' }) - ret <- m - modify (\s -> s { mrAssumptions = assumps }) - return ret + local (\info -> info { mriAssumptions = assumps' }) m -- | Print a 'String' if the debug level is at least the supplied 'Int' debugPrint :: Int -> String -> MRM () debugPrint i str = - (mrDebugLevel <$> get) >>= \lvl -> + mrDebugLevel >>= \lvl -> if lvl >= i then liftIO (hPutStrLn stderr str) else return () -- | Print a document if the debug level is at least the supplied 'Int' @@ -677,19 +782,19 @@ debugPretty i pp = debugPrint i $ renderSawDoc defaultPPOpts pp -- at least the supplied 'Int' debugPrettyInCtx :: PrettyInCtx a => Int -> a -> MRM () debugPrettyInCtx i a = - (mrUVars <$> get) >>= \ctx -> debugPrint i (showInCtx (map fst ctx) a) + mrUVars >>= \ctx -> debugPrint i (showInCtx (map fst ctx) a) -- | Pretty-print an object relative to the current context mrPPInCtx :: PrettyInCtx a => a -> MRM SawDoc mrPPInCtx a = - runReader (prettyInCtx a) <$> map fst <$> mrUVars <$> get + runReader (prettyInCtx a) <$> map fst <$> mrUVars -- | Pretty-print the result of 'ppWithPrefixSep' relative to the current uvar -- context to 'stderr' if the debug level is at least the 'Int' provided mrDebugPPPrefixSep :: PrettyInCtx a => Int -> String -> a -> String -> a -> MRM () mrDebugPPPrefixSep i pre a1 sp a2 = - (mrUVars <$> get) >>= \ctx -> + mrUVars >>= \ctx -> debugPretty i $ flip runReader (map fst ctx) (group <$> nest 2 <$> ppWithPrefixSep pre a1 sp a2) diff --git a/src/SAWScript/Prover/MRSolver/SMT.hs b/src/SAWScript/Prover/MRSolver/SMT.hs index ed318a1455..f597b35756 100644 --- a/src/SAWScript/Prover/MRSolver/SMT.hs +++ b/src/SAWScript/Prover/MRSolver/SMT.hs @@ -17,8 +17,6 @@ namely 'mrProvable' and 'mrProveEq'. module SAWScript.Prover.MRSolver.SMT where import qualified Data.Vector as V -import Control.Monad.Reader -import Control.Monad.State import Control.Monad.Except import Data.Map (Map) @@ -162,7 +160,7 @@ normSMTProp t = -- FIXME: use the timeout! mrProvableRaw :: Term -> MRM Bool mrProvableRaw prop_term = - do sc <- mrSC <$> get + do sc <- mrSC prop <- liftSC1 termToProp prop_term unints <- Set.map ecVarIndex <$> getAllExtSet <$> liftSC1 propToTerm prop debugPrint 2 ("Calling SMT solver with proposition: " ++ @@ -180,7 +178,7 @@ mrProvableRaw prop_term = -- assumptions mrProvable :: Term -> MRM Bool mrProvable bool_tm = - do assumps <- mrAssumptions <$> get + do assumps <- mrAssumptions prop <- liftSC2 scImplies assumps bool_tm >>= liftSC1 scEqTrue prop_inst <- flip instantiateUVarsM prop $ \nm tp -> liftSC1 scWhnf tp >>= \case @@ -262,16 +260,21 @@ mrProveEqSimple eqf t1 t2 = t2' <- mrSubstEVars t2 TermInCtx [] <$> eqf t1' t2' - --- | Prove that two terms are equal, instantiating evars if necessary, or --- throwing an error if this is not possible -mrProveEq :: Term -> Term -> MRM () +-- | Prove that two terms are equal, instantiating evars if necessary, +-- returning true on success +mrProveEq :: Term -> Term -> MRM Bool mrProveEq t1 t2 = do mrDebugPPPrefixSep 1 "mrProveEq" t1 "==" t2 tp <- mrTypeOf t1 - varmap <- mrVars <$> get + varmap <- mrVars cond_in_ctx <- mrProveEqH varmap tp t1 t2 - success <- withTermInCtx cond_in_ctx mrProvable + withTermInCtx cond_in_ctx mrProvable + +-- | Prove that two terms are equal, instantiating evars if necessary, or +-- throwing an error if this is not possible +mrAssertProveEq :: Term -> Term -> MRM () +mrAssertProveEq t1 t2 = + do success <- mrProveEq t1 t2 if success then return () else throwError (TermsNotEq t1 t2) @@ -343,7 +346,7 @@ mrProveEqH _ (asBVVecType -> Just (n, len, tp)) t1 t2 = t1'', ix'', pf''] t2_prj <- liftSC2 scGlobalApply "Prelude.atBVVec" [n'', len'', tp'', t2'', ix'', pf''] - var_map <- mrVars <$> get + var_map <- mrVars extTermInCtx [("eq_ix",ix_tp),("eq_pf",pf_tp)] <$> mrProveEqH var_map tp'' t1_prj t2_prj diff --git a/src/SAWScript/Prover/MRSolver/Solver.hs b/src/SAWScript/Prover/MRSolver/Solver.hs index 2f55d31c02..49452fe4d6 100644 --- a/src/SAWScript/Prover/MRSolver/Solver.hs +++ b/src/SAWScript/Prover/MRSolver/Solver.hs @@ -114,7 +114,6 @@ C |- F e1 ... en >>= k |= F' e1' ... em' >>= k': module SAWScript.Prover.MRSolver.Solver where -import Control.Monad.Reader import Control.Monad.Except import qualified Data.Map as Map @@ -366,7 +365,7 @@ mrRefines t1 t2 = -- | The main implementation of 'mrRefines' mrRefines' :: NormComp -> NormComp -> MRM () -mrRefines' (ReturnM e1) (ReturnM e2) = mrProveEq e1 e2 +mrRefines' (ReturnM e1) (ReturnM e2) = mrAssertProveEq e1 e2 mrRefines' (ErrorM _) (ErrorM _) = return () mrRefines' (ReturnM e) (ErrorM _) = throwError (ReturnNotError e) mrRefines' (ErrorM _) (ReturnM e) = throwError (ReturnNotError e) @@ -445,7 +444,7 @@ mrRefines' (FunBind (EVarFunName evar) args CompFunReturn) m2 = mrRefines' (FunBind (LetRecName f) args1 k1) (FunBind (LetRecName f') args2 k2) | f == f' && length args1 == length args2 = - zipWithM_ mrProveEq args1 args2 >> + zipWithM_ mrAssertProveEq args1 args2 >> mrRefinesFun k1 k2 mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = @@ -454,13 +453,31 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = mrConvertible tp1 tp2 >>= \tps_eq -> mrFunBodyRecInfo f1 args1 >>= \maybe_f1_body -> mrFunBodyRecInfo f2 args2 >>= \maybe_f2_body -> - mrGetFunAssump f1 >>= \case + mrGetCoIndHyp f1 f2 >>= \maybe_coIndHyp -> + mrGetFunAssump f1 >>= \maybe_fassump -> + case (maybe_coIndHyp, maybe_fassump) of + + -- If we have a co-inductive assumption that f1 args1' |= f2 args2': + -- * If it is convertible to our goal, continue and prove that k1 |= k2 + -- * If it can be widened with our goal, restart the current proof branch + -- with the widened hypothesis (done by throwing a + -- 'CoIndHypMismatchWidened' error for 'withCoIndHyp' to catch) + -- * Otherwise, throw a 'CoIndHypMismatchFailure' error. + (Just hyp, _) -> + do (args1', args2') <- instantiateCoIndHyp hyp + mrWidenCoIndHyp f1 f2 args1 args2 args1' args2' >>= \case + Convertible -> mrRefinesFun k1 k2 + Widened hyp' -> throwError (CoIndHypMismatchWidened f1 f2 hyp') + CouldNotWiden -> + let m1' = FunBind f1 args1' CompFunReturn + m2' = FunBind f2 args2' CompFunReturn + in throwError (CoIndHypMismatchFailure (m1, m2) (m1', m2')) -- If we have an assumption that f1 args' refines some rhs, then prove that -- args1 = args' and then that rhs refines m2 - Just fassump -> + (_, Just fassump) -> do (assump_args, assump_rhs) <- instantiateFunAssump fassump - zipWithM_ mrProveEq assump_args args1 + zipWithM_ mrAssertProveEq assump_args args1 m1' <- normBind assump_rhs k1 mrRefines m1' m2 @@ -472,16 +489,15 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = _ | Just (f2_body, False) <- maybe_f2_body -> normBindTerm f2_body k2 >>= \m2' -> mrRefines m1 m2' - -- If we do not already have an assumption that f1 refines some specification, - -- and both f1 and f2 are recursive but have the same return type, then try to - -- coinductively prove that f1 args1 |= f2 args2 under the assumption that f1 - -- args1 |= f2 args2, and then try to prove that k1 |= k2 - Nothing - | tps_eq + -- If we don't have a co-inducitve hypothesis for f1 and f2, don't have an + -- assumption that f1 refines some specification, and both f1 and f2 are + -- recursive and have the same return type, then try to coinductively prove + -- that f1 args1 |= f2 args2 under the assumption that f1 args1 |= f2 args2, + -- and then try to prove that k1 |= k2 + _ | tps_eq , Just (f1_body, _) <- maybe_f1_body , Just (f2_body, _) <- maybe_f2_body -> - do withFunAssump f1 args1 (FunBind f2 args2 CompFunReturn) $ - mrRefines f1_body f2_body + do withCoIndHyp f1 args1 f2 args2 $ mrRefines f1_body f2_body mrRefinesFun k1 k2 -- If we cannot line up f1 and f2, then making progress here would require us @@ -489,7 +505,7 @@ mrRefines' m1@(FunBind f1 args1 k1) m2@(FunBind f2 args2 k2) = -- related to the function call on the other side and k' is related to the -- continuation on the other side, but we don't know how to do that, so give -- up - Nothing -> + _ -> throwError (CompsDoNotRefine m1 m2) {- FIXME: handle FunBind on just one side @@ -499,7 +515,7 @@ mrRefines' m1@(FunBind f@(GlobalName _) args k1) m2 = -- If we have an assumption that f args' refines some rhs, then prove that -- args = args' and then that rhs refines m2 do (assump_args, assump_rhs) <- instantiateFunAssump fassump - zipWithM_ mrProveEq assump_args args + zipWithM_ mrAssertProveEq assump_args args m1' <- normBind assump_rhs k1 mrRefines m1' m2 Nothing -> @@ -516,7 +532,7 @@ mrRefines' m1@(FunBind f1 args1 k1) m2 = -- args1 = args' and then that rhs refines m2 Just fassump -> do (assump_args, assump_rhs) <- instantiateFunAssump fassump - zipWithM_ mrProveEq assump_args args1 + zipWithM_ mrAssertProveEq assump_args args1 m1' <- normBind assump_rhs k1 mrRefines m1' m2 @@ -588,6 +604,26 @@ mrRefinesFun f1 f2 mrRefinesFun _ _ = error "mrRefinesFun: unreachable!" +-- | The result type of 'mrWidenCoIndHyp' +data WidenCoIndHypResult = Convertible | Widened CoIndHyp | CouldNotWiden + +-- | Given a goal and a co-inductive hypothesis over the same pair of function +-- names, try to widen them into a more general co-inductive hypothesis which +-- implies both the given goal and the given co-inductive hypothesis. Returns +-- 'Convertible' if the goal and co-inductive hypothesis are convertible (and +-- therefore no widening needs to be done), 'Widened' if widening was +-- successful, and 'CouldNotWiden' if the terms are neither convertible nor +-- able to be widened. +-- FIXME: Finish implementing this function! +mrWidenCoIndHyp :: FunName -> FunName -> + [Term] -> [Term] -> [Term] -> [Term] -> + MRM WidenCoIndHypResult +mrWidenCoIndHyp _f1 _f2 args1 args2 args1' args2' = + do eq1 <- and <$> zipWithM mrProveEq args1' args1 + eq2 <- and <$> zipWithM mrProveEq args2' args2 + return $ if eq1 && eq2 then Convertible else CouldNotWiden + + ---------------------------------------------------------------------- -- * External Entrypoints ---------------------------------------------------------------------- @@ -602,10 +638,10 @@ askMRSolver :: askMRSolver sc dlvl timeout t1 t2 = do tp1 <- scTypeOf sc t1 >>= scWhnf sc tp2 <- scTypeOf sc t2 >>= scWhnf sc - init_st <- mkMRState sc Map.empty timeout dlvl case asPiList tp1 of (uvar_ctx, asCompM -> Just _) -> - fmap (either Just (const Nothing)) $ runMRM init_st $ + fmap (either Just (const Nothing)) $ + runMRM sc timeout dlvl Map.empty $ withUVars uvar_ctx $ \vars -> do tps_are_eq <- mrConvertible tp1 tp2 if tps_are_eq then return () else