diff --git a/liquid-fixpoint.cabal b/liquid-fixpoint.cabal index d42136211..c03821a09 100644 --- a/liquid-fixpoint.cabal +++ b/liquid-fixpoint.cabal @@ -100,6 +100,7 @@ library Language.Fixpoint.Solver.UniqifyKVars Language.Fixpoint.Solver.Worklist Language.Fixpoint.SortCheck + Language.Fixpoint.Union Language.Fixpoint.Types Language.Fixpoint.Types.Config Language.Fixpoint.Types.Constraints diff --git a/src/Language/Fixpoint/Horn/Transformations.hs b/src/Language/Fixpoint/Horn/Transformations.hs index 0a71c20f8..7b79d3231 100644 --- a/src/Language/Fixpoint/Horn/Transformations.hs +++ b/src/Language/Fixpoint/Horn/Transformations.hs @@ -543,7 +543,7 @@ elimKs' (k:ks) (noside, side) = elimKs' (trace ("solved kvar " <> F.showpp k <> -- exists in the positive positions (which will stay exists when we go to -- prenex) may give us a lot of trouble during _quantifier elimination_ -- tx :: F.Symbol -> [[Bind]] -> Pred -> Pred --- tx k bss = trans (defaultVisitor { txExpr = existentialPackage, ctxExpr = ctxKV }) M.empty () +-- tx k bss = trans (defaultFolder { txExpr = existentialPackage, ctxExpr = ctxKV }) M.empty () -- where -- splitBinds xs = unzip $ (\(Bind x t p) -> ((x,t),p)) <$> xs -- cubeSol su (Bind _ _ (Reft eqs):xs) @@ -564,16 +564,16 @@ elimKs' (k:ks) (noside, side) = elimKs' (trace ("solved kvar " <> F.showpp k <> -- ctxKV m _ = m -- Visitor only visit Exprs in Pred! -instance V.Visitable Pred where - visit v c (PAnd ps) = PAnd <$> mapM (visit v c) ps - visit v c (Reft e) = Reft <$> visit v c e - visit _ _ var = pure var - -instance V.Visitable (Cstr a) where - visit v c (CAnd cs) = CAnd <$> mapM (visit v c) cs - visit v c (Head p a) = Head <$> visit v c p <*> pure a - visit v ctx (All (Bind x t p l) c) = All <$> (Bind x t <$> visit v ctx p <*> pure l) <*> visit v ctx c - visit v ctx (Any (Bind x t p l) c) = All <$> (Bind x t <$> visit v ctx p <*> pure l) <*> visit v ctx c +instance V.Foldable Pred where + foldE v c (PAnd ps) = PAnd <$> mapM (foldE v c) ps + foldE v c (Reft e) = Reft <$> foldE v c e + foldE _ _ var = pure var + +instance V.Foldable (Cstr a) where + foldE v c (CAnd cs) = CAnd <$> mapM (foldE v c) cs + foldE v c (Head p a) = Head <$> foldE v c p <*> pure a + foldE v ctx (All (Bind x t p l) c) = All <$> (Bind x t <$> foldE v ctx p <*> pure l) <*> foldE v ctx c + foldE v ctx (Any (Bind x t p l) c) = All <$> (Bind x t <$> foldE v ctx p <*> pure l) <*> foldE v ctx c ------------------------------------------------------------------------------ -- | Quantifier elimination for use with implicit solver diff --git a/src/Language/Fixpoint/Solver/Sanitize.hs b/src/Language/Fixpoint/Solver/Sanitize.hs index 15ac71063..1360591c2 100644 --- a/src/Language/Fixpoint/Solver/Sanitize.hs +++ b/src/Language/Fixpoint/Solver/Sanitize.hs @@ -51,8 +51,6 @@ sanitize cfg = banIrregularData >=> banConstraintFreeVars cfg >=> Misc.fM addLiterals >=> Misc.fM (eliminateEta cfg) - >=> Misc.fM cancelCoercion - -------------------------------------------------------------------------------- -- | 'dropAdtMeasures' removes all the measure definitions that correspond to @@ -83,16 +81,6 @@ addLiterals si = si { F.dLits = F.unionSEnv (F.dLits si) lits' where lits' = M.fromList [ (F.symbol x, F.strSort) | x <- symConsts si ] - - -cancelCoercion :: F.SInfo a -> F.SInfo a -cancelCoercion = mapExpr (trans (defaultVisitor { txExpr = go }) () ()) - where - go _ (F.ECoerc t1 t2 (F.ECoerc t2' t1' e)) - | t1 == t1' && t2 == t2' - = e - go _ e = e - -------------------------------------------------------------------------------- -- | `eliminateEta` converts equations of the form f x = g x into f = g -------------------------------------------------------------------------------- diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index 42da843eb..d66740b50 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -75,7 +75,7 @@ import qualified Data.HashMap.Strict as M import qualified Data.HashSet as S import Data.IORef import qualified Data.List as L -import Data.Maybe (mapMaybe, fromMaybe, catMaybes, isJust) +import Data.Maybe (mapMaybe, fromMaybe, isJust) import Language.Fixpoint.Types.PrettyPrint import Language.Fixpoint.Misc @@ -88,6 +88,8 @@ import Text.Printf import GHC.Stack import qualified Language.Fixpoint.Types as F import System.IO.Unsafe (unsafePerformIO) +import qualified Language.Fixpoint.Union as Union +import Language.Fixpoint.Union (UF) --import Debug.Trace as Debug @@ -273,11 +275,15 @@ elabExpr msg env e = case elabExprE msg env e of elabExprE :: Located String -> SymEnv -> Expr -> Either Error Expr elabExprE msg env e = - case runCM0 (srcSpan msg) (elab (env, envLookup) e) of + case runCM0 (srcSpan msg) $ do + (!e', _) <- elab (env, envLookup) e + fufRef <- asks ufM + finalUF <- liftIO $ readIORef fufRef + return (applyExprUF finalUF e') of Left (ChError f') -> let e' = f' () in Left $ err (srcSpan e') (d (val e')) - Right s -> Right (fst s) + Right s -> Right s where sEnv = seSort env envLookup = (`lookupSEnvWithDistance` sEnv) @@ -371,7 +377,7 @@ instance Show ChError where show (ChError f) = show (f ()) instance Exception ChError where -data ChState = ChS { chCount :: IORef Int, chSpan :: SrcSpan } +data ChState = ChS {chCount :: IORef Int, chSpan :: SrcSpan, ufM :: IORef Union.UF} type Env = Symbol -> SESearch Sort type ElabEnv = (SymEnv, Env) @@ -406,7 +412,8 @@ varCounterRef = unsafePerformIO $ newIORef 42 -- value of counter. runCM0 :: SrcSpan -> CheckM a -> Either ChError a runCM0 sp act = unsafePerformIO $ do - try (runReaderT act (ChS varCounterRef sp)) + ufR <- newIORef Union.new + try (runReaderT act (ChS varCounterRef sp ufR)) fresh :: CheckM Int fresh = do @@ -524,143 +531,151 @@ addEnv f bs x Just s -> Found s Nothing -> f x +-------------------------------------------------------------------------------- -------------------------------------------------------------------------------- -- | Elaborate expressions with types to make polymorphic instantiation explicit. -------------------------------------------------------------------------------- {-# SCC elab #-} elab :: ElabEnv -> Expr -> CheckM (Expr, Sort) -------------------------------------------------------------------------------- -elab f@(_, g) e@(EBin o e1 e2) = do - (e1', s1) <- elab f e1 - (e2', s2) <- elab f e2 - s <- checkOpTy g e s1 s2 - return (EBin o (eCst e1' s1) (eCst e2' s2), s) - -elab f (EApp e1@(EApp _ _) e2) = do - (e1', _, e2', s2, s) <- notracepp "ELAB-EAPP" <$> elabEApp f e1 e2 - let e = eAppC s e1' (eCst e2' s2) - let θ = unifyExpr (snd f) e - return (applyExpr θ e, maybe s (`apply` s) θ) - -elab f (EApp e1 e2) = do - (e1', s1, e2', s2, s) <- elabEApp f e1 e2 - let e = eAppC s (eCst e1' s1) (eCst e2' s2) - let θ = unifyExpr (snd f) e - return (applyExpr θ e, maybe s (`apply` s) θ) - -elab _ e@(ESym _) = +elab f@(!_, !g) e@(EBin !o !e1 !e2) = do + (!e1', !s1) <- elab f e1 + (!e2', !s2) <- elab f e2 + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifyUF g uf (Just e) s1 s2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + let !result = EBin o (eCst e1' s1) (eCst e2' s2) + return (result, s2) + + +elab !f (EApp !e1 !e2) = do + (!e1', !s1, !e2', !s2, !s) <- elabEApp f e1 e2 + let !e = eAppC s (eCst e1' s1) (eCst e2' s2) + return (e, s) + + +elab !_ e@(ESym _) = return (e, strSort) -elab _ e@(ECon (I _)) = +elab !_ e@(ECon (I _)) = return (e, FInt) -elab _ e@(ECon (R _)) = +elab !_ e@(ECon (R _)) = return (e, FReal) -elab _ e@(ECon (L _ s)) = +elab !_ e@(ECon (L _ !s)) = return (e, s) -elab _ e@(PKVar _ _) = +elab !_ e@(PKVar _ _) = return (e, boolSort) -elab f (PGrad k su i e) = - (, boolSort) . PGrad k su i . fst <$> elab f e -elab (_, f) e@(EVar x) = do - cs <- checkSym f x - pure (e, cs) +elab !f (PGrad !k !su !i !e) = do + (!e', !_) <- elab f e + return (PGrad k su i e', boolSort) -elab f (ENeg e) = do - (e', s) <- elab f e +elab (!_, !f) e@(EVar !x) = do + !cs <- checkSym f x + return (e, cs) + +elab !f (ENeg !e) = do + (!e', !s) <- elab f e return (ENeg e', s) -elab f@(_,g) (ECst (EIte p e1 e2) t) = do - (p', _) <- elab f p - (e1', s1) <- elab f (eCst e1 t) - (e2', s2) <- elab f (eCst e2 t) - s <- checkIteTy g p e1' e2' s1 s2 - return (EIte p' (eCst e1' s) (eCst e2' s), t) - -elab f@(_,g) (EIte p e1 e2) = do - t <- getIte g e1 e2 - (p', _) <- elab f p - (e1', s1) <- elab f (eCst e1 t) - (e2', s2) <- elab f (eCst e2 t) - s <- checkIteTy g p e1' e2' s1 s2 - return (EIte p' (eCst e1' s) (eCst e2' s), s) - -elab f (ECst e t) = do - (e', _) <- elab f e +elab f@(!_,!g) (ECst (EIte !p !e1 !e2) !t) = do + ufRef <- asks ufM + (!p', !_) <- elab f p + (!e1', !s1) <- elab f (eCst e1 t) + (!e2', !s2) <- elab f (eCst e2 t) + uf <- liftIO $ readIORef ufRef + (!_, !uf') <- checkIteTyUF g uf p e1' e2' s1 s2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return (EIte p' (eCst e1' s1) (eCst e2' s2), t) + +elab f@(!_,!g) (EIte !p !e1 !e2) = do + !t <- getIteUF g e1 e2 + (!p', !_) <- elab f p + (!e1', !s1) <- elab f (eCst e1 t) + (!e2', !s2) <- elab f (eCst e2 t) + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + (_, uf''') <- checkIteTyUF g uf p e1' e2' s1 s2 + liftIO $ atomicModifyIORef' ufRef $ const (uf''', ()) + return (EIte p' (eCst e1' s1) (eCst e2' s2), s2) + + +elab !f (ECst !e !t) = do + (!e', !_) <- elab f e return (eCst e' t, t) -elab f (PNot p) = do - (e', _) <- elab f p +elab !f (PNot !p) = do + (!e', !_) <- elab f p return (PNot e', boolSort) -elab f (PImp p1 p2) = do - (p1', _) <- elab f p1 - (p2', _) <- elab f p2 +elab !f (PImp !p1 !p2) = do + (!p1', !_) <- elab f p1 + (!p2', !_) <- elab f p2 return (PImp p1' p2', boolSort) -elab f (PIff p1 p2) = do - (p1', _) <- elab f p1 - (p2', _) <- elab f p2 +elab !f (PIff !p1 !p2) = do + (!p1', !_) <- elab f p1 + (!p2', !_) <- elab f p2 return (PIff p1' p2', boolSort) -elab f (PAnd ps) = do - ps' <- mapM (elab f) ps +elab !f (PAnd !ps) = do + !ps' <- mapM (elab f) ps return (PAnd (fst <$> ps'), boolSort) -elab f (POr ps) = do - ps' <- mapM (elab f) ps +elab !f (POr !ps) = do + !ps' <- mapM (elab f) ps return (POr (fst <$> ps'), boolSort) -elab f@(_,g) e@(PAtom eq e1 e2) | eq == Eq || eq == Ne = do - t1 <- checkExpr g e1 - t2 <- checkExpr g e2 - (t1',t2') <- unite g e t1 t2 `withError` errElabExpr e - e1' <- elabAs f t1' e1 - e2' <- elabAs f t2' e2 - e1'' <- eCstAtom f e1' t1' - e2'' <- eCstAtom f e2' t2' - return (PAtom eq e1'' e2'' , boolSort) - -elab f (PAtom r e1 e2) +elab f@(!_,!g) e@(PAtom !eq !e1 !e2) | eq == Eq || eq == Ne = do + !t1 <- checkExpr g e1 + !t2 <- checkExpr g e2 + (!t1',!t2') <- unite g e t1 t2 `withError` errElabExpr e + !e1' <- elabAs f t1' e1 + !e2' <- elabAs f t2' e2 + !e1'' <- eCstAtom f e1' t1' + !e2'' <- eCstAtom f e2' t2' + return (PAtom eq e1'' e2'', boolSort) + +elab !f (PAtom !r !e1 !e2) | r == Ueq || r == Une = do - (e1', _) <- elab f e1 - (e2', _) <- elab f e2 + (!e1', !_) <- elab f e1 + (!e2', !_) <- elab f e2 return (PAtom r e1' e2', boolSort) -elab f@(env,_) (PAtom r e1 e2) = do - e1' <- uncurry (toInt env) <$> elab f e1 - e2' <- uncurry (toInt env) <$> elab f e2 +elab f@(!env,!_) (PAtom !r !e1 !e2) = do + !e1' <- uncurry (toInt env) <$> elab f e1 + !e2' <- uncurry (toInt env) <$> elab f e2 return (PAtom r e1' e2', boolSort) -elab f (PExist bs e) = do - (e', s) <- elab (elabAddEnv f bs) e - let bs' = elaborate "PExist Args" mempty bs +elab !f (PExist !bs !e) = do + (!e', !s) <- elab (elabAddEnv f bs) e + let !bs' = elaborate "PExist Args" mempty bs return (PExist bs' e', s) -elab f (PAll bs e) = do - (e', s) <- elab (elabAddEnv f bs) e - let bs' = elaborate "PAll Args" mempty bs +elab !f (PAll !bs !e) = do + (!e', !s) <- elab (elabAddEnv f bs) e + let !bs' = elaborate "PAll Args" mempty bs return (PAll bs' e', s) -elab f (ELam (x,t) e) = do - (e', s) <- elab (elabAddEnv f [(x, t)]) e - let t' = elaborate "ELam Arg" mempty t +elab !f (ELam (!x,!t) !e) = do + (!e', !s) <- elab (elabAddEnv f [(x, t)]) e + let !t' = elaborate "ELam Arg" mempty t return (ELam (x, t') (eCst e' s), FFunc t s) -elab f (ECoerc s t e) = do - (e', _) <- elab f e - return (ECoerc s t e', t) +elab !f (ECoerc !s !t !e) = do + (!e', !_) <- elab f e + return (ECoerc s t e', t) -elab _ (ETApp _ _) = +elab !_ (ETApp _ _) = error "SortCheck.elab: TODO: implement ETApp" -elab _ (ETAbs _ _) = +elab !_ (ETAbs _ _) = error "SortCheck.elab: TODO: implement ETAbs" - -- | 'eCstAtom' is to support tests like `tests/pos/undef00.fq` eCstAtom :: ElabEnv -> Expr -> Sort -> CheckM Expr eCstAtom f@(sym,g) (EVar x) t @@ -688,14 +703,15 @@ elabAs f t e = notracepp _msg <$> go e -- DUPLICATION with `checkApp'` elabAppAs :: ElabEnv -> Sort -> Expr -> Expr -> CheckM Expr elabAppAs env@(_, f) t g e = do - gT <- checkExpr f g - eT <- checkExpr f e - (iT, oT, isu) <- checkFunSort gT + tg <- checkExpr f g + te <- checkExpr f e + (iT, oT) <- checkFunSort tg let ge = Just (EApp g e) - su <- unifyMany f ge isu [oT, iT] [t, eT] - let tg = apply su gT + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifyManyUF f ge uf [oT, iT] [t, te] + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) g' <- elabAs env tg g - let te = apply su eT e' <- elabAs env te e pure $ EApp (ECst g' tg) (ECst e' te) @@ -709,9 +725,12 @@ elabEApp f@(_, g) e1 e2 = do elabAppSort :: Env -> Expr -> Expr -> Sort -> Sort -> CheckM (Expr, Expr, Sort, Sort, Sort) elabAppSort f e1 e2 s1 s2 = do let e = Just (EApp e1 e2) - (sIn, sOut, su) <- checkFunSort s1 - su' <- unify1 f e su sIn s2 - return (applyExpr (Just su') e1 , applyExpr (Just su') e2, apply su' s1, apply su' s2, apply su' sOut) + (sIn, sOut) <- checkFunSort s1 + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unify1UF f e uf sIn s2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return (e1 , e2, s1, s2, sOut) -------------------------------------------------------------------------------- @@ -906,12 +925,12 @@ which, I imagine is what happens _somewhere_ inside GHC too? -} -------------------------------------------------------------------------------- -applySorts :: Vis.Visitable t => t -> [Sort] +applySorts :: Vis.Foldable t => t -> [Sort] -------------------------------------------------------------------------------- applySorts = {- notracepp "applySorts" . -} (defs ++) . Vis.fold vis () [] where defs = [FFunc t1 t2 | t1 <- basicSorts, t2 <- basicSorts] - vis = (Vis.defaultVisitor :: Vis.Visitor [KVar] t) { Vis.accExpr = go } + vis = (Vis.defaultFolder :: Vis.Folder [KVar] t) { Vis.accExpr = go } go _ (EApp (ECst (EVar f) t) _) -- get types needed for [NOTE:apply-monomorphism] | f == applyName = [t] @@ -934,7 +953,9 @@ exprSortMaybe = go go (ELam (_, sx) e) = FFunc sx <$> go e go (EApp e ex) | Just (FFunc sx s) <- genSort <$> go e - = maybe s (`apply` s) . (`unifySorts` sx) <$> go ex + = + let !_ = unsafePerformIO $ print ("About to apply " ++ show sx) in + maybe s (`apply` s) . (`unifySorts` sx) <$> go ex go _ = Nothing genSort :: Sort -> Sort @@ -943,8 +964,11 @@ genSort t = t unite :: Env -> Expr -> Sort -> Sort -> CheckM (Sort, Sort) unite f e t1 t2 = do - su <- unifys f (Just e) [t1] [t2] - return (apply su t1, apply su t2) + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifysUF f (Just e) uf [t1] [t2] + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return (t1, t2) throwErrorAt :: String -> CheckM a throwErrorAt ~err' = do -- Lazy pattern needed because we use LANGUAGE Strict in this module @@ -970,7 +994,7 @@ refreshNegativeTyVars s = do let negativeSorts = negSort s freshVars <- mapM pair $ S.toList negativeSorts pure $ foldr (uncurry subst) s freshVars - where + where pair i = do f <- fresh pure (i, FVar f) @@ -987,18 +1011,27 @@ checkIte f p e1 e2 = do checkPred f p t1 <- checkExpr f e1 t2 <- checkExpr f e2 - checkIteTy f p e1 e2 t1 t2 - -getIte :: Env -> Expr -> Expr -> CheckM Sort -getIte f e1 e2 = do + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + (s, uf') <- checkIteTyUF f uf p e1 e2 t1 t2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return s + +getIteUF :: Env -> Expr -> Expr -> CheckM Sort +getIteUF f e1 e2 = do t1 <- checkExpr f e1 t2 <- checkExpr f e2 - (`apply` t1) <$> unifys f Nothing [t1] [t2] - -checkIteTy :: Env -> Expr -> Expr -> Expr -> Sort -> Sort -> CheckM Sort -checkIteTy f p e1 e2 t1 t2 = - ((`apply` t1) <$> unifys f e' [t1] [t2]) `withError` errIte e1 e2 t1 t2 - where + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifyUF f uf Nothing t1 t2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return t1 + +checkIteTyUF :: Env -> UF -> Expr -> Expr -> Expr -> Sort -> Sort -> CheckM (Sort, UF) +checkIteTyUF f uf p e1 e2 t1 t2 = do + uf' <- unifysUF f e' uf [t1] [t2] `withError` errIte e1 e2 t1 t2 + return (t1, uf') + where e' = Just (EIte p e1 e2) -- | Helper for checking cast expressions @@ -1006,21 +1039,29 @@ checkCst :: Env -> Sort -> Expr -> CheckM Sort checkCst f t (EApp g e) = checkApp f (Just t) g e checkCst f t e - = do t' <- checkExpr f e - su <- unifys f (Just e) [t] [t'] `withError` errCast e t' t - pure (apply su t) + = do + t' <- checkExpr f e + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifysUF f (Just e) uf [t] [t'] `withError` errCast e t' t + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + pure t checkApp :: Env -> Maybe Sort -> Expr -> Expr -> CheckM Sort checkApp f to g es - = snd <$> checkApp' f to g es + = checkApp' f to g es checkExprAs :: Env -> Sort -> Expr -> CheckM Sort checkExprAs f t (EApp g e) = checkApp f (Just t) g e checkExprAs f t e - = do t' <- checkExpr f e - θ <- unifys f (Just e) [t'] [t] - pure $ apply θ t + = do + t' <- checkExpr f e + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifysUF f (Just e) uf [t'] [t] + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return t -- | Helper for checking uninterpreted function applications -- | Checking function application should be curried, e.g. @@ -1028,20 +1069,23 @@ checkExprAs f t e -- RJ: The above comment makes no sense to me :( -- DUPLICATION with 'elabAppAs' -checkApp' :: Env -> Maybe Sort -> Expr -> Expr -> CheckM (TVSubst, Sort) +checkApp' :: Env -> Maybe Sort -> Expr -> Expr -> CheckM Sort checkApp' f to g e = do gt <- checkExpr f g et <- checkExpr f e - (it, ot, isu) <- checkFunSort gt + (it, ot) <- checkFunSort gt let ge = Just (EApp g e) - su <- unifyMany f ge isu [it] [et] - let t = apply su ot + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifyManyUF f ge uf [it] [et] + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) case to of - Nothing -> return (su, t) - Just t' -> do θ' <- unifyMany f ge su [t] [t'] - let ti = apply θ' et - _ <- checkExprAs f ti e - return (θ', apply θ' t) + Nothing -> return ot + Just t' -> do + uf'' <- unifyManyUF f ge uf' [ot] [t'] + liftIO $ atomicModifyIORef' ufRef $ const (uf'', ()) + _ <- checkExprAs f et e + return ot -- | Helper for checking binary (numeric) operations @@ -1056,7 +1100,6 @@ checkOp f e1 o e2 t2 <- checkExpr f e2 checkOpTy f (EBin o e1 e2) t1 t2 - checkOpTy :: Env -> Expr -> Sort -> Sort -> CheckM Sort checkOpTy _ _ FInt FInt = return FInt @@ -1070,12 +1113,14 @@ checkOpTy _ _ FInt FReal checkOpTy _ _ FReal FInt = return FReal -checkOpTy f e t t' - | Just s <- unify f (Just e) t t' - = checkNumeric f (apply s t) >> return (apply s t) - -checkOpTy _ e t t' - = throwErrorAt (errOp e t t') +checkOpTy f e t t' = do + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifyUF f uf (Just e) t t' + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + -- probably need to unify this with a numeric sort for this to work properly + return t + -- checkNumeric f t >> return t checkFractional :: Env -> Sort -> CheckM () checkFractional f s@(FObj l) @@ -1088,6 +1133,11 @@ checkNumeric :: Env -> Sort -> CheckM () checkNumeric f s@(FObj l) = do t <- checkSym f l unless (t `elem` [FNum, FFrac, intSort, FInt]) (throwErrorAt $ errNonNumeric s) +checkNumeric _ s@(FVar i) = do + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + unless (isNumeric (Union.find uf i)) (throwErrorAt $ errNonNumeric s) + checkNumeric _ s = unless (isNumeric s) (throwErrorAt $ errNonNumeric s) @@ -1100,6 +1150,15 @@ checkEqConstr f e θ a t = Found tA -> unify1 f e θ tA t _ -> throwErrorAt $ errUnifyMsg (Just "ceq2") e (FObj a) t +checkEqConstrUF :: Env -> Maybe Expr -> UF -> Symbol -> Sort -> CheckM UF +checkEqConstrUF _ _ uf a (FObj b) + | a == b + = return uf +checkEqConstrUF f e uf a t = + case f a of + Found tA -> unify1UF f e uf tA t + _ -> throwErrorAt $ errUnifyMsg (Just "ceq2") e (FObj a) t + -------------------------------------------------------------------------------- -- | Checking Predicates ------------------------------------------------------- -------------------------------------------------------------------------------- @@ -1107,6 +1166,12 @@ checkPred :: Env -> Expr -> CheckM () checkPred f e = checkExpr f e >>= checkBoolSort e checkBoolSort :: Expr -> Sort -> CheckM () + +checkBoolSort e (FVar i) = do + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + let s = Union.find uf i + if s == boolSort then return () else throwErrorAt (errBoolSort e s) checkBoolSort e s | s == boolSort = return () | otherwise = throwErrorAt (errBoolSort e s) @@ -1116,9 +1181,12 @@ checkRel :: HasCallStack => Env -> Brel -> Expr -> Expr -> CheckM () checkRel f Eq e1 e2 = do t1 <- checkExpr f e1 t2 <- checkExpr f e2 - su <- unifys f (Just e) [t1] [t2] `withError` errRel e t1 t2 - _ <- checkExprAs f (apply su t1) e1 - _ <- checkExprAs f (apply su t2) e2 + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifysUF f (Just e) uf [t1] [t2] `withError` errRel e t1 t2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + _ <- checkExprAs f t1 e1 + _ <- checkExprAs f t2 e2 checkRelTy f e Eq t1 t2 where e = PAtom Eq e1 e2 @@ -1137,13 +1205,16 @@ checkRelTy f _ _ s1@(FObj l) s2@(FObj l') | l /= l' checkRelTy _ _ _ FReal FReal = return () checkRelTy _ _ _ FInt FReal = return () checkRelTy _ _ _ FReal FInt = return () -checkRelTy f _ _ FInt s2 = checkNumeric f s2 `withError` errNonNumeric s2 +checkRelTy f _ _ FInt s2 = checkNumeric f s2 `withError` errNonNumeric s2 checkRelTy f _ _ s1 FInt = checkNumeric f s1 `withError` errNonNumeric s1 checkRelTy f _ _ FReal s2 = checkFractional f s2 `withError` errNonFractional s2 checkRelTy f _ _ s1 FReal = checkFractional f s1 `withError` errNonFractional s1 -checkRelTy f e Eq t1 t2 = void (unifys f (Just e) [t1] [t2] `withError` errRel e t1 t2) -checkRelTy f e Ne t1 t2 = void (unifys f (Just e) [t1] [t2] `withError` errRel e t1 t2) -checkRelTy _ e _ t1 t2 = unless (t1 == t2) (throwErrorAt $ errRel e t1 t2) +checkRelTy f e _ t1 t2 = do + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifysUF f (Just e) uf [t1] [t2] `withError` errRel e t1 t2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return () checkURel :: Expr -> Sort -> Sort -> CheckM () checkURel e s1 s2 = unless (b1 == b2) (throwErrorAt $ errRel e s1 s2) @@ -1151,30 +1222,87 @@ checkURel e s1 s2 = unless (b1 == b2) (throwErrorAt $ errRel e s1 s2) b1 = s1 == boolSort b2 = s2 == boolSort --------------------------------------------------------------------------------- +------------------------------------------------------------------- -- | Sort Unification on Expressions -------------------------------------------------------------------------------- +unifyUF :: Env -> UF -> Maybe Expr -> Sort -> Sort -> CheckM UF +-------------------------------------------------------------------------------- +unifyUF f uf e t1 t2 + = unify1UF f e uf t1 t2 + +-- -------------------------------------------------------------------------------- +-- unifyTo1UF :: Env -> UF -> [Sort] -> CheckM UF +-- -------------------------------------------------------------------------------- +-- unifyTo1UF f uf ts +-- = unifyTo1MUF f uf ts + + +-- -------------------------------------------------------------------------------- +-- unifyTo1MUF :: Env -> UF -> [Sort] -> CheckM UF +-- -------------------------------------------------------------------------------- +-- unifyTo1MUF _ _ [] = panic "unifyTo1: empty list" +-- unifyTo1MUF f uf (t0:ts) = fst <$> foldM step (uf, t0) ts +-- where +-- step :: (UF, Sort) -> Sort -> CheckM (UF, Sort) +-- step (ufm, t) t' = do +-- ufm' <- unify1UF f Nothing ufm t t' +-- return (ufm', t) -{-# SCC unifyExpr #-} -unifyExpr :: Env -> Expr -> Maybe TVSubst -unifyExpr f (EApp e1 e2) = Just $ mconcat $ catMaybes [θ1, θ2, θ] - where - θ1 = unifyExpr f e1 - θ2 = unifyExpr f e2 - θ = unifyExprApp f e1 e2 -unifyExpr f (ECst e _) - = unifyExpr f e -unifyExpr _ _ - = Nothing - -unifyExprApp :: Env -> Expr -> Expr -> Maybe TVSubst -unifyExprApp f e1 e2 = do - t1 <- getArg $ exprSortMaybe e1 - t2 <- exprSortMaybe e2 - unify f (Just $ EApp e1 e2) t1 t2 - where - getArg (Just (FFunc t1 _)) = Just t1 - getArg _ = Nothing +-------------------------------------------------------------------------------- +unifysUF :: HasCallStack => Env -> Maybe Expr -> UF -> [Sort] -> [Sort] -> CheckM UF +-------------------------------------------------------------------------------- +unifysUF f e uf = unifyManyUF f e uf + +unifyManyUF :: HasCallStack => Env -> Maybe Expr -> UF -> [Sort] -> [Sort] -> CheckM UF +unifyManyUF f e uf ts ts' + | length ts == length ts' = foldM (uncurry . unify1UF f e) uf $ zip ts ts' + | otherwise = throwErrorAt (errUnifyMany ts ts') + +unify1UF :: Env -> Maybe Expr -> UF -> Sort -> Sort -> CheckM UF +unify1UF f e !uf (FVar !i) !t + = unifyVarUF f e uf i t +unify1UF f e !uf !t (FVar !i) + = unifyVarUF f e uf i t +unify1UF f e !uf (FApp !t1 !t2) (FApp !t1' !t2') + = unifyManyUF f e uf [t1, t2] [t1', t2'] +unify1UF _ _ !uf (FTC !l1) (FTC !l2) + | isListTC l1 && isListTC l2 + = return uf +unify1UF f e !uf t1@(FAbs _ _) !t2 = do + !t1' <- instantiate t1 + unifyManyUF f e uf [t1'] [t2] +unify1UF f e !uf !t1 t2@(FAbs _ _) = do + !t2' <- instantiate t2 + unifyManyUF f e uf [t1] [t2'] +unify1UF _ _ !uf !s1 !s2 + | isString s1, isString s2 + = return uf +unify1UF _ _ !uf FInt FReal = return uf + +unify1UF _ _ !uf FReal FInt = return uf + +unify1UF f e !uf !t FInt = do + checkNumeric f t `withError` errUnify e t FInt + return uf + +unify1UF f e !uf FInt !t = do + checkNumeric f t `withError` errUnify e FInt t + return uf + +unify1UF f e !uf (FFunc !t1 !t2) (FFunc !t1' !t2') = + unifyManyUF f e uf [t1, t2] [t1', t2'] + +unify1UF f e uf (FObj a) !t = + checkEqConstrUF f e uf a t + +unify1UF f e uf !t (FObj a) = + checkEqConstrUF f e uf a t + +unify1UF _ e uf !t1 !t2 + | t1 == t2 + = return uf + | otherwise + = throwErrorAt (errUnify e t1 t2) -------------------------------------------------------------------------------- @@ -1355,22 +1483,67 @@ unifyVar f e θ !i !t Just !t' -> if t == t' then return θ else unify1 f e θ t t' Nothing -> return (updateVar i t θ) +unifyVarUF :: Env -> Maybe Expr -> UF -> Int -> Sort -> CheckM UF +unifyVarUF _ _ uf !i (FVar !j) = + if i == j then return uf else return (Union.union uf j (FVar i)) + +unifyVarUF _ _ uf !i !t + = return (Union.union uf i t) + +-------------------------------------------------------------------------------- +-- | Update global subst to be applied to expressions +-------------------------------------------------------------------------------- + +-- updateTVSubst :: TVSubst -> CheckM () +-- updateTVSubst theta = do +-- refTheta <- asks chTVSubst +-- liftIO $ atomicModifyIORef' refTheta $ const (Just theta, ()) + +-- -- local (\s -> s {chTVSubst = theta}) (return ()) + +-- mergeTVSubst :: TVSubst -> Maybe TVSubst -> TVSubst +-- mergeTVSubst (Th m1) Nothing = Th m1 +-- mergeTVSubst (Th m1) (Just (Th m2)) = Th m1 <> Th m2 + +-- composeTVSubst :: Maybe TVSubst -> CheckM () +-- composeTVSubst Nothing = return () +-- composeTVSubst (Just theta1) = do +-- refTheta <- asks chTVSubst +-- theta <- liftIO $ readIORef refTheta +-- updateTVSubst (mergeTVSubst theta1 theta) + -------------------------------------------------------------------------------- -- | Applying a Type Substitution ---------------------------------------------- -------------------------------------------------------------------------------- apply :: TVSubst -> Sort -> Sort -------------------------------------------------------------------------------- -apply θ = Vis.mapSort f +apply !θ = Vis.mapSort f where - f t@(FVar i) = fromMaybe t (lookupVar i θ) - f t = t + f t@(FVar !i) = fromMaybe t (lookupVar i θ) + f !t = t + +-- applyExpr :: Maybe TVSubst -> Expr -> Expr +-- applyExpr Nothing e = e +-- applyExpr (Just θ) e = Vis.mapExprOnExpr f e +-- where +-- f (ECst !e' !s) = ECst e' (apply θ s) +-- f !e' = e' -applyExpr :: Maybe TVSubst -> Expr -> Expr -applyExpr Nothing e = e -applyExpr (Just θ) e = Vis.mapExprOnExpr f e +-------------------------------------------------------------------------------- +-- | Applying the result of union find +-------------------------------------------------------------------------------- +applyUF :: UF -> Sort -> Sort +applyUF uf = Vis.mapSort f where - f (ECst e' s) = ECst e' (apply θ s) - f e' = e' + f (FVar !i) = Union.find uf i + f !t = t + +{-# SCC applyExprUF #-} +applyExprUF :: UF -> Expr -> Expr +applyExprUF uf e = Vis.mapExprOnExpr f e + where + f (ECst !e' !s) = ECst e' (applyUF uf s) + f !e' = e' -------------------------------------------------------------------------------- _applyCoercion :: Symbol -> Sort -> Sort -> Sort @@ -1385,12 +1558,15 @@ _applyCoercion a t = Vis.mapSort f -------------------------------------------------------------------------------- -- | Deconstruct a function-sort ----------------------------------------------- -------------------------------------------------------------------------------- -checkFunSort :: Sort -> CheckM (Sort, Sort, TVSubst) +checkFunSort :: Sort -> CheckM (Sort, Sort) checkFunSort (FAbs _ t) = checkFunSort t -checkFunSort (FFunc t1 t2) = return (t1, t2, emptySubst) -checkFunSort (FVar i) = do j <- fresh - k <- fresh - return (FVar j, FVar k, updateVar i (FFunc (FVar j) (FVar k)) emptySubst) +checkFunSort (FFunc t1 t2) = return (t1, t2) +checkFunSort (FVar i) = do + k <- fresh + j <- fresh + ufRef <- asks ufM + _ <- liftIO $ atomicModifyIORef' ufRef $ \uf -> (Union.union uf i (FFunc (FVar j) (FVar k)), ()) + return (FVar j, FVar k) checkFunSort t = throwErrorAt (errNonFunction 1 t) -------------------------------------------------------------------------------- @@ -1446,12 +1622,12 @@ errRel e t1 t2 = traced $ printf "Invalid Relation %s with operand types %s and %s" (showpp e) (showpp t1) (showpp t2) -errOp :: Expr -> Sort -> Sort -> String -errOp e t t' - | t == t' = printf "Operands have non-numeric types %s in %s" - (showpp t) (showpp e) - | otherwise = printf "Operands have different types %s and %s in %s" - (showpp t) (showpp t') (showpp e) +-- errOp :: Expr -> Sort -> Sort -> String +-- errOp e t t' +-- | t == t' = printf "Operands have non-numeric types %s in %s" +-- (showpp t) (showpp e) +-- | otherwise = printf "Operands have different types %s and %s in %s" +-- (showpp t) (showpp t') (showpp e) errIte :: Expr -> Expr -> Sort -> Sort -> String errIte e1 e2 t1 t2 = printf "Mismatched branches in Ite: then %s : %s, else %s : %s" @@ -1479,3 +1655,4 @@ errNonFractional l = printf "The sort %s is not fractional" (showpp l) errBoolSort :: Expr -> Sort -> String errBoolSort e s = printf "Expressions %s should have bool sort, but has %s" (showpp e) (showpp s) + diff --git a/src/Language/Fixpoint/Types/Visitor.hs b/src/Language/Fixpoint/Types/Visitor.hs index 163518e0f..c2f9d6f21 100644 --- a/src/Language/Fixpoint/Types/Visitor.hs +++ b/src/Language/Fixpoint/Types/Visitor.hs @@ -6,17 +6,18 @@ {-# LANGUAGE BangPatterns #-} {-# OPTIONS_GHC -Wno-name-shadowing #-} +{-# LANGUAGE InstanceSigs #-} module Language.Fixpoint.Types.Visitor ( -- * Visitor - Visitor (..) - , Visitable (..) + Folder (..) + , Foldable (..) -- * Extracting Symbolic Constants (String Literals) , SymConsts (..) -- * Default Visitor - , defaultVisitor + , defaultFolder -- * Transformers , trans @@ -57,11 +58,15 @@ import qualified Data.HashMap.Strict as M import qualified Data.List as L import Language.Fixpoint.Types hiding (mapSort) import qualified Language.Fixpoint.Misc as Misc +import Control.Monad.Reader +import GHC.IO (unsafePerformIO) +import Data.IORef (newIORef, readIORef, IORef, modifyIORef') +import Prelude hiding (Foldable) -data Visitor acc ctx = Visitor { +data Folder acc ctx = Visitor { -- | Context @ctx@ is built in a "top-down" fashion; not "across" siblings ctxExpr :: ctx -> Expr -> ctx @@ -73,9 +78,9 @@ data Visitor acc ctx = Visitor { } --------------------------------------------------------------------------------- -defaultVisitor :: (Monoid acc) => Visitor acc ctx +defaultFolder :: (Monoid acc) => Folder acc ctx --------------------------------------------------------------------------------- -defaultVisitor = Visitor +defaultFolder = Visitor { ctxExpr = const , txExpr = \_ x -> x , accExpr = \_ _ -> mempty @@ -83,81 +88,164 @@ defaultVisitor = Visitor ------------------------------------------------------------------------ -fold :: (Visitable t, Monoid a) => Visitor a ctx -> ctx -> a -> t -> a -fold v c a t = snd $ execVisitM v c a visit t +fold :: (Foldable t, Monoid a) => Folder a ctx -> ctx -> a -> t -> a +fold v c a t = snd $ execVisitM v c a foldE t -trans :: (Visitable t, Monoid a) => Visitor a ctx -> ctx -> a -> t -> t -trans v c _ z = fst $ execVisitM v c mempty visit z - -execVisitM :: Visitor a ctx -> ctx -> a -> (Visitor a ctx -> ctx -> t -> State a t) -> t -> (t, a) -execVisitM v c a f x = runState (f v c x) a - -type VisitM acc = State acc - -accum :: (Monoid a) => a -> VisitM a () -accum !z = modify (mappend z) - -- do - -- !cur <- get - -- put ((mappend $!! z) $!! cur) +-- trans is always passed () () for a and t so we don't need to use the visitor pattern +-- trans :: (Visitable t, Monoid a) => Visitor a ctx -> ctx -> a -> t -> t +-- trans !v !c !_ !z = fst $ execVisitM v c mempty visit z class Visitable t where - visit :: (Monoid a) => Visitor a c -> c -> t -> VisitM a t + transE :: (Expr -> Expr) -> t -> t + +trans :: Visitable t => (Expr -> Expr) -> t -> t +trans f t = transE f t instance Visitable Expr where - visit = visitExpr + transE f = vE + where + vE e = step e' where e' = f e + step e@(ESym _) = e + step e@(ECon _) = e + step e@(EVar _) = e + step (EApp e1 e2) = EApp (vE e1) (vE e2) + step (ENeg e) = ENeg (vE e) + step (EBin o e1 e2) = EBin o (vE e1) (vE e2) + step (EIte p e1 e2) = EIte (vE p) (vE e1) (vE e2) + step (ECst e t) = ECst (vE e) t + step (PAnd ps) = PAnd (map vE ps) + step (POr ps) = POr (map vE ps) + step (PNot p) = PNot (vE p) + step (PImp p1 p2) = PImp (vE p1) (vE p2) + step (PIff p1 p2) = PIff (vE p1) (vE p2) + step (PAtom r e1 e2) = PAtom r (vE e1) (vE e2) + step (PAll xts p) = PAll xts (vE p) + step (ELam (x,t) e) = ELam (x,t) (vE e) + step (ECoerc a t e) = ECoerc a t (vE e) + step (PExist xts p) = PExist xts (vE p) + step (ETApp e s) = ETApp (vE e) s + step (ETAbs e s) = ETAbs (vE e) s + step p@(PKVar _ _) = p + step (PGrad k su i e) = PGrad k su i (vE e) instance Visitable Reft where - visit v c (Reft (x, ra)) = Reft . (x, ) <$> visit v c ra + transE v (Reft (x, ra)) = Reft (x, transE v ra) instance Visitable SortedReft where - visit v c (RR t r) = RR t <$> visit v c r + transE v (RR t r) = RR t (transE v r) instance Visitable (Symbol, SortedReft, a) where - visit v c (sym, sr, a) = (sym, ,a) <$> visit v c sr + transE f (sym, sr, a) = (sym, transE f sr, a) instance Visitable (BindEnv a) where - visit v c = mapM (visit v c) + transE v be = be { beBinds = M.map (transE v) (beBinds be) } + +instance (Visitable (c a)) => Visitable (GInfo c a) where + transE f x = x { + cm = transE f <$> cm x + , bs = transE f (bs x) + , ae = transE f (ae x) + } + +instance Visitable (SimpC a) where + transE v x = x { + _crhs = transE v (_crhs x) + } + +instance Visitable (SubC a) where + transE v x = x { + slhs = transE v (slhs x), + srhs = transE v (srhs x) + } + +instance Visitable AxiomEnv where + transE v x = x { + aenvEqs = transE v <$> aenvEqs x, + aenvSimpl = transE v <$> aenvSimpl x + } + +instance Visitable Equation where + transE v eq = eq { + eqBody = transE v (eqBody eq) + } + +instance Visitable Rewrite where + transE v rw = rw { + smBody = transE v (smBody rw) + } + +execVisitM :: Folder a ctx -> ctx -> a -> (Folder a ctx -> ctx -> t -> FoldM a t) -> t -> (t, a) +execVisitM !v !c !a !f !x = unsafePerformIO $ do + rn <- newIORef a + result <- runReaderT (f v c x) rn + finalAcc <- readIORef rn + return (result, finalAcc) + +type FoldM acc = ReaderT (IORef acc) IO + +accum :: (Monoid a) => a -> FoldM a () +accum !z = do + ref <- ask + liftIO $ modifyIORef' ref (mappend z) + +class Foldable t where + foldE :: (Monoid a) => Folder a c -> c -> t -> FoldM a t + +instance Foldable Expr where + foldE = foldExpr + +instance Foldable Reft where + foldE v c (Reft (x, ra)) = Reft . (x, ) <$> foldE v c ra + +instance Foldable SortedReft where + foldE v c (RR t r) = RR t <$> foldE v c r + +instance Foldable (Symbol, SortedReft, a) where + foldE v c (sym, sr, a) = (sym, ,a) <$> foldE v c sr + +instance Foldable (BindEnv a) where + foldE v c = mapM (foldE v c) --------------------------------------------------------------------------------- -- WARNING: these instances were written for mapKVars over GInfos only; -- check that they behave as expected before using with other clients. -instance Visitable (SimpC a) where - visit v c x = do - rhs' <- visit v c (_crhs x) +instance Foldable (SimpC a) where + foldE v c x = do + rhs' <- foldE v c (_crhs x) return x { _crhs = rhs' } -instance Visitable (SubC a) where - visit v c x = do - lhs' <- visit v c (slhs x) - rhs' <- visit v c (srhs x) +instance Foldable (SubC a) where + foldE v c x = do + lhs' <- foldE v c (slhs x) + rhs' <- foldE v c (srhs x) return x { slhs = lhs', srhs = rhs' } -instance (Visitable (c a)) => Visitable (GInfo c a) where - visit v c x = do - cm' <- mapM (visit v c) (cm x) - bs' <- visit v c (bs x) - ae' <- visit v c (ae x) +instance (Foldable (c a)) => Foldable (GInfo c a) where + foldE v c x = do + cm' <- mapM (foldE v c) (cm x) + bs' <- foldE v c (bs x) + ae' <- foldE v c (ae x) return x { cm = cm', bs = bs', ae = ae' } -instance Visitable AxiomEnv where - visit v c x = do - eqs' <- mapM (visit v c) (aenvEqs x) - simpls' <- mapM (visit v c) (aenvSimpl x) +instance Foldable AxiomEnv where + foldE v c x = do + eqs' <- mapM (foldE v c) (aenvEqs x) + simpls' <- mapM (foldE v c) (aenvSimpl x) return x { aenvEqs = eqs' , aenvSimpl = simpls'} -instance Visitable Equation where - visit v c eq = do - body' <- visit v c (eqBody eq) +instance Foldable Equation where + foldE v c eq = do + body' <- foldE v c (eqBody eq) return eq { eqBody = body' } -instance Visitable Rewrite where - visit v c rw = do - body' <- visit v c (smBody rw) +instance Foldable Rewrite where + foldE v c rw = do + body' <- foldE v c (smBody rw) return rw { smBody = body' } --------------------------------------------------------------------------------- -visitExpr :: (Monoid a) => Visitor a ctx -> ctx -> Expr -> VisitM a Expr -visitExpr !v = vE +foldExpr :: (Monoid a) => Folder a ctx -> ctx -> Expr -> FoldM a Expr +foldExpr !v = vE where vE !c !e = do {- SCC "visitExpr.vE.accum" -} accum acc {- SCC "visitExpr.vE.step" -} step c' e' @@ -193,56 +281,124 @@ mapKVars f = mapKVars' f' f' (kv', _) = f kv' mapKVars' :: Visitable t => ((KVar, Subst) -> Maybe Expr) -> t -> t -mapKVars' f = trans kvVis () () +mapKVars' f = trans txK where - kvVis = defaultVisitor { txExpr = txK } - txK _ (PKVar k su) + txK (PKVar k su) | Just p' <- f (k, su) = subst su p' - txK _ (PGrad k su _ _) + txK (PGrad k su _ _) | Just p' <- f (k, su) = subst su p' - txK _ p = p + txK p = p mapGVars' :: Visitable t => ((KVar, Subst) -> Maybe Expr) -> t -> t -mapGVars' f = trans kvVis () () +mapGVars' f = trans txK where - kvVis = defaultVisitor { txExpr = txK } - txK _ (PGrad k su _ _) + txK (PGrad k su _ _) | Just p' <- f (k, su) = subst su p' - txK _ p = p + txK p = p mapExpr :: Visitable t => (Expr -> Expr) -> t -> t -mapExpr f = trans (defaultVisitor {txExpr = const f}) () () +mapExpr f = trans f -- | Specialized and faster version of mapExpr for expressions mapExprOnExpr :: (Expr -> Expr) -> Expr -> Expr mapExprOnExpr f = go where - go e0 = f $ case e0 of - EApp f e -> EApp (go f) (go e) - ENeg e -> ENeg (go e) - EBin o e1 e2 -> EBin o (go e1) (go e2) - EIte p e1 e2 -> EIte (go p) (go e1) (go e2) - ECst e t -> ECst (go e) t - PAnd ps -> PAnd (map go ps) - POr ps -> POr (map go ps) - PNot p -> PNot (go p) - PImp p1 p2 -> PImp (go p1) (go p2) - PIff p1 p2 -> PIff (go p1) (go p2) - PAtom r e1 e2 -> PAtom r (go e1) (go e2) - PAll xts p -> PAll xts (go p) - ELam (x,t) e -> ELam (x,t) (go e) - ECoerc a t e -> ECoerc a t (go e) - PExist xts p -> PExist xts (go p) - ETApp e s -> ETApp (go e) s - ETAbs e s -> ETAbs (go e) s - PGrad k su i e -> PGrad k su i (go e) + go !e0 = f $! case e0 of + EApp f e -> + let !f' = go f + !e' = go e + in EApp f' e' + ENeg e -> + let !e' = go e + in ENeg e' + EBin o e1 e2 -> + let !e1' = go e1 + !e2' = go e2 + in EBin o e1' e2' + EIte p e1 e2 -> + let !p' = go p + !e1' = go e1 + !e2' = go e2 + in EIte p' e1' e2' + ECst e t -> + let !e' = go e + in ECst e' t + PAnd ps -> + let !ps' = map go ps + in PAnd ps' + POr ps -> + let !ps' = map go ps + in POr ps' + PNot p -> + let !p' = go p + in PNot p' + PImp p1 p2 -> + let !p1' = go p1 + !p2' = go p2 + in PImp p1' p2' + PIff p1 p2 -> + let !p1' = go p1 + !p2' = go p2 + in PIff p1' p2' + PAtom r e1 e2 -> + let !e1' = go e1 + !e2' = go e2 + in PAtom r e1' e2' + PAll xts p -> + let !p' = go p + in PAll xts p' + ELam (x,t) e -> + let !e' = go e + in ELam (x,t) e' + ECoerc a t e -> + let !e' = go e + in ECoerc a t e' + PExist xts p -> + let !p' = go p + in PExist xts p' + ETApp e s -> + let !e' = go e + in ETApp e' s + ETAbs e s -> + let !e' = go e + in ETAbs e' s + PGrad k su i e -> + let !e' = go e + in PGrad k su i e' e@PKVar{} -> e e@EVar{} -> e e@ESym{} -> e e@ECon{} -> e +-- mapExprOnExpr :: (Expr -> Expr) -> Expr -> Expr +-- mapExprOnExpr f = go +-- where +-- go !e0 = f $! case e0 of +-- EApp f e -> EApp !(go f) !(go e) +-- ENeg e -> ENeg (go e) +-- EBin o e1 e2 -> EBin o (go e1) (go e2) +-- EIte p e1 e2 -> EIte (go p) (go e1) (go e2) +-- ECst e t -> ECst (go e) t +-- PAnd ps -> PAnd (map go ps) +-- POr ps -> POr (map go ps) +-- PNot p -> PNot (go p) +-- PImp p1 p2 -> PImp (go p1) (go p2) +-- PIff p1 p2 -> PIff (go p1) (go p2) +-- PAtom r e1 e2 -> PAtom r (go e1) (go e2) +-- PAll xts p -> PAll xts (go p) +-- ELam (x,t) e -> ELam (x,t) (go e) +-- ECoerc a t e -> ECoerc a t (go e) +-- PExist xts p -> PExist xts (go p) +-- ETApp e s -> ETApp (go e) s +-- ETAbs e s -> ETAbs (go e) s +-- PGrad k su i e -> PGrad k su i (go e) +-- e@PKVar{} -> e +-- e@EVar{} -> e +-- e@ESym{} -> e +-- e@ECon{} -> e + mapMExpr :: (Monad m) => (Expr -> m Expr) -> Expr -> m Expr mapMExpr f = go @@ -271,12 +427,11 @@ mapMExpr f = go go (POr ps) = f . POr =<< (go `traverse` ps) mapKVarSubsts :: Visitable t => (KVar -> Subst -> Subst) -> t -> t -mapKVarSubsts f = trans kvVis () () +mapKVarSubsts f = trans txK where - kvVis = defaultVisitor { txExpr = txK } - txK _ (PKVar k su) = PKVar k (f k su) - txK _ (PGrad k su i e) = PGrad k (f k su) i e - txK _ p = p + txK (PKVar k su) = PKVar k (f k su) + txK (PGrad k su i e) = PGrad k (f k su) i e + txK p = p newtype MInt = MInt Integer -- deriving (Eq, NFData) @@ -285,27 +440,28 @@ instance Semigroup MInt where instance Monoid MInt where mempty = MInt 0 + mappend :: MInt -> MInt -> MInt mappend = (<>) -size :: Visitable t => t -> Integer +size :: Foldable t => t -> Integer size t = n where MInt n = fold szV () mempty t - szV = (defaultVisitor :: Visitor MInt t) { accExpr = \ _ _ -> MInt 1 } + szV = (defaultFolder :: Folder MInt t) { accExpr = \ _ _ -> MInt 1 } -lamSize :: Visitable t => t -> Integer +lamSize :: Foldable t => t -> Integer lamSize t = n where MInt n = fold szV () mempty t - szV = (defaultVisitor :: Visitor MInt t) { accExpr = accum } + szV = (defaultFolder :: Folder MInt t) { accExpr = accum } accum _ (ELam _ _) = MInt 1 accum _ _ = MInt 0 -eapps :: Visitable t => t -> [Expr] +eapps :: Foldable t => t -> [Expr] eapps = fold eappVis () [] where - eappVis = (defaultVisitor :: Visitor [KVar] t) { accExpr = eapp' } + eappVis = (defaultFolder :: Folder [KVar] t) { accExpr = eapp' } eapp' _ e@(EApp _ _) = [e] eapp' _ _ = [] @@ -511,9 +667,9 @@ instance SymConsts Reft where instance SymConsts Expr where symConsts = getSymConsts -getSymConsts :: Visitable t => t -> [SymConst] +getSymConsts :: Foldable t => t -> [SymConst] getSymConsts = fold scVis () [] where - scVis = (defaultVisitor :: Visitor [SymConst] t) { accExpr = sc } + scVis = (defaultFolder :: Folder [SymConst] t) { accExpr = sc } sc _ (ESym c) = [c] sc _ _ = [] diff --git a/src/Language/Fixpoint/Union.hs b/src/Language/Fixpoint/Union.hs new file mode 100644 index 000000000..12e4b9e9b --- /dev/null +++ b/src/Language/Fixpoint/Union.hs @@ -0,0 +1,59 @@ +{-# LANGUAGE BangPatterns #-} +module Language.Fixpoint.Union where +import Data.HashMap.Strict (lookup, insert, HashMap, empty) +import Prelude hiding (lookup) +import Language.Fixpoint.Types (Sort(..)) + +-------------------------------------------------------------------------------- +-- | union for sorts in union find +-------------------------------------------------------------------------------- +unionVals :: UF -> Sort -> Sort -> UF +-------------------------------------------------------------------------------- +unionVals uf s1 s2 + | isNumericSort s1 && isNumericSort s2 = uf + where + isNumericSort FReal = True + isNumericSort FNum = True + isNumericSort FFrac = True + isNumericSort FInt = True + isNumericSort _ = False + +unionVals uf (FObj x) (FObj y) + | x == y = uf +unionVals u@(MkUF uf) (FVar i) (FVar j) = if i == j then u else MkUF (insert i (FVar j) uf) +unionVals (MkUF uf) (FVar i) s = MkUF (insert i s uf) +unionVals (MkUF uf) s (FVar i) = MkUF (insert i s uf) +unionVals uf (FFunc s1 s2) (FFunc s1' s2') = unionVals uf' (getRep uf' s2) (getRep uf' s2') + where uf' = unionVals uf (getRep uf s1) (getRep uf s1') +unionVals uf (FApp s1 s2) (FApp s1' s2') = unionVals uf' (getRep uf' s2) (getRep uf' s2') + where uf' = unionVals uf (getRep uf s1) (getRep uf s1') +unionVals uf (FAbs _ s) (FAbs _ s') = unionVals uf (getRep uf s) (getRep uf s') +unionVals uf (FTC s1) (FTC s2) | s1 == s2 = uf +unionVals _ s1 s2 = error ("Cannot unify " ++ show s1 ++ " and " ++ show s2) + + +newtype UF = MkUF (HashMap Int Sort) deriving (Show) +new :: UF +new = MkUF empty + +union :: UF -> Int -> Sort -> UF +union !u !tyv !s = + let tyv_root = find u tyv + sort_root = getRep u s + in + if tyv_root == sort_root then u else unionVals u tyv_root sort_root + +getRep :: UF -> Sort -> Sort +getRep u s = + case s of + FVar i -> find u i + _ -> s + +find :: UF -> Int -> Sort +find (MkUF ufM) = f + where + f k = do + case lookup k ufM of + Nothing -> FVar k + Just (FVar i) -> f i + Just s -> s \ No newline at end of file