From 92a96c32f7dec1bbee7a214cb1c4791b48e17136 Mon Sep 17 00:00:00 2001 From: Rob Dockins Date: Mon, 26 Apr 2021 12:53:39 -0700 Subject: [PATCH] Add annotations to rewrite rules and simpsets. This allows users to add metadata their rewrite rules, and have rewriting steps record and return the metadata for rules what were actually used in rewriting. The intention is to use these annotations to link rewrite rules back to theorems/axioms that generated them so proof steps can determine the dependencies of a proof. --- cryptol-saw-core/src/Verifier/SAW/Cryptol.hs | 6 +- .../src/Verifier/SAW/Cryptol/Simpset.hs | 2 +- saw-core/src/Verifier/SAW/Constant.hs | 2 +- saw-core/src/Verifier/SAW/Rewriter.hs | 262 +++++++++++------- saw-core/src/Verifier/SAW/SCTypeCheck.hs | 2 +- 5 files changed, 161 insertions(+), 113 deletions(-) diff --git a/cryptol-saw-core/src/Verifier/SAW/Cryptol.hs b/cryptol-saw-core/src/Verifier/SAW/Cryptol.hs index 2e293779..c8824fb2 100644 --- a/cryptol-saw-core/src/Verifier/SAW/Cryptol.hs +++ b/cryptol-saw-core/src/Verifier/SAW/Cryptol.hs @@ -1611,9 +1611,9 @@ scCryptolType sc t = scCryptolEq :: SharedContext -> Term -> Term -> IO Term scCryptolEq sc x y = do rules <- concat <$> traverse defRewrites defs - let ss = addConvs natConversions (addRules rules emptySimpset) - tx <- scTypeOf sc x >>= rewriteSharedTerm sc ss >>= scCryptolType sc - ty <- scTypeOf sc y >>= rewriteSharedTerm sc ss >>= scCryptolType sc + let ss = addConvs natConversions (addRules rules emptySimpset :: Simpset ()) + tx <- scTypeOf sc x >>= rewriteSharedTerm sc ss >>= scCryptolType sc . snd + ty <- scTypeOf sc y >>= rewriteSharedTerm sc ss >>= scCryptolType sc . snd unless (tx == ty) $ panic "scCryptolEq" [ "scCryptolEq: type mismatch between" diff --git a/cryptol-saw-core/src/Verifier/SAW/Cryptol/Simpset.hs b/cryptol-saw-core/src/Verifier/SAW/Cryptol/Simpset.hs index c6a03265..0a7e3578 100644 --- a/cryptol-saw-core/src/Verifier/SAW/Cryptol/Simpset.hs +++ b/cryptol-saw-core/src/Verifier/SAW/Cryptol/Simpset.hs @@ -18,7 +18,7 @@ import Verifier.SAW.Rewriter import Verifier.SAW.SharedTerm import Verifier.SAW.Term.Functor -mkCryptolSimpset :: SharedContext -> IO Simpset +mkCryptolSimpset :: SharedContext -> IO (Simpset a) mkCryptolSimpset sc = do m <- scFindModule sc cryptolModuleName scSimpset sc (cryptolDefs m) [] [] diff --git a/saw-core/src/Verifier/SAW/Constant.hs b/saw-core/src/Verifier/SAW/Constant.hs index 8b124445..53efde1f 100644 --- a/saw-core/src/Verifier/SAW/Constant.hs +++ b/saw-core/src/Verifier/SAW/Constant.hs @@ -15,5 +15,5 @@ import Verifier.SAW.Conversion scConst :: SharedContext -> String -> Term -> IO Term scConst sc name t = do ty <- scTypeOf sc t - ty' <- rewriteSharedTerm sc (addConvs natConversions emptySimpset) ty + (_,ty') <- rewriteSharedTerm sc (addConvs natConversions emptySimpset :: Simpset ()) ty scConstant sc name t ty' diff --git a/saw-core/src/Verifier/SAW/Rewriter.hs b/saw-core/src/Verifier/SAW/Rewriter.hs index 18d18a4f..7591aa7b 100644 --- a/saw-core/src/Verifier/SAW/Rewriter.hs +++ b/saw-core/src/Verifier/SAW/Rewriter.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveFoldable #-} @@ -25,6 +26,7 @@ module Verifier.SAW.Rewriter , ctxtRewriteRule , lhsRewriteRule , rhsRewriteRule + , annRewriteRule , ruleOfTerm , ruleOfTerms , ruleOfProp @@ -62,6 +64,7 @@ import Data.Foldable (Foldable) import Control.Monad.Identity import Control.Monad.State import Control.Monad.Trans.Maybe +import Data.IORef import qualified Data.Foldable as Foldable import Data.Map (Map) import qualified Data.List as List @@ -81,23 +84,31 @@ import Verifier.SAW.Term.Functor import Verifier.SAW.TypedAST import qualified Verifier.SAW.TermNet as Net -data RewriteRule - = RewriteRule { ctxt :: [Term], lhs :: Term, rhs :: Term, permutative :: Bool } - deriving (Eq, Show) +data RewriteRule a + = RewriteRule { ctxt :: [Term], lhs :: Term, rhs :: Term, permutative :: Bool, annotation :: Maybe a } + deriving (Show) -- ^ Invariant: The set of loose variables in @lhs@ must be exactly -- @[0 .. length ctxt - 1]@. The @rhs@ may contain a subset of these. -ctxtRewriteRule :: RewriteRule -> [Term] +-- NB, exclude the annotation from equality tests +instance Eq (RewriteRule a) where + RewriteRule c1 l1 r1 p1 _a1 == RewriteRule c2 l2 r2 p2 _a2 = + c1 == c2 && l1 == l2 && r1 == r2 && p1 == p2 + +ctxtRewriteRule :: RewriteRule a -> [Term] ctxtRewriteRule = ctxt -lhsRewriteRule :: RewriteRule -> Term +lhsRewriteRule :: RewriteRule a -> Term lhsRewriteRule = lhs -rhsRewriteRule :: RewriteRule -> Term +rhsRewriteRule :: RewriteRule a -> Term rhsRewriteRule = rhs -instance Net.Pattern RewriteRule where - toPat (RewriteRule _ lhs _ _) = Net.toPat lhs +annRewriteRule :: RewriteRule a -> Maybe a +annRewriteRule = annotation + +instance Net.Pattern (RewriteRule a) where + toPat (RewriteRule _ lhs _ _ _) = Net.toPat lhs ---------------------------------------------------------------------- -- Matching @@ -284,15 +295,15 @@ equalNatIdent = mkIdent (mkModuleName ["Prelude"]) "equalNat" -- | Converts a universally quantified equality proposition from a -- Term representation to a RewriteRule. -ruleOfTerm :: Term -> RewriteRule -ruleOfTerm t = +ruleOfTerm :: Term -> Maybe a -> RewriteRule a +ruleOfTerm t ann = case unwrapTermF t of -- NOTE: this assumes the Coq-style equality type Eq X x y, where both X -- (the type of x and y) and x are parameters, and y is an index FTermF (DataTypeApp ident [_, x] [y]) - | ident == eqIdent -> mkRewriteRule [] x y + | ident == eqIdent -> mkRewriteRule [] x y ann Pi _ ty body -> rule { ctxt = ty : ctxt rule } - where rule = ruleOfTerm body + where rule = ruleOfTerm body ann _ -> error "ruleOfSharedTerm: Illegal argument" -- Test whether a rewrite rule is permutative @@ -306,46 +317,46 @@ rulePermutes lhs rhs = Nothing -> False -- but here we have a looping rule, not good! Just _ -> True -mkRewriteRule :: [Term] -> Term -> Term -> RewriteRule -mkRewriteRule c l r = - RewriteRule {ctxt = c, lhs = l, rhs = r , permutative = rulePermutes l r} +mkRewriteRule :: [Term] -> Term -> Term -> Maybe a -> RewriteRule a +mkRewriteRule c l r ann = + RewriteRule {ctxt = c, lhs = l, rhs = r , permutative = rulePermutes l r, annotation = ann} -- | Converts a universally quantified equality proposition between the -- two given terms to a RewriteRule. -ruleOfTerms :: Term -> Term -> RewriteRule -ruleOfTerms l r = mkRewriteRule [] l r +ruleOfTerms :: Term -> Term -> RewriteRule a +ruleOfTerms l r = mkRewriteRule [] l r Nothing -- | Converts a parameterized equality predicate to a RewriteRule, -- returning 'Nothing' if the predicate is not an equation. -ruleOfProp :: Term -> Maybe RewriteRule -ruleOfProp (R.asPi -> Just (_, ty, body)) = - do rule <- ruleOfProp body +ruleOfProp :: Term -> Maybe a -> Maybe (RewriteRule a) +ruleOfProp (R.asPi -> Just (_, ty, body)) ann = + do rule <- ruleOfProp body ann Just rule { ctxt = ty : ctxt rule } -ruleOfProp (R.asLambda -> Just (_, ty, body)) = - do rule <- ruleOfProp body +ruleOfProp (R.asLambda -> Just (_, ty, body)) ann = + do rule <- ruleOfProp body ann Just rule { ctxt = ty : ctxt rule } -ruleOfProp (R.asApplyAll -> (R.isGlobalDef ecEqIdent -> Just (), [_, _, x, y])) = - Just $ mkRewriteRule [] x y -ruleOfProp (R.asApplyAll -> (R.isGlobalDef bvEqIdent -> Just (), [_, x, y])) = - Just $ mkRewriteRule [] x y -ruleOfProp (R.asApplyAll -> (R.isGlobalDef equalNatIdent -> Just (), [x, y])) = - Just $ mkRewriteRule [] x y -ruleOfProp (R.asApplyAll -> (R.isGlobalDef boolEqIdent -> Just (), [x, y])) = - Just $ mkRewriteRule [] x y -ruleOfProp (R.asApplyAll -> (R.isGlobalDef vecEqIdent -> Just (), [_, _, _, x, y])) = - Just $ mkRewriteRule [] x y -ruleOfProp (unwrapTermF -> Constant _ body) = ruleOfProp body -ruleOfProp (R.asEq -> Just (_, x, y)) = - Just $ mkRewriteRule [] x y -ruleOfProp (R.asEqTrue -> Just body) = ruleOfProp body -ruleOfProp _ = Nothing +ruleOfProp (R.asApplyAll -> (R.isGlobalDef ecEqIdent -> Just (), [_, _, x, y])) ann = + Just $ mkRewriteRule [] x y ann +ruleOfProp (R.asApplyAll -> (R.isGlobalDef bvEqIdent -> Just (), [_, x, y])) ann = + Just $ mkRewriteRule [] x y ann +ruleOfProp (R.asApplyAll -> (R.isGlobalDef equalNatIdent -> Just (), [x, y])) ann = + Just $ mkRewriteRule [] x y ann +ruleOfProp (R.asApplyAll -> (R.isGlobalDef boolEqIdent -> Just (), [x, y])) ann = + Just $ mkRewriteRule [] x y ann +ruleOfProp (R.asApplyAll -> (R.isGlobalDef vecEqIdent -> Just (), [_, _, _, x, y])) ann = + Just $ mkRewriteRule [] x y ann +ruleOfProp (unwrapTermF -> Constant _ body) ann = ruleOfProp body ann +ruleOfProp (R.asEq -> Just (_, x, y)) ann = + Just $ mkRewriteRule [] x y ann +ruleOfProp (R.asEqTrue -> Just body) ann = ruleOfProp body ann +ruleOfProp _ _ = Nothing -- | Generate a rewrite rule from the type of an identifier, using 'ruleOfTerm' -scEqRewriteRule :: SharedContext -> Ident -> IO RewriteRule -scEqRewriteRule sc i = ruleOfTerm <$> scTypeOfGlobal sc i +scEqRewriteRule :: SharedContext -> Ident -> IO (RewriteRule a) +scEqRewriteRule sc i = ruleOfTerm <$> scTypeOfGlobal sc i <*> pure Nothing -- | Collects rewrite rules from named constants, whose types must be equations. -scEqsRewriteRules :: SharedContext -> [Ident] -> IO [RewriteRule] +scEqsRewriteRules :: SharedContext -> [Ident] -> IO [RewriteRule a] scEqsRewriteRules sc = mapM (scEqRewriteRule sc) -- | Transform the given rewrite rule to a set of one or more @@ -354,19 +365,19 @@ scEqsRewriteRules sc = mapM (scEqRewriteRule sc) -- * If the rhs is a lambda, then add an argument to the lhs. -- * If the rhs is a recursor, then split into a separate rule for each constructor. -- * If the rhs is a record, then split into a separate rule for each accessor. -scExpandRewriteRule :: SharedContext -> RewriteRule -> IO (Maybe [RewriteRule]) -scExpandRewriteRule sc (RewriteRule ctxt lhs rhs _) = +scExpandRewriteRule :: SharedContext -> RewriteRule a -> IO (Maybe [RewriteRule a]) +scExpandRewriteRule sc (RewriteRule ctxt lhs rhs _ ann) = case rhs of (R.asLambda -> Just (_, ty, body)) -> do let ctxt' = ctxt ++ [ty] lhs1 <- incVars sc 0 1 lhs var0 <- scLocalVar sc 0 lhs' <- scApply sc lhs1 var0 - return $ Just [mkRewriteRule ctxt' lhs' body] + return $ Just [mkRewriteRule ctxt' lhs' body ann] (R.asRecordValue -> Just m) -> do let mkRule (k, x) = do l <- scRecordSelect sc lhs k - return (mkRewriteRule ctxt l x) + return (mkRewriteRule ctxt l x ann) Just <$> traverse mkRule (Map.assocs m) (R.asApplyAll -> (R.asRecursorApp -> Just (d, params, p_ret, cs_fs, _ixs, R.asLocalVar -> Just i), @@ -407,9 +418,9 @@ scExpandRewriteRule sc (RewriteRule ctxt lhs rhs _) = rhs2 <- scApplyAll sc rhs1 more' rhs3 <- betaReduce rhs2 -- re-fold recursive occurrences of the original rhs - let ss = addRule (mkRewriteRule ctxt rhs lhs) emptySimpset - rhs' <- rewriteSharedTerm sc ss rhs3 - return (mkRewriteRule ctxt' lhs' rhs') + let ss = addRule (mkRewriteRule ctxt rhs lhs Nothing) emptySimpset + (_,rhs') <- rewriteSharedTerm sc (ss :: Simpset ()) rhs3 + return (mkRewriteRule ctxt' lhs' rhs' ann) dt <- scRequireDataType sc d rules <- traverse ctorRule (dtCtors dt) return (Just rules) @@ -433,7 +444,7 @@ scExpandRewriteRule sc (RewriteRule ctxt lhs rhs _) = Just (_, _, body) -> instantiateVar sc 0 arg body -- | Repeatedly apply the rule transformations in 'scExpandRewriteRule'. -scExpandRewriteRules :: SharedContext -> [RewriteRule] -> IO [RewriteRule] +scExpandRewriteRules :: SharedContext -> [RewriteRule a] -> IO [RewriteRule a] scExpandRewriteRules sc rs = case rs of [] -> return [] @@ -445,12 +456,12 @@ scExpandRewriteRules sc rs = -- | Create a rewrite rule for a definition that expands the definition, if it -- has a body to expand to, otherwise return the empty list -scDefRewriteRules :: SharedContext -> Def -> IO [RewriteRule] +scDefRewriteRules :: SharedContext -> Def -> IO [RewriteRule a] scDefRewriteRules _ (Def { defBody = Nothing }) = return [] scDefRewriteRules sc (Def { defIdent = ident, defBody = Just body }) = do lhs <- scGlobalDef sc ident rhs <- scSharedTerm sc body - scExpandRewriteRules sc [mkRewriteRule [] lhs rhs] + scExpandRewriteRules sc [mkRewriteRule [] lhs rhs Nothing] ---------------------------------------------------------------------- @@ -458,40 +469,40 @@ scDefRewriteRules sc (Def { defIdent = ident, defBody = Just body }) = -- | Invariant: 'Simpset's should not contain reflexive rules. We avoid -- adding them in 'addRule' below. -type Simpset = Net.Net (Either RewriteRule Conversion) +type Simpset a = Net.Net (Either (RewriteRule a) Conversion) -emptySimpset :: Simpset +emptySimpset :: Simpset a emptySimpset = Net.empty -addRule :: RewriteRule -> Simpset -> Simpset +addRule :: RewriteRule a -> Simpset a -> Simpset a addRule rule | lhs rule /= rhs rule = Net.insert_term (lhs rule, Left rule) | otherwise = id -delRule :: RewriteRule -> Simpset -> Simpset +delRule :: RewriteRule a -> Simpset a -> Simpset a delRule rule = Net.delete_term (lhs rule, Left rule) -addRules :: [RewriteRule] -> Simpset -> Simpset +addRules :: [RewriteRule a] -> Simpset a -> Simpset a addRules rules ss = foldr addRule ss rules -addSimp :: Term -> Simpset -> Simpset -addSimp prop = addRule (ruleOfTerm prop) +addSimp :: Term -> Maybe a -> Simpset a -> Simpset a +addSimp prop ann = addRule (ruleOfTerm prop ann) -delSimp :: Term -> Simpset -> Simpset -delSimp prop = delRule (ruleOfTerm prop) +delSimp :: Term -> Simpset a -> Simpset a +delSimp prop = delRule (ruleOfTerm prop Nothing) -addConv :: Conversion -> Simpset -> Simpset +addConv :: Conversion -> Simpset a -> Simpset a addConv conv = Net.insert_term (conv, Right conv) -addConvs :: [Conversion] -> Simpset -> Simpset +addConvs :: [Conversion] -> Simpset a -> Simpset a addConvs convs ss = foldr addConv ss convs -scSimpset :: SharedContext -> [Def] -> [Ident] -> [Conversion] -> IO Simpset +scSimpset :: SharedContext -> [Def] -> [Ident] -> [Conversion] -> IO (Simpset a) scSimpset sc defs eqIdents convs = do defRules <- concat <$> traverse (scDefRewriteRules sc) defs eqRules <- mapM (scEqRewriteRule sc) eqIdents return $ addRules defRules $ addRules eqRules $ addConvs convs $ emptySimpset -listRules :: Simpset -> [RewriteRule] +listRules :: Simpset a -> [RewriteRule a] listRules ss = [ r | Left r <- Net.content ss ] ---------------------------------------------------------------------- @@ -558,7 +569,7 @@ appCollectedArgs t = step0 (unshared t) [] -- step2: analyse an arg. look inside tuples, sequences (TBD), more calls to f step2 :: TermF Term -> TermF Term -> [Term] step2 f (FTermF (PairValue x y)) = (step2 f $ unshared x) ++ (step2 f $ unshared y) - step2 f (s@(App g a)) = possibly_curried_args s f (unshared g) (step2 f $ unshared a) + step2 f (s@(App g a)) = possibly_curried_args s f (unshared g) (step2 f $ unshared a) step2 _ a = [Unshared a] -- possibly_curried_args :: TermF Term -> TermF Term -> TermF Term -> [Term] -> [Term] @@ -580,29 +591,43 @@ reduceSharedTerm sc (asIotaRedex -> Just (d, params, p_ret, cs_fs, c, args)) = Just $ scReduceRecursor sc d params p_ret cs_fs c args reduceSharedTerm _ _ = Nothing --- | Rewriter for shared terms -rewriteSharedTerm :: SharedContext -> Simpset -> Term -> IO Term +-- | Rewriter for shared terms. The annotations of any used rules are collected +-- and returned in the result set. +rewriteSharedTerm :: forall a. Ord a => SharedContext -> Simpset a -> Term -> IO (Set a, Term) rewriteSharedTerm sc ss t0 = do cache <- newCache - let ?cache = cache in rewriteAll t0 + let ?cache = cache + setRef <- newIORef mempty + let ?annSet = setRef + t <- rewriteAll t0 + anns <- readIORef setRef + pure (anns, t) + where - rewriteAll :: (?cache :: Cache IO TermIndex Term) => Term -> IO Term + rewriteAll :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => Term -> IO Term rewriteAll (Unshared tf) = traverseTF rewriteAll tf >>= scTermF sc >>= rewriteTop rewriteAll STApp{ stAppIndex = tidx, stAppTermF = tf } = useCache ?cache tidx (traverseTF rewriteAll tf >>= scTermF sc >>= rewriteTop) - traverseTF :: (a -> IO a) -> TermF a -> IO (TermF a) + + traverseTF :: forall b. (b -> IO b) -> TermF b -> IO (TermF b) traverseTF _ tf@(Constant {}) = pure tf traverseTF f tf = traverse f tf - rewriteTop :: (?cache :: Cache IO TermIndex Term) => Term -> IO Term + + rewriteTop :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => Term -> IO Term rewriteTop t = case reduceSharedTerm sc t of Nothing -> apply (Net.unify_term ss t) t Just io -> rewriteAll =<< io - apply :: (?cache :: Cache IO TermIndex Term) => - [Either RewriteRule Conversion] -> Term -> IO Term + + recordAnn :: (?annSet :: IORef (Set a)) => Maybe a -> IO () + recordAnn Nothing = return () + recordAnn (Just a) = modifyIORef' ?annSet (Set.insert a) + + apply :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => + [Either (RewriteRule a) Conversion] -> Term -> IO Term apply [] t = return t - apply (Left (RewriteRule {ctxt, lhs, rhs, permutative}) : rules) t = do + apply (Left (RewriteRule {ctxt, lhs, rhs, permutative, annotation}) : rules) t = do result <- scMatch sc lhs t case result of Nothing -> apply rules t @@ -621,11 +646,12 @@ rewriteSharedTerm sc ss t0 = do t' <- instantiateVarList sc 0 (Map.elems inst) rhs case termWeightLt t' t of - True -> rewriteAll t' -- keep the result only if it is "smaller" + True -> recordAnn annotation >> rewriteAll t' -- keep the result only if it is "smaller" False -> apply rules t | otherwise -> do -- putStrLn "REWRITING:" -- print lhs + recordAnn annotation rewriteAll =<< instantiateVarList sc 0 (Map.elems inst) rhs apply (Right conv : rules) t = do -- putStrLn "REWRITING:" @@ -635,20 +661,27 @@ rewriteSharedTerm sc ss t0 = Just tb -> rewriteAll =<< runTermBuilder tb (scGlobalDef sc) (scTermF sc) -- | Type-safe rewriter for shared terms -rewriteSharedTermTypeSafe - :: SharedContext -> Simpset -> Term -> IO Term +rewriteSharedTermTypeSafe :: forall a. Ord a => + SharedContext -> Simpset a -> Term -> IO (Set a, Term) rewriteSharedTermTypeSafe sc ss t0 = do cache <- newCache - let ?cache = cache in rewriteAll t0 + let ?cache = cache + annRef <- newIORef mempty + let ?annSet = annRef + t <- rewriteAll t0 + anns <- readIORef annRef + return (anns, t) + where - rewriteAll :: (?cache :: Cache IO TermIndex Term) => + rewriteAll :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => Term -> IO Term rewriteAll (Unshared tf) = rewriteTermF tf >>= scTermF sc >>= rewriteTop rewriteAll STApp{ stAppIndex = tidx, stAppTermF = tf } = -- putStrLn "Rewriting term:" >> print t >> useCache ?cache tidx (rewriteTermF tf >>= scTermF sc >>= rewriteTop) - rewriteTermF :: (?cache :: Cache IO TermIndex Term) => + + rewriteTermF :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => TermF Term -> IO (TermF Term) rewriteTermF tf = case tf of @@ -663,7 +696,8 @@ rewriteSharedTermTypeSafe sc ss t0 = Lambda pat t e -> Lambda pat t <$> rewriteAll e Constant{} -> return tf _ -> return tf -- traverse rewriteAll tf - rewriteFTermF :: (?cache :: Cache IO TermIndex Term) => + + rewriteFTermF :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => FlatTermF Term -> IO (FlatTermF Term) rewriteFTermF ftf = case ftf of @@ -690,24 +724,33 @@ rewriteSharedTermTypeSafe sc ss t0 = Primitive{} -> return ftf StringLit{} -> return ftf ExtCns{} -> return ftf - rewriteTop :: (?cache :: Cache IO TermIndex Term) => + + rewriteTop :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => Term -> IO Term rewriteTop t = apply (Net.match_term ss t) t - apply :: (?cache :: Cache IO TermIndex Term) => - [Either RewriteRule Conversion] -> + + recordAnn :: (?annSet :: IORef (Set a)) => Maybe a -> IO () + recordAnn Nothing = return () + recordAnn (Just a) = modifyIORef' ?annSet (Set.insert a) + + apply :: (?cache :: Cache IO TermIndex Term, ?annSet :: IORef (Set a)) => + [Either (RewriteRule a) Conversion] -> Term -> IO Term apply [] t = return t apply (Left rule : rules) t = case first_order_match (lhs rule) t of Nothing -> apply rules t - Just inst -> rewriteAll =<< instantiateVarList sc 0 (Map.elems inst) (rhs rule) + Just inst -> + do recordAnn (annotation rule) + rewriteAll =<< instantiateVarList sc 0 (Map.elems inst) (rhs rule) apply (Right conv : rules) t = case runConversion conv t of Nothing -> apply rules t Just tb -> rewriteAll =<< runTermBuilder tb (scGlobalDef sc) (scTermF sc) -- | Generate a new SharedContext that normalizes terms as it builds them. -rewritingSharedContext :: SharedContext -> Simpset -> SharedContext +-- Rule annotations are ignored. +rewritingSharedContext :: SharedContext -> Simpset a -> SharedContext rewritingSharedContext sc ss = sc' where sc' = sc { scTermF = rewriteTop } @@ -722,11 +765,11 @@ rewritingSharedContext sc ss = sc' Nothing -> apply (Net.match_term ss t) t where t = Unshared tf - apply :: [Either RewriteRule Conversion] -> + apply :: [Either (RewriteRule a) Conversion] -> Term -> IO Term apply [] (Unshared tf) = scTermF sc tf apply [] STApp{ stAppTermF = tf } = scTermF sc tf - apply (Left (RewriteRule _ l r _) : rules) t = + apply (Left (RewriteRule _ l r _ _ann) : rules) t = case first_order_match l t of Nothing -> apply rules t Just inst @@ -742,11 +785,12 @@ rewritingSharedContext sc ss = sc' -- FIXME: is there some way to have sensable term replacement in the presence of loose variables -- and/or under binders? -replaceTerm :: SharedContext - -> Simpset -- ^ A simpset of rewrite rules to apply along with the replacement - -> (Term, Term) -- ^ (pat,repl) is a tuple of a pattern term to replace and a replacement term - -> Term -- ^ the term in which to perform the replacement - -> IO Term +replaceTerm :: Ord a => + SharedContext -> + Simpset a {- ^ A simpset of rewrite rules to apply along with the replacement -} -> + (Term, Term) {- ^ (pat,repl) is a tuple of a pattern term to replace and a replacement term -} -> + Term {- ^ the term in which to perform the replacement -} -> + IO (Set a, Term) replaceTerm sc ss (pat, repl) t = do let fvs = looseVars pat unless (fvs == emptyBitSet) $ fail $ unwords @@ -783,7 +827,7 @@ hoistIfs sc t = do `app` (scLocalVar sc 3) - rules <- map ruleOfTerm <$> mapM (scTypeOfGlobal sc) + rules <- map (\rt -> ruleOfTerm rt Nothing) <$> mapM (scTypeOfGlobal sc) [ "Prelude.ite_true" , "Prelude.ite_false" , "Prelude.ite_not" @@ -807,38 +851,42 @@ hoistIfs sc t = do , "Prelude.not_or" , "Prelude.not_and" ] - let ss = addRules rules emptySimpset + let ss :: Simpset () = addRules rules emptySimpset - (t', conds) <- doHoistIfs sc ss cache itePat =<< rewriteSharedTerm sc ss t + (t', conds) <- doHoistIfs sc ss cache itePat . snd =<< rewriteSharedTerm sc ss t splitConds sc ss (map fst conds) t' -splitConds :: SharedContext -> Simpset -> [Term] -> Term -> IO Term -splitConds _ _ [] = return -splitConds sc ss (c:cs) = splitCond sc ss c >=> splitConds sc ss cs +splitConds :: Ord a => SharedContext -> Simpset a -> [Term] -> Term -> IO Term +splitConds sc ss = go + where + go [] t = return t + go (c:cs) t = go cs =<< splitCond sc ss c t -splitCond :: SharedContext -> Simpset -> Term -> Term -> IO Term +splitCond :: Ord a => SharedContext -> Simpset a -> Term -> Term -> IO Term splitCond sc ss c t = do ty <- scTypeOf sc t trueTerm <- scBool sc True falseTerm <- scBool sc False - then_branch <- replaceTerm sc ss (c, trueTerm) t - else_branch <- replaceTerm sc ss (c, falseTerm) t + (_,then_branch) <- replaceTerm sc ss (c, trueTerm) t + (_,else_branch) <- replaceTerm sc ss (c, falseTerm) t scGlobalApply sc "Prelude.ite" [ty, c, then_branch, else_branch] + type HoistIfs s = (Term, [(Term, Set (ExtCns Term))]) orderTerms :: SharedContext -> [Term] -> IO [Term] orderTerms _sc xs = return $ List.sort xs -doHoistIfs :: SharedContext - -> Simpset - -> Cache IO TermIndex (HoistIfs s) - -> Term - -> Term - -> IO (HoistIfs s) +doHoistIfs :: Ord a => + SharedContext -> + Simpset a -> + Cache IO TermIndex (HoistIfs s) -> + Term -> + Term -> + IO (HoistIfs s) doHoistIfs sc ss hoistCache itePat = go where go :: Term -> IO (HoistIfs s) diff --git a/saw-core/src/Verifier/SAW/SCTypeCheck.hs b/saw-core/src/Verifier/SAW/SCTypeCheck.hs index 029e5eca..c67fa6a1 100644 --- a/saw-core/src/Verifier/SAW/SCTypeCheck.hs +++ b/saw-core/src/Verifier/SAW/SCTypeCheck.hs @@ -533,7 +533,7 @@ typeCheckWHNF = liftTCM scTypeCheckWHNF -- | The 'IO' version of 'typeCheckWHNF' scTypeCheckWHNF :: SharedContext -> Term -> IO Term scTypeCheckWHNF sc t = - do t' <- rewriteSharedTerm sc (addConvs natConversions emptySimpset) t + do (_, t') <- rewriteSharedTerm sc (addConvs natConversions emptySimpset :: Simpset ()) t scWhnf sc t' -- | Check that one type is a subtype of another, assuming both arguments are