From 58a951a2b2757f50d76b08ce7439b736f4cf67fd Mon Sep 17 00:00:00 2001 From: Ranjit Jhala Date: Wed, 15 Jan 2025 13:55:58 -0800 Subject: [PATCH 01/33] elab-strict --- src/Language/Fixpoint/SortCheck.hs | 171 +++++++++++++------------ src/Language/Fixpoint/Types/Visitor.hs | 108 +++++++++++++--- 2 files changed, 177 insertions(+), 102 deletions(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index 42da843eb..095f9d8c6 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -530,137 +530,142 @@ addEnv f bs x {-# SCC elab #-} elab :: ElabEnv -> Expr -> CheckM (Expr, Sort) -------------------------------------------------------------------------------- -elab f@(_, g) e@(EBin o e1 e2) = do - (e1', s1) <- elab f e1 - (e2', s2) <- elab f e2 - s <- checkOpTy g e s1 s2 - return (EBin o (eCst e1' s1) (eCst e2' s2), s) - -elab f (EApp e1@(EApp _ _) e2) = do - (e1', _, e2', s2, s) <- notracepp "ELAB-EAPP" <$> elabEApp f e1 e2 - let e = eAppC s e1' (eCst e2' s2) - let θ = unifyExpr (snd f) e +elab f@(!_, !g) e@(EBin !o !e1 !e2) = do + (!e1', !s1) <- elab f e1 + (!e2', !s2) <- elab f e2 + !s <- checkOpTy g e s1 s2 + let !result = EBin o (eCst e1' s1) (eCst e2' s2) + return (result, s) + + +elab !f (EApp e1@(EApp !_ !_) !e2) = do + (!e1', !_, !e2', !s2, !s) <- notracepp "ELAB-EAPP" <$> elabEApp f e1 e2 + let !e = eAppC s e1' (eCst e2' s2) + let !θ = unifyExpr (snd f) e return (applyExpr θ e, maybe s (`apply` s) θ) -elab f (EApp e1 e2) = do - (e1', s1, e2', s2, s) <- elabEApp f e1 e2 - let e = eAppC s (eCst e1' s1) (eCst e2' s2) - let θ = unifyExpr (snd f) e +elab !f (EApp !e1 !e2) = do + (!e1', !s1, !e2', !s2, !s) <- elabEApp f e1 e2 + let !e = eAppC s (eCst e1' s1) (eCst e2' s2) + let !θ = unifyExpr (snd f) e return (applyExpr θ e, maybe s (`apply` s) θ) -elab _ e@(ESym _) = + +elab !_ e@(ESym _) = return (e, strSort) -elab _ e@(ECon (I _)) = +elab !_ e@(ECon (I _)) = return (e, FInt) -elab _ e@(ECon (R _)) = +elab !_ e@(ECon (R _)) = return (e, FReal) -elab _ e@(ECon (L _ s)) = +elab !_ e@(ECon (L _ !s)) = return (e, s) -elab _ e@(PKVar _ _) = +elab !_ e@(PKVar _ _) = return (e, boolSort) -elab f (PGrad k su i e) = - (, boolSort) . PGrad k su i . fst <$> elab f e -elab (_, f) e@(EVar x) = do - cs <- checkSym f x - pure (e, cs) +elab !f (PGrad !k !su !i !e) = do + (!e', !_) <- elab f e + return (PGrad k su i e', boolSort) + +elab (!_, !f) e@(EVar !x) = do + !cs <- checkSym f x + return (e, cs) -elab f (ENeg e) = do - (e', s) <- elab f e +elab !f (ENeg !e) = do + (!e', !s) <- elab f e return (ENeg e', s) -elab f@(_,g) (ECst (EIte p e1 e2) t) = do - (p', _) <- elab f p - (e1', s1) <- elab f (eCst e1 t) - (e2', s2) <- elab f (eCst e2 t) - s <- checkIteTy g p e1' e2' s1 s2 +elab f@(!_,!g) (ECst (EIte !p !e1 !e2) !t) = do + (!p', !_) <- elab f p + (!e1', !s1) <- elab f (eCst e1 t) + (!e2', !s2) <- elab f (eCst e2 t) + !s <- checkIteTy g p e1' e2' s1 s2 return (EIte p' (eCst e1' s) (eCst e2' s), t) -elab f@(_,g) (EIte p e1 e2) = do - t <- getIte g e1 e2 - (p', _) <- elab f p - (e1', s1) <- elab f (eCst e1 t) - (e2', s2) <- elab f (eCst e2 t) - s <- checkIteTy g p e1' e2' s1 s2 +elab f@(!_,!g) (EIte !p !e1 !e2) = do + !t <- getIte g e1 e2 + (!p', !_) <- elab f p + (!e1', !s1) <- elab f (eCst e1 t) + (!e2', !s2) <- elab f (eCst e2 t) + !s <- checkIteTy g p e1' e2' s1 s2 return (EIte p' (eCst e1' s) (eCst e2' s), s) -elab f (ECst e t) = do - (e', _) <- elab f e + +elab !f (ECst !e !t) = do + (!e', !_) <- elab f e return (eCst e' t, t) -elab f (PNot p) = do - (e', _) <- elab f p +elab !f (PNot !p) = do + (!e', !_) <- elab f p return (PNot e', boolSort) -elab f (PImp p1 p2) = do - (p1', _) <- elab f p1 - (p2', _) <- elab f p2 +elab !f (PImp !p1 !p2) = do + (!p1', !_) <- elab f p1 + (!p2', !_) <- elab f p2 return (PImp p1' p2', boolSort) -elab f (PIff p1 p2) = do - (p1', _) <- elab f p1 - (p2', _) <- elab f p2 +elab !f (PIff !p1 !p2) = do + (!p1', !_) <- elab f p1 + (!p2', !_) <- elab f p2 return (PIff p1' p2', boolSort) -elab f (PAnd ps) = do - ps' <- mapM (elab f) ps +elab !f (PAnd !ps) = do + !ps' <- mapM (elab f) ps return (PAnd (fst <$> ps'), boolSort) -elab f (POr ps) = do - ps' <- mapM (elab f) ps +elab !f (POr !ps) = do + !ps' <- mapM (elab f) ps return (POr (fst <$> ps'), boolSort) -elab f@(_,g) e@(PAtom eq e1 e2) | eq == Eq || eq == Ne = do - t1 <- checkExpr g e1 - t2 <- checkExpr g e2 - (t1',t2') <- unite g e t1 t2 `withError` errElabExpr e - e1' <- elabAs f t1' e1 - e2' <- elabAs f t2' e2 - e1'' <- eCstAtom f e1' t1' - e2'' <- eCstAtom f e2' t2' - return (PAtom eq e1'' e2'' , boolSort) - -elab f (PAtom r e1 e2) +elab f@(!_,!g) e@(PAtom !eq !e1 !e2) | eq == Eq || eq == Ne = do + !t1 <- checkExpr g e1 + !t2 <- checkExpr g e2 + (!t1',!t2') <- unite g e t1 t2 `withError` errElabExpr e + !e1' <- elabAs f t1' e1 + !e2' <- elabAs f t2' e2 + !e1'' <- eCstAtom f e1' t1' + !e2'' <- eCstAtom f e2' t2' + return (PAtom eq e1'' e2'', boolSort) + +elab !f (PAtom !r !e1 !e2) | r == Ueq || r == Une = do - (e1', _) <- elab f e1 - (e2', _) <- elab f e2 + (!e1', !_) <- elab f e1 + (!e2', !_) <- elab f e2 return (PAtom r e1' e2', boolSort) -elab f@(env,_) (PAtom r e1 e2) = do - e1' <- uncurry (toInt env) <$> elab f e1 - e2' <- uncurry (toInt env) <$> elab f e2 +elab f@(!env,!_) (PAtom !r !e1 !e2) = do + !e1' <- uncurry (toInt env) <$> elab f e1 + !e2' <- uncurry (toInt env) <$> elab f e2 return (PAtom r e1' e2', boolSort) -elab f (PExist bs e) = do - (e', s) <- elab (elabAddEnv f bs) e - let bs' = elaborate "PExist Args" mempty bs +elab !f (PExist !bs !e) = do + (!e', !s) <- elab (elabAddEnv f bs) e + let !bs' = elaborate "PExist Args" mempty bs return (PExist bs' e', s) -elab f (PAll bs e) = do - (e', s) <- elab (elabAddEnv f bs) e - let bs' = elaborate "PAll Args" mempty bs +elab !f (PAll !bs !e) = do + (!e', !s) <- elab (elabAddEnv f bs) e + let !bs' = elaborate "PAll Args" mempty bs return (PAll bs' e', s) -elab f (ELam (x,t) e) = do - (e', s) <- elab (elabAddEnv f [(x, t)]) e - let t' = elaborate "ELam Arg" mempty t +elab !f (ELam (!x,!t) !e) = do + (!e', !s) <- elab (elabAddEnv f [(x, t)]) e + let !t' = elaborate "ELam Arg" mempty t return (ELam (x, t') (eCst e' s), FFunc t s) -elab f (ECoerc s t e) = do - (e', _) <- elab f e - return (ECoerc s t e', t) +elab !f (ECoerc !s !t !e) = do + (!e', !_) <- elab f e + return (ECoerc s t e', t) -elab _ (ETApp _ _) = +elab !_ (ETApp _ _) = error "SortCheck.elab: TODO: implement ETApp" -elab _ (ETAbs _ _) = +elab !_ (ETAbs _ _) = error "SortCheck.elab: TODO: implement ETAbs" - -- | 'eCstAtom' is to support tests like `tests/pos/undef00.fq` eCstAtom :: ElabEnv -> Expr -> Sort -> CheckM Expr eCstAtom f@(sym,g) (EVar x) t @@ -970,7 +975,7 @@ refreshNegativeTyVars s = do let negativeSorts = negSort s freshVars <- mapM pair $ S.toList negativeSorts pure $ foldr (uncurry subst) s freshVars - where + where pair i = do f <- fresh pure (i, FVar f) diff --git a/src/Language/Fixpoint/Types/Visitor.hs b/src/Language/Fixpoint/Types/Visitor.hs index 163518e0f..2afc55770 100644 --- a/src/Language/Fixpoint/Types/Visitor.hs +++ b/src/Language/Fixpoint/Types/Visitor.hs @@ -219,30 +219,100 @@ mapExpr f = trans (defaultVisitor {txExpr = const f}) () () mapExprOnExpr :: (Expr -> Expr) -> Expr -> Expr mapExprOnExpr f = go where - go e0 = f $ case e0 of - EApp f e -> EApp (go f) (go e) - ENeg e -> ENeg (go e) - EBin o e1 e2 -> EBin o (go e1) (go e2) - EIte p e1 e2 -> EIte (go p) (go e1) (go e2) - ECst e t -> ECst (go e) t - PAnd ps -> PAnd (map go ps) - POr ps -> POr (map go ps) - PNot p -> PNot (go p) - PImp p1 p2 -> PImp (go p1) (go p2) - PIff p1 p2 -> PIff (go p1) (go p2) - PAtom r e1 e2 -> PAtom r (go e1) (go e2) - PAll xts p -> PAll xts (go p) - ELam (x,t) e -> ELam (x,t) (go e) - ECoerc a t e -> ECoerc a t (go e) - PExist xts p -> PExist xts (go p) - ETApp e s -> ETApp (go e) s - ETAbs e s -> ETAbs (go e) s - PGrad k su i e -> PGrad k su i (go e) + go !e0 = f $! case e0 of + EApp f e -> + let !f' = go f + !e' = go e + in EApp f' e' + ENeg e -> + let !e' = go e + in ENeg e' + EBin o e1 e2 -> + let !e1' = go e1 + !e2' = go e2 + in EBin o e1' e2' + EIte p e1 e2 -> + let !p' = go p + !e1' = go e1 + !e2' = go e2 + in EIte p' e1' e2' + ECst e t -> + let !e' = go e + in ECst e' t + PAnd ps -> + let !ps' = map go ps + in PAnd ps' + POr ps -> + let !ps' = map go ps + in POr ps' + PNot p -> + let !p' = go p + in PNot p' + PImp p1 p2 -> + let !p1' = go p1 + !p2' = go p2 + in PImp p1' p2' + PIff p1 p2 -> + let !p1' = go p1 + !p2' = go p2 + in PIff p1' p2' + PAtom r e1 e2 -> + let !e1' = go e1 + !e2' = go e2 + in PAtom r e1' e2' + PAll xts p -> + let !p' = go p + in PAll xts p' + ELam (x,t) e -> + let !e' = go e + in ELam (x,t) e' + ECoerc a t e -> + let !e' = go e + in ECoerc a t e' + PExist xts p -> + let !p' = go p + in PExist xts p' + ETApp e s -> + let !e' = go e + in ETApp e' s + ETAbs e s -> + let !e' = go e + in ETAbs e' s + PGrad k su i e -> + let !e' = go e + in PGrad k su i e' e@PKVar{} -> e e@EVar{} -> e e@ESym{} -> e e@ECon{} -> e +-- mapExprOnExpr :: (Expr -> Expr) -> Expr -> Expr +-- mapExprOnExpr f = go +-- where +-- go !e0 = f $! case e0 of +-- EApp f e -> EApp !(go f) !(go e) +-- ENeg e -> ENeg (go e) +-- EBin o e1 e2 -> EBin o (go e1) (go e2) +-- EIte p e1 e2 -> EIte (go p) (go e1) (go e2) +-- ECst e t -> ECst (go e) t +-- PAnd ps -> PAnd (map go ps) +-- POr ps -> POr (map go ps) +-- PNot p -> PNot (go p) +-- PImp p1 p2 -> PImp (go p1) (go p2) +-- PIff p1 p2 -> PIff (go p1) (go p2) +-- PAtom r e1 e2 -> PAtom r (go e1) (go e2) +-- PAll xts p -> PAll xts (go p) +-- ELam (x,t) e -> ELam (x,t) (go e) +-- ECoerc a t e -> ECoerc a t (go e) +-- PExist xts p -> PExist xts (go p) +-- ETApp e s -> ETApp (go e) s +-- ETAbs e s -> ETAbs (go e) s +-- PGrad k su i e -> PGrad k su i (go e) +-- e@PKVar{} -> e +-- e@EVar{} -> e +-- e@ESym{} -> e +-- e@ECon{} -> e + mapMExpr :: (Monad m) => (Expr -> m Expr) -> Expr -> m Expr mapMExpr f = go From c3a8a08359895c11be0658079731523efbe6b12e Mon Sep 17 00:00:00 2001 From: vrindisbacher Date: Thu, 16 Jan 2025 18:35:52 -0800 Subject: [PATCH 02/33] delay applyExpr until elaboration is finished --- src/Language/Fixpoint/SortCheck.hs | 49 ++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index 095f9d8c6..7963771e1 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -273,11 +273,15 @@ elabExpr msg env e = case elabExprE msg env e of elabExprE :: Located String -> SymEnv -> Expr -> Either Error Expr elabExprE msg env e = - case runCM0 (srcSpan msg) (elab (env, envLookup) e) of + case runCM0 (srcSpan msg) $ do + (!e', _) <- elab (env, envLookup) e + finalThetaRef <- asks chTVSubst + finalTheta <- liftIO $ readIORef finalThetaRef + return (applyExpr finalTheta e') of Left (ChError f') -> let e' = f' () in Left $ err (srcSpan e') (d (val e')) - Right s -> Right (fst s) + Right s -> Right s where sEnv = seSort env envLookup = (`lookupSEnvWithDistance` sEnv) @@ -371,7 +375,7 @@ instance Show ChError where show (ChError f) = show (f ()) instance Exception ChError where -data ChState = ChS { chCount :: IORef Int, chSpan :: SrcSpan } +data ChState = ChS {chCount :: IORef Int, chSpan :: SrcSpan, chTVSubst :: IORef (Maybe TVSubst)} type Env = Symbol -> SESearch Sort type ElabEnv = (SymEnv, Env) @@ -406,7 +410,8 @@ varCounterRef = unsafePerformIO $ newIORef 42 -- value of counter. runCM0 :: SrcSpan -> CheckM a -> Either ChError a runCM0 sp act = unsafePerformIO $ do - try (runReaderT act (ChS varCounterRef sp)) + ref <- newIORef Nothing + try (runReaderT act (ChS varCounterRef sp ref)) fresh :: CheckM Int fresh = do @@ -538,17 +543,12 @@ elab f@(!_, !g) e@(EBin !o !e1 !e2) = do return (result, s) -elab !f (EApp e1@(EApp !_ !_) !e2) = do - (!e1', !_, !e2', !s2, !s) <- notracepp "ELAB-EAPP" <$> elabEApp f e1 e2 - let !e = eAppC s e1' (eCst e2' s2) - let !θ = unifyExpr (snd f) e - return (applyExpr θ e, maybe s (`apply` s) θ) - elab !f (EApp !e1 !e2) = do (!e1', !s1, !e2', !s2, !s) <- elabEApp f e1 e2 let !e = eAppC s (eCst e1' s1) (eCst e2' s2) let !θ = unifyExpr (snd f) e - return (applyExpr θ e, maybe s (`apply` s) θ) + composeTVSubst θ + return (e, maybe s (`apply` s) θ) elab !_ e@(ESym _) = @@ -716,7 +716,9 @@ elabAppSort f e1 e2 s1 s2 = do let e = Just (EApp e1 e2) (sIn, sOut, su) <- checkFunSort s1 su' <- unify1 f e su sIn s2 - return (applyExpr (Just su') e1 , applyExpr (Just su') e2, apply su' s1, apply su' s2, apply su' sOut) + composeTVSubst (Just su) + composeTVSubst (Just su') + return (e1 , e2, apply su' s1, apply su' s2, apply su' sOut) -------------------------------------------------------------------------------- @@ -1360,6 +1362,29 @@ unifyVar f e θ !i !t Just !t' -> if t == t' then return θ else unify1 f e θ t t' Nothing -> return (updateVar i t θ) + +-------------------------------------------------------------------------------- +-- | Update global subst to be applied to expressions +-------------------------------------------------------------------------------- + +updateTVSubst :: TVSubst -> CheckM () +updateTVSubst theta = do + refTheta <- asks chTVSubst + liftIO $ atomicModifyIORef' refTheta $ const (Just theta, ()) + +-- local (\s -> s {chTVSubst = theta}) (return ()) + +mergeTVSubst :: TVSubst -> Maybe TVSubst -> TVSubst +mergeTVSubst (Th m1) Nothing = Th m1 +mergeTVSubst (Th m1) (Just (Th m2)) = Th m1 <> Th m2 + +composeTVSubst :: Maybe TVSubst -> CheckM () +composeTVSubst Nothing = return () +composeTVSubst (Just theta1) = do + refTheta <- asks chTVSubst + theta <- liftIO $ readIORef refTheta + updateTVSubst (mergeTVSubst theta1 theta) + -------------------------------------------------------------------------------- -- | Applying a Type Substitution ---------------------------------------------- -------------------------------------------------------------------------------- From ac79d5ca46c02b6861d54537f0f0cc2762edce92 Mon Sep 17 00:00:00 2001 From: vrindisbacher Date: Thu, 16 Jan 2025 22:41:06 -0800 Subject: [PATCH 03/33] More strict mode that makes a seemingly large difference --- src/Language/Fixpoint/SortCheck.hs | 10 +++++----- src/Language/Fixpoint/Types/Visitor.hs | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index 7963771e1..862d19a38 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -1390,17 +1390,17 @@ composeTVSubst (Just theta1) = do -------------------------------------------------------------------------------- apply :: TVSubst -> Sort -> Sort -------------------------------------------------------------------------------- -apply θ = Vis.mapSort f +apply !θ = Vis.mapSort f where - f t@(FVar i) = fromMaybe t (lookupVar i θ) - f t = t + f t@(FVar !i) = fromMaybe t (lookupVar i θ) + f !t = t applyExpr :: Maybe TVSubst -> Expr -> Expr applyExpr Nothing e = e applyExpr (Just θ) e = Vis.mapExprOnExpr f e where - f (ECst e' s) = ECst e' (apply θ s) - f e' = e' + f (ECst !e' !s) = ECst e' (apply θ s) + f !e' = e' -------------------------------------------------------------------------------- _applyCoercion :: Symbol -> Sort -> Sort -> Sort diff --git a/src/Language/Fixpoint/Types/Visitor.hs b/src/Language/Fixpoint/Types/Visitor.hs index 2afc55770..7576ae183 100644 --- a/src/Language/Fixpoint/Types/Visitor.hs +++ b/src/Language/Fixpoint/Types/Visitor.hs @@ -87,10 +87,10 @@ fold :: (Visitable t, Monoid a) => Visitor a ctx -> ctx -> a -> t -> a fold v c a t = snd $ execVisitM v c a visit t trans :: (Visitable t, Monoid a) => Visitor a ctx -> ctx -> a -> t -> t -trans v c _ z = fst $ execVisitM v c mempty visit z +trans !v !c !_ !z = fst $ execVisitM v c mempty visit z execVisitM :: Visitor a ctx -> ctx -> a -> (Visitor a ctx -> ctx -> t -> State a t) -> t -> (t, a) -execVisitM v c a f x = runState (f v c x) a +execVisitM !v !c !a !f !x = runState (f v c x) a type VisitM acc = State acc From 9f55cca463b7154163fc73f772d3fa38c21bdcca Mon Sep 17 00:00:00 2001 From: vrindisbacher Date: Fri, 17 Jan 2025 20:50:04 -0800 Subject: [PATCH 04/33] Use ReaderT instead of State monad in visitor --- src/Language/Fixpoint/Types/Visitor.hs | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/Language/Fixpoint/Types/Visitor.hs b/src/Language/Fixpoint/Types/Visitor.hs index 7576ae183..f90843d1b 100644 --- a/src/Language/Fixpoint/Types/Visitor.hs +++ b/src/Language/Fixpoint/Types/Visitor.hs @@ -57,6 +57,9 @@ import qualified Data.HashMap.Strict as M import qualified Data.List as L import Language.Fixpoint.Types hiding (mapSort) import qualified Language.Fixpoint.Misc as Misc +import Control.Monad.Reader +import GHC.IO (unsafePerformIO) +import Data.IORef (newIORef, readIORef, IORef, modifyIORef') @@ -89,16 +92,19 @@ fold v c a t = snd $ execVisitM v c a visit t trans :: (Visitable t, Monoid a) => Visitor a ctx -> ctx -> a -> t -> t trans !v !c !_ !z = fst $ execVisitM v c mempty visit z -execVisitM :: Visitor a ctx -> ctx -> a -> (Visitor a ctx -> ctx -> t -> State a t) -> t -> (t, a) -execVisitM !v !c !a !f !x = runState (f v c x) a +execVisitM :: Visitor a ctx -> ctx -> a -> (Visitor a ctx -> ctx -> t -> VisitM a t) -> t -> (t, a) +execVisitM !v !c !a !f !x = unsafePerformIO $ do + rn <- newIORef a + result <- runReaderT (f v c x) rn + finalAcc <- readIORef rn + return (result, finalAcc) -type VisitM acc = State acc +type VisitM acc = ReaderT (IORef acc) IO accum :: (Monoid a) => a -> VisitM a () -accum !z = modify (mappend z) - -- do - -- !cur <- get - -- put ((mappend $!! z) $!! cur) +accum !z = do + ref <- ask + liftIO $ modifyIORef' ref (mappend z) class Visitable t where visit :: (Monoid a) => Visitor a c -> c -> t -> VisitM a t From 1a12686a22724321a04e32ca767683b9fed5f059 Mon Sep 17 00:00:00 2001 From: vrindisbacher Date: Sat, 18 Jan 2025 15:36:30 -0800 Subject: [PATCH 05/33] specialized visitor for trans since acc is unit --- src/Language/Fixpoint/Solver/Sanitize.hs | 12 --- src/Language/Fixpoint/Types/Visitor.hs | 123 +++++++++++++++++++---- 2 files changed, 101 insertions(+), 34 deletions(-) diff --git a/src/Language/Fixpoint/Solver/Sanitize.hs b/src/Language/Fixpoint/Solver/Sanitize.hs index 15ac71063..1360591c2 100644 --- a/src/Language/Fixpoint/Solver/Sanitize.hs +++ b/src/Language/Fixpoint/Solver/Sanitize.hs @@ -51,8 +51,6 @@ sanitize cfg = banIrregularData >=> banConstraintFreeVars cfg >=> Misc.fM addLiterals >=> Misc.fM (eliminateEta cfg) - >=> Misc.fM cancelCoercion - -------------------------------------------------------------------------------- -- | 'dropAdtMeasures' removes all the measure definitions that correspond to @@ -83,16 +81,6 @@ addLiterals si = si { F.dLits = F.unionSEnv (F.dLits si) lits' where lits' = M.fromList [ (F.symbol x, F.strSort) | x <- symConsts si ] - - -cancelCoercion :: F.SInfo a -> F.SInfo a -cancelCoercion = mapExpr (trans (defaultVisitor { txExpr = go }) () ()) - where - go _ (F.ECoerc t1 t2 (F.ECoerc t2' t1' e)) - | t1 == t1' && t2 == t2' - = e - go _ e = e - -------------------------------------------------------------------------------- -- | `eliminateEta` converts equations of the form f x = g x into f = g -------------------------------------------------------------------------------- diff --git a/src/Language/Fixpoint/Types/Visitor.hs b/src/Language/Fixpoint/Types/Visitor.hs index f90843d1b..805b1ebc3 100644 --- a/src/Language/Fixpoint/Types/Visitor.hs +++ b/src/Language/Fixpoint/Types/Visitor.hs @@ -6,6 +6,7 @@ {-# LANGUAGE BangPatterns #-} {-# OPTIONS_GHC -Wno-name-shadowing #-} +{-# LANGUAGE InstanceSigs #-} module Language.Fixpoint.Types.Visitor ( -- * Visitor @@ -89,8 +90,88 @@ defaultVisitor = Visitor fold :: (Visitable t, Monoid a) => Visitor a ctx -> ctx -> a -> t -> a fold v c a t = snd $ execVisitM v c a visit t -trans :: (Visitable t, Monoid a) => Visitor a ctx -> ctx -> a -> t -> t -trans !v !c !_ !z = fst $ execVisitM v c mempty visit z +-- trans is always passed () () for a and t so we don't need to use the visitor pattern +-- trans :: (Visitable t, Monoid a) => Visitor a ctx -> ctx -> a -> t -> t +-- trans !v !c !_ !z = fst $ execVisitM v c mempty visit z + +class VisitableSpecialized t where + visitSpecialized :: (Expr -> Expr) -> t -> t + +trans :: VisitableSpecialized t => (Expr -> Expr) -> t -> t +trans f t = visitSpecialized f t + +instance VisitableSpecialized Expr where + visitSpecialized f = vE + where + vE e = step e + step e@(ESym _) = f e + step e@(ECon _) = f e + step e@(EVar _) = f e + step (EApp f e) = EApp (vE f) (vE e) + step (ENeg e) = ENeg (vE e) + step (EBin o e1 e2) = EBin o (vE e1) (vE e2) + step (EIte p e1 e2) = EIte (vE p) (vE e1) (vE e2) + step (ECst e t) = ECst (vE e) t + step (PAnd ps) = PAnd (map vE ps) + step (POr ps) = POr (map vE ps) + step (PNot p) = PNot (vE p) + step (PImp p1 p2) = PImp (vE p1) (vE p2) + step (PIff p1 p2) = PIff (vE p1) (vE p2) + step (PAtom r e1 e2) = PAtom r (vE e1) (vE e2) + step (PAll xts p) = PAll xts (vE p) + step (ELam (x,t) e) = ELam (x,t) (vE e) + step (ECoerc a t e) = ECoerc a t (vE e) + step (PExist xts p) = PExist xts (vE p) + step (ETApp e s) = ETApp (vE e) s + step (ETAbs e s) = ETAbs (vE e) s + step p@(PKVar _ _) = p + step (PGrad k su i e) = PGrad k su i (vE e) + +instance VisitableSpecialized Reft where + visitSpecialized v (Reft (x, ra)) = Reft (x, visitSpecialized v ra) + +instance VisitableSpecialized SortedReft where + visitSpecialized v (RR t r) = RR t (visitSpecialized v r) + +instance VisitableSpecialized (Symbol, SortedReft, a) where + visitSpecialized f (sym, sr, a) = (sym, visitSpecialized f sr, a) + +instance VisitableSpecialized (BindEnv a) where + visitSpecialized v be = be { beBinds = M.map (visitSpecialized v) (beBinds be) } + +instance (VisitableSpecialized (c a)) => VisitableSpecialized (GInfo c a) where + visitSpecialized f x = x { + cm = visitSpecialized f <$> cm x + , bs = visitSpecialized f (bs x) + , ae = visitSpecialized f (ae x) + } + +instance VisitableSpecialized (SimpC a) where + visitSpecialized v x = x { + _crhs = visitSpecialized v (_crhs x) + } + +instance VisitableSpecialized (SubC a) where + visitSpecialized v x = x { + slhs = visitSpecialized v (slhs x), + srhs = visitSpecialized v (srhs x) + } + +instance VisitableSpecialized AxiomEnv where + visitSpecialized v x = x { + aenvEqs = visitSpecialized v <$> aenvEqs x, + aenvSimpl = visitSpecialized v <$> aenvSimpl x + } + +instance VisitableSpecialized Equation where + visitSpecialized v eq = eq { + eqBody = visitSpecialized v (eqBody eq) + } + +instance VisitableSpecialized Rewrite where + visitSpecialized v rw = rw { + smBody = visitSpecialized v (smBody rw) + } execVisitM :: Visitor a ctx -> ctx -> a -> (Visitor a ctx -> ctx -> t -> VisitM a t) -> t -> (t, a) execVisitM !v !c !a !f !x = unsafePerformIO $ do @@ -193,33 +274,31 @@ visitExpr !v = vE step _ p@(PKVar _ _) = return p step !c (PGrad k su i e) = PGrad k su i <$> vE c e -mapKVars :: Visitable t => (KVar -> Maybe Expr) -> t -> t +mapKVars :: VisitableSpecialized t => (KVar -> Maybe Expr) -> t -> t mapKVars f = mapKVars' f' where f' (kv', _) = f kv' -mapKVars' :: Visitable t => ((KVar, Subst) -> Maybe Expr) -> t -> t -mapKVars' f = trans kvVis () () +mapKVars' :: VisitableSpecialized t => ((KVar, Subst) -> Maybe Expr) -> t -> t +mapKVars' f = trans txK where - kvVis = defaultVisitor { txExpr = txK } - txK _ (PKVar k su) + txK (PKVar k su) | Just p' <- f (k, su) = subst su p' - txK _ (PGrad k su _ _) + txK (PGrad k su _ _) | Just p' <- f (k, su) = subst su p' - txK _ p = p + txK p = p -mapGVars' :: Visitable t => ((KVar, Subst) -> Maybe Expr) -> t -> t -mapGVars' f = trans kvVis () () +mapGVars' :: VisitableSpecialized t => ((KVar, Subst) -> Maybe Expr) -> t -> t +mapGVars' f = trans txK where - kvVis = defaultVisitor { txExpr = txK } - txK _ (PGrad k su _ _) + txK (PGrad k su _ _) | Just p' <- f (k, su) = subst su p' - txK _ p = p + txK p = p -mapExpr :: Visitable t => (Expr -> Expr) -> t -> t -mapExpr f = trans (defaultVisitor {txExpr = const f}) () () +mapExpr :: VisitableSpecialized t => (Expr -> Expr) -> t -> t +mapExpr f = trans f -- | Specialized and faster version of mapExpr for expressions mapExprOnExpr :: (Expr -> Expr) -> Expr -> Expr @@ -346,13 +425,12 @@ mapMExpr f = go go (PAnd ps) = f . PAnd =<< (go `traverse` ps) go (POr ps) = f . POr =<< (go `traverse` ps) -mapKVarSubsts :: Visitable t => (KVar -> Subst -> Subst) -> t -> t -mapKVarSubsts f = trans kvVis () () +mapKVarSubsts :: VisitableSpecialized t => (KVar -> Subst -> Subst) -> t -> t +mapKVarSubsts f = trans txK where - kvVis = defaultVisitor { txExpr = txK } - txK _ (PKVar k su) = PKVar k (f k su) - txK _ (PGrad k su i e) = PGrad k (f k su) i e - txK _ p = p + txK (PKVar k su) = PKVar k (f k su) + txK (PGrad k su i e) = PGrad k (f k su) i e + txK p = p newtype MInt = MInt Integer -- deriving (Eq, NFData) @@ -361,6 +439,7 @@ instance Semigroup MInt where instance Monoid MInt where mempty = MInt 0 + mappend :: MInt -> MInt -> MInt mappend = (<>) size :: Visitable t => t -> Integer From 36ae1ca5cdf1beb6835ccb157212647beb1b33c4 Mon Sep 17 00:00:00 2001 From: vrindisbacher Date: Mon, 20 Jan 2025 10:42:24 -0800 Subject: [PATCH 06/33] Fix traversal order for transE on expressions --- src/Language/Fixpoint/Types/Visitor.hs | 92 +++++++++++++------------- 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/src/Language/Fixpoint/Types/Visitor.hs b/src/Language/Fixpoint/Types/Visitor.hs index 805b1ebc3..2da542051 100644 --- a/src/Language/Fixpoint/Types/Visitor.hs +++ b/src/Language/Fixpoint/Types/Visitor.hs @@ -94,20 +94,20 @@ fold v c a t = snd $ execVisitM v c a visit t -- trans :: (Visitable t, Monoid a) => Visitor a ctx -> ctx -> a -> t -> t -- trans !v !c !_ !z = fst $ execVisitM v c mempty visit z -class VisitableSpecialized t where - visitSpecialized :: (Expr -> Expr) -> t -> t +class Translatable t where + transE :: (Expr -> Expr) -> t -> t -trans :: VisitableSpecialized t => (Expr -> Expr) -> t -> t -trans f t = visitSpecialized f t +trans :: Translatable t => (Expr -> Expr) -> t -> t +trans f t = transE f t -instance VisitableSpecialized Expr where - visitSpecialized f = vE +instance Translatable Expr where + transE f = vE where - vE e = step e - step e@(ESym _) = f e - step e@(ECon _) = f e - step e@(EVar _) = f e - step (EApp f e) = EApp (vE f) (vE e) + vE e = step e' where e' = f e + step e@(ESym _) = e + step e@(ECon _) = e + step e@(EVar _) = e + step (EApp e1 e2) = EApp (vE e1) (vE e2) step (ENeg e) = ENeg (vE e) step (EBin o e1 e2) = EBin o (vE e1) (vE e2) step (EIte p e1 e2) = EIte (vE p) (vE e1) (vE e2) @@ -127,50 +127,50 @@ instance VisitableSpecialized Expr where step p@(PKVar _ _) = p step (PGrad k su i e) = PGrad k su i (vE e) -instance VisitableSpecialized Reft where - visitSpecialized v (Reft (x, ra)) = Reft (x, visitSpecialized v ra) +instance Translatable Reft where + transE v (Reft (x, ra)) = Reft (x, transE v ra) -instance VisitableSpecialized SortedReft where - visitSpecialized v (RR t r) = RR t (visitSpecialized v r) +instance Translatable SortedReft where + transE v (RR t r) = RR t (transE v r) -instance VisitableSpecialized (Symbol, SortedReft, a) where - visitSpecialized f (sym, sr, a) = (sym, visitSpecialized f sr, a) +instance Translatable (Symbol, SortedReft, a) where + transE f (sym, sr, a) = (sym, transE f sr, a) -instance VisitableSpecialized (BindEnv a) where - visitSpecialized v be = be { beBinds = M.map (visitSpecialized v) (beBinds be) } +instance Translatable (BindEnv a) where + transE v be = be { beBinds = M.map (transE v) (beBinds be) } -instance (VisitableSpecialized (c a)) => VisitableSpecialized (GInfo c a) where - visitSpecialized f x = x { - cm = visitSpecialized f <$> cm x - , bs = visitSpecialized f (bs x) - , ae = visitSpecialized f (ae x) +instance (Translatable (c a)) => Translatable (GInfo c a) where + transE f x = x { + cm = transE f <$> cm x + , bs = transE f (bs x) + , ae = transE f (ae x) } -instance VisitableSpecialized (SimpC a) where - visitSpecialized v x = x { - _crhs = visitSpecialized v (_crhs x) +instance Translatable (SimpC a) where + transE v x = x { + _crhs = transE v (_crhs x) } -instance VisitableSpecialized (SubC a) where - visitSpecialized v x = x { - slhs = visitSpecialized v (slhs x), - srhs = visitSpecialized v (srhs x) +instance Translatable (SubC a) where + transE v x = x { + slhs = transE v (slhs x), + srhs = transE v (srhs x) } -instance VisitableSpecialized AxiomEnv where - visitSpecialized v x = x { - aenvEqs = visitSpecialized v <$> aenvEqs x, - aenvSimpl = visitSpecialized v <$> aenvSimpl x +instance Translatable AxiomEnv where + transE v x = x { + aenvEqs = transE v <$> aenvEqs x, + aenvSimpl = transE v <$> aenvSimpl x } -instance VisitableSpecialized Equation where - visitSpecialized v eq = eq { - eqBody = visitSpecialized v (eqBody eq) +instance Translatable Equation where + transE v eq = eq { + eqBody = transE v (eqBody eq) } -instance VisitableSpecialized Rewrite where - visitSpecialized v rw = rw { - smBody = visitSpecialized v (smBody rw) +instance Translatable Rewrite where + transE v rw = rw { + smBody = transE v (smBody rw) } execVisitM :: Visitor a ctx -> ctx -> a -> (Visitor a ctx -> ctx -> t -> VisitM a t) -> t -> (t, a) @@ -274,12 +274,12 @@ visitExpr !v = vE step _ p@(PKVar _ _) = return p step !c (PGrad k su i e) = PGrad k su i <$> vE c e -mapKVars :: VisitableSpecialized t => (KVar -> Maybe Expr) -> t -> t +mapKVars :: Translatable t => (KVar -> Maybe Expr) -> t -> t mapKVars f = mapKVars' f' where f' (kv', _) = f kv' -mapKVars' :: VisitableSpecialized t => ((KVar, Subst) -> Maybe Expr) -> t -> t +mapKVars' :: Translatable t => ((KVar, Subst) -> Maybe Expr) -> t -> t mapKVars' f = trans txK where txK (PKVar k su) @@ -290,14 +290,14 @@ mapKVars' f = trans txK -mapGVars' :: VisitableSpecialized t => ((KVar, Subst) -> Maybe Expr) -> t -> t +mapGVars' :: Translatable t => ((KVar, Subst) -> Maybe Expr) -> t -> t mapGVars' f = trans txK where txK (PGrad k su _ _) | Just p' <- f (k, su) = subst su p' txK p = p -mapExpr :: VisitableSpecialized t => (Expr -> Expr) -> t -> t +mapExpr :: Translatable t => (Expr -> Expr) -> t -> t mapExpr f = trans f -- | Specialized and faster version of mapExpr for expressions @@ -425,7 +425,7 @@ mapMExpr f = go go (PAnd ps) = f . PAnd =<< (go `traverse` ps) go (POr ps) = f . POr =<< (go `traverse` ps) -mapKVarSubsts :: VisitableSpecialized t => (KVar -> Subst -> Subst) -> t -> t +mapKVarSubsts :: Translatable t => (KVar -> Subst -> Subst) -> t -> t mapKVarSubsts f = trans txK where txK (PKVar k su) = PKVar k (f k su) From 140ae73812c9cdfcff4729d55fdf9c6779f1abb5 Mon Sep 17 00:00:00 2001 From: vrindisbacher Date: Mon, 20 Jan 2025 10:58:21 -0800 Subject: [PATCH 07/33] renaming visitable -> foldable and specialized visitor for trans to Visitable --- src/Language/Fixpoint/Horn/Transformations.hs | 22 +-- src/Language/Fixpoint/SortCheck.hs | 4 +- src/Language/Fixpoint/Types/Visitor.hs | 147 +++++++++--------- 3 files changed, 87 insertions(+), 86 deletions(-) diff --git a/src/Language/Fixpoint/Horn/Transformations.hs b/src/Language/Fixpoint/Horn/Transformations.hs index 0a71c20f8..7b79d3231 100644 --- a/src/Language/Fixpoint/Horn/Transformations.hs +++ b/src/Language/Fixpoint/Horn/Transformations.hs @@ -543,7 +543,7 @@ elimKs' (k:ks) (noside, side) = elimKs' (trace ("solved kvar " <> F.showpp k <> -- exists in the positive positions (which will stay exists when we go to -- prenex) may give us a lot of trouble during _quantifier elimination_ -- tx :: F.Symbol -> [[Bind]] -> Pred -> Pred --- tx k bss = trans (defaultVisitor { txExpr = existentialPackage, ctxExpr = ctxKV }) M.empty () +-- tx k bss = trans (defaultFolder { txExpr = existentialPackage, ctxExpr = ctxKV }) M.empty () -- where -- splitBinds xs = unzip $ (\(Bind x t p) -> ((x,t),p)) <$> xs -- cubeSol su (Bind _ _ (Reft eqs):xs) @@ -564,16 +564,16 @@ elimKs' (k:ks) (noside, side) = elimKs' (trace ("solved kvar " <> F.showpp k <> -- ctxKV m _ = m -- Visitor only visit Exprs in Pred! -instance V.Visitable Pred where - visit v c (PAnd ps) = PAnd <$> mapM (visit v c) ps - visit v c (Reft e) = Reft <$> visit v c e - visit _ _ var = pure var - -instance V.Visitable (Cstr a) where - visit v c (CAnd cs) = CAnd <$> mapM (visit v c) cs - visit v c (Head p a) = Head <$> visit v c p <*> pure a - visit v ctx (All (Bind x t p l) c) = All <$> (Bind x t <$> visit v ctx p <*> pure l) <*> visit v ctx c - visit v ctx (Any (Bind x t p l) c) = All <$> (Bind x t <$> visit v ctx p <*> pure l) <*> visit v ctx c +instance V.Foldable Pred where + foldE v c (PAnd ps) = PAnd <$> mapM (foldE v c) ps + foldE v c (Reft e) = Reft <$> foldE v c e + foldE _ _ var = pure var + +instance V.Foldable (Cstr a) where + foldE v c (CAnd cs) = CAnd <$> mapM (foldE v c) cs + foldE v c (Head p a) = Head <$> foldE v c p <*> pure a + foldE v ctx (All (Bind x t p l) c) = All <$> (Bind x t <$> foldE v ctx p <*> pure l) <*> foldE v ctx c + foldE v ctx (Any (Bind x t p l) c) = All <$> (Bind x t <$> foldE v ctx p <*> pure l) <*> foldE v ctx c ------------------------------------------------------------------------------ -- | Quantifier elimination for use with implicit solver diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index 862d19a38..140aa79f7 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -913,12 +913,12 @@ which, I imagine is what happens _somewhere_ inside GHC too? -} -------------------------------------------------------------------------------- -applySorts :: Vis.Visitable t => t -> [Sort] +applySorts :: Vis.Foldable t => t -> [Sort] -------------------------------------------------------------------------------- applySorts = {- notracepp "applySorts" . -} (defs ++) . Vis.fold vis () [] where defs = [FFunc t1 t2 | t1 <- basicSorts, t2 <- basicSorts] - vis = (Vis.defaultVisitor :: Vis.Visitor [KVar] t) { Vis.accExpr = go } + vis = (Vis.defaultFolder :: Vis.Folder [KVar] t) { Vis.accExpr = go } go _ (EApp (ECst (EVar f) t) _) -- get types needed for [NOTE:apply-monomorphism] | f == applyName = [t] diff --git a/src/Language/Fixpoint/Types/Visitor.hs b/src/Language/Fixpoint/Types/Visitor.hs index 2da542051..c2f9d6f21 100644 --- a/src/Language/Fixpoint/Types/Visitor.hs +++ b/src/Language/Fixpoint/Types/Visitor.hs @@ -10,14 +10,14 @@ module Language.Fixpoint.Types.Visitor ( -- * Visitor - Visitor (..) - , Visitable (..) + Folder (..) + , Foldable (..) -- * Extracting Symbolic Constants (String Literals) , SymConsts (..) -- * Default Visitor - , defaultVisitor + , defaultFolder -- * Transformers , trans @@ -61,11 +61,12 @@ import qualified Language.Fixpoint.Misc as Misc import Control.Monad.Reader import GHC.IO (unsafePerformIO) import Data.IORef (newIORef, readIORef, IORef, modifyIORef') +import Prelude hiding (Foldable) -data Visitor acc ctx = Visitor { +data Folder acc ctx = Visitor { -- | Context @ctx@ is built in a "top-down" fashion; not "across" siblings ctxExpr :: ctx -> Expr -> ctx @@ -77,9 +78,9 @@ data Visitor acc ctx = Visitor { } --------------------------------------------------------------------------------- -defaultVisitor :: (Monoid acc) => Visitor acc ctx +defaultFolder :: (Monoid acc) => Folder acc ctx --------------------------------------------------------------------------------- -defaultVisitor = Visitor +defaultFolder = Visitor { ctxExpr = const , txExpr = \_ x -> x , accExpr = \_ _ -> mempty @@ -87,20 +88,20 @@ defaultVisitor = Visitor ------------------------------------------------------------------------ -fold :: (Visitable t, Monoid a) => Visitor a ctx -> ctx -> a -> t -> a -fold v c a t = snd $ execVisitM v c a visit t +fold :: (Foldable t, Monoid a) => Folder a ctx -> ctx -> a -> t -> a +fold v c a t = snd $ execVisitM v c a foldE t -- trans is always passed () () for a and t so we don't need to use the visitor pattern -- trans :: (Visitable t, Monoid a) => Visitor a ctx -> ctx -> a -> t -> t -- trans !v !c !_ !z = fst $ execVisitM v c mempty visit z -class Translatable t where +class Visitable t where transE :: (Expr -> Expr) -> t -> t -trans :: Translatable t => (Expr -> Expr) -> t -> t +trans :: Visitable t => (Expr -> Expr) -> t -> t trans f t = transE f t -instance Translatable Expr where +instance Visitable Expr where transE f = vE where vE e = step e' where e' = f e @@ -127,124 +128,124 @@ instance Translatable Expr where step p@(PKVar _ _) = p step (PGrad k su i e) = PGrad k su i (vE e) -instance Translatable Reft where +instance Visitable Reft where transE v (Reft (x, ra)) = Reft (x, transE v ra) -instance Translatable SortedReft where +instance Visitable SortedReft where transE v (RR t r) = RR t (transE v r) -instance Translatable (Symbol, SortedReft, a) where +instance Visitable (Symbol, SortedReft, a) where transE f (sym, sr, a) = (sym, transE f sr, a) -instance Translatable (BindEnv a) where +instance Visitable (BindEnv a) where transE v be = be { beBinds = M.map (transE v) (beBinds be) } -instance (Translatable (c a)) => Translatable (GInfo c a) where +instance (Visitable (c a)) => Visitable (GInfo c a) where transE f x = x { cm = transE f <$> cm x , bs = transE f (bs x) , ae = transE f (ae x) } -instance Translatable (SimpC a) where +instance Visitable (SimpC a) where transE v x = x { _crhs = transE v (_crhs x) } -instance Translatable (SubC a) where +instance Visitable (SubC a) where transE v x = x { slhs = transE v (slhs x), srhs = transE v (srhs x) } -instance Translatable AxiomEnv where +instance Visitable AxiomEnv where transE v x = x { aenvEqs = transE v <$> aenvEqs x, aenvSimpl = transE v <$> aenvSimpl x } -instance Translatable Equation where +instance Visitable Equation where transE v eq = eq { eqBody = transE v (eqBody eq) } -instance Translatable Rewrite where +instance Visitable Rewrite where transE v rw = rw { smBody = transE v (smBody rw) } -execVisitM :: Visitor a ctx -> ctx -> a -> (Visitor a ctx -> ctx -> t -> VisitM a t) -> t -> (t, a) +execVisitM :: Folder a ctx -> ctx -> a -> (Folder a ctx -> ctx -> t -> FoldM a t) -> t -> (t, a) execVisitM !v !c !a !f !x = unsafePerformIO $ do rn <- newIORef a result <- runReaderT (f v c x) rn finalAcc <- readIORef rn return (result, finalAcc) -type VisitM acc = ReaderT (IORef acc) IO +type FoldM acc = ReaderT (IORef acc) IO -accum :: (Monoid a) => a -> VisitM a () +accum :: (Monoid a) => a -> FoldM a () accum !z = do ref <- ask liftIO $ modifyIORef' ref (mappend z) -class Visitable t where - visit :: (Monoid a) => Visitor a c -> c -> t -> VisitM a t +class Foldable t where + foldE :: (Monoid a) => Folder a c -> c -> t -> FoldM a t -instance Visitable Expr where - visit = visitExpr +instance Foldable Expr where + foldE = foldExpr -instance Visitable Reft where - visit v c (Reft (x, ra)) = Reft . (x, ) <$> visit v c ra +instance Foldable Reft where + foldE v c (Reft (x, ra)) = Reft . (x, ) <$> foldE v c ra -instance Visitable SortedReft where - visit v c (RR t r) = RR t <$> visit v c r +instance Foldable SortedReft where + foldE v c (RR t r) = RR t <$> foldE v c r -instance Visitable (Symbol, SortedReft, a) where - visit v c (sym, sr, a) = (sym, ,a) <$> visit v c sr +instance Foldable (Symbol, SortedReft, a) where + foldE v c (sym, sr, a) = (sym, ,a) <$> foldE v c sr -instance Visitable (BindEnv a) where - visit v c = mapM (visit v c) +instance Foldable (BindEnv a) where + foldE v c = mapM (foldE v c) --------------------------------------------------------------------------------- -- WARNING: these instances were written for mapKVars over GInfos only; -- check that they behave as expected before using with other clients. -instance Visitable (SimpC a) where - visit v c x = do - rhs' <- visit v c (_crhs x) +instance Foldable (SimpC a) where + foldE v c x = do + rhs' <- foldE v c (_crhs x) return x { _crhs = rhs' } -instance Visitable (SubC a) where - visit v c x = do - lhs' <- visit v c (slhs x) - rhs' <- visit v c (srhs x) +instance Foldable (SubC a) where + foldE v c x = do + lhs' <- foldE v c (slhs x) + rhs' <- foldE v c (srhs x) return x { slhs = lhs', srhs = rhs' } -instance (Visitable (c a)) => Visitable (GInfo c a) where - visit v c x = do - cm' <- mapM (visit v c) (cm x) - bs' <- visit v c (bs x) - ae' <- visit v c (ae x) +instance (Foldable (c a)) => Foldable (GInfo c a) where + foldE v c x = do + cm' <- mapM (foldE v c) (cm x) + bs' <- foldE v c (bs x) + ae' <- foldE v c (ae x) return x { cm = cm', bs = bs', ae = ae' } -instance Visitable AxiomEnv where - visit v c x = do - eqs' <- mapM (visit v c) (aenvEqs x) - simpls' <- mapM (visit v c) (aenvSimpl x) +instance Foldable AxiomEnv where + foldE v c x = do + eqs' <- mapM (foldE v c) (aenvEqs x) + simpls' <- mapM (foldE v c) (aenvSimpl x) return x { aenvEqs = eqs' , aenvSimpl = simpls'} -instance Visitable Equation where - visit v c eq = do - body' <- visit v c (eqBody eq) +instance Foldable Equation where + foldE v c eq = do + body' <- foldE v c (eqBody eq) return eq { eqBody = body' } -instance Visitable Rewrite where - visit v c rw = do - body' <- visit v c (smBody rw) +instance Foldable Rewrite where + foldE v c rw = do + body' <- foldE v c (smBody rw) return rw { smBody = body' } --------------------------------------------------------------------------------- -visitExpr :: (Monoid a) => Visitor a ctx -> ctx -> Expr -> VisitM a Expr -visitExpr !v = vE +foldExpr :: (Monoid a) => Folder a ctx -> ctx -> Expr -> FoldM a Expr +foldExpr !v = vE where vE !c !e = do {- SCC "visitExpr.vE.accum" -} accum acc {- SCC "visitExpr.vE.step" -} step c' e' @@ -274,12 +275,12 @@ visitExpr !v = vE step _ p@(PKVar _ _) = return p step !c (PGrad k su i e) = PGrad k su i <$> vE c e -mapKVars :: Translatable t => (KVar -> Maybe Expr) -> t -> t +mapKVars :: Visitable t => (KVar -> Maybe Expr) -> t -> t mapKVars f = mapKVars' f' where f' (kv', _) = f kv' -mapKVars' :: Translatable t => ((KVar, Subst) -> Maybe Expr) -> t -> t +mapKVars' :: Visitable t => ((KVar, Subst) -> Maybe Expr) -> t -> t mapKVars' f = trans txK where txK (PKVar k su) @@ -290,14 +291,14 @@ mapKVars' f = trans txK -mapGVars' :: Translatable t => ((KVar, Subst) -> Maybe Expr) -> t -> t +mapGVars' :: Visitable t => ((KVar, Subst) -> Maybe Expr) -> t -> t mapGVars' f = trans txK where txK (PGrad k su _ _) | Just p' <- f (k, su) = subst su p' txK p = p -mapExpr :: Translatable t => (Expr -> Expr) -> t -> t +mapExpr :: Visitable t => (Expr -> Expr) -> t -> t mapExpr f = trans f -- | Specialized and faster version of mapExpr for expressions @@ -425,7 +426,7 @@ mapMExpr f = go go (PAnd ps) = f . PAnd =<< (go `traverse` ps) go (POr ps) = f . POr =<< (go `traverse` ps) -mapKVarSubsts :: Translatable t => (KVar -> Subst -> Subst) -> t -> t +mapKVarSubsts :: Visitable t => (KVar -> Subst -> Subst) -> t -> t mapKVarSubsts f = trans txK where txK (PKVar k su) = PKVar k (f k su) @@ -442,25 +443,25 @@ instance Monoid MInt where mappend :: MInt -> MInt -> MInt mappend = (<>) -size :: Visitable t => t -> Integer +size :: Foldable t => t -> Integer size t = n where MInt n = fold szV () mempty t - szV = (defaultVisitor :: Visitor MInt t) { accExpr = \ _ _ -> MInt 1 } + szV = (defaultFolder :: Folder MInt t) { accExpr = \ _ _ -> MInt 1 } -lamSize :: Visitable t => t -> Integer +lamSize :: Foldable t => t -> Integer lamSize t = n where MInt n = fold szV () mempty t - szV = (defaultVisitor :: Visitor MInt t) { accExpr = accum } + szV = (defaultFolder :: Folder MInt t) { accExpr = accum } accum _ (ELam _ _) = MInt 1 accum _ _ = MInt 0 -eapps :: Visitable t => t -> [Expr] +eapps :: Foldable t => t -> [Expr] eapps = fold eappVis () [] where - eappVis = (defaultVisitor :: Visitor [KVar] t) { accExpr = eapp' } + eappVis = (defaultFolder :: Folder [KVar] t) { accExpr = eapp' } eapp' _ e@(EApp _ _) = [e] eapp' _ _ = [] @@ -666,9 +667,9 @@ instance SymConsts Reft where instance SymConsts Expr where symConsts = getSymConsts -getSymConsts :: Visitable t => t -> [SymConst] +getSymConsts :: Foldable t => t -> [SymConst] getSymConsts = fold scVis () [] where - scVis = (defaultVisitor :: Visitor [SymConst] t) { accExpr = sc } + scVis = (defaultFolder :: Folder [SymConst] t) { accExpr = sc } sc _ (ESym c) = [c] sc _ _ = [] From eddd41f2fab1292e36763ba3b25074cc211f7328 Mon Sep 17 00:00:00 2001 From: vrindisbacher Date: Fri, 24 Jan 2025 07:28:16 -0800 Subject: [PATCH 08/33] add union to CheckM and start to take apply out of elab --- liquid-fixpoint.cabal | 1 + src/Language/Fixpoint/SortCheck.hs | 135 +++++++++++++++-------------- src/Language/Fixpoint/Union.hs | 33 +++++++ 3 files changed, 104 insertions(+), 65 deletions(-) create mode 100644 src/Language/Fixpoint/Union.hs diff --git a/liquid-fixpoint.cabal b/liquid-fixpoint.cabal index d42136211..c03821a09 100644 --- a/liquid-fixpoint.cabal +++ b/liquid-fixpoint.cabal @@ -100,6 +100,7 @@ library Language.Fixpoint.Solver.UniqifyKVars Language.Fixpoint.Solver.Worklist Language.Fixpoint.SortCheck + Language.Fixpoint.Union Language.Fixpoint.Types Language.Fixpoint.Types.Config Language.Fixpoint.Types.Constraints diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index 140aa79f7..ea502017a 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -75,7 +75,7 @@ import qualified Data.HashMap.Strict as M import qualified Data.HashSet as S import Data.IORef import qualified Data.List as L -import Data.Maybe (mapMaybe, fromMaybe, catMaybes, isJust) +import Data.Maybe (mapMaybe, fromMaybe, isJust) import Language.Fixpoint.Types.PrettyPrint import Language.Fixpoint.Misc @@ -88,6 +88,7 @@ import Text.Printf import GHC.Stack import qualified Language.Fixpoint.Types as F import System.IO.Unsafe (unsafePerformIO) +import qualified Language.Fixpoint.Union as Union --import Debug.Trace as Debug @@ -375,7 +376,7 @@ instance Show ChError where show (ChError f) = show (f ()) instance Exception ChError where -data ChState = ChS {chCount :: IORef Int, chSpan :: SrcSpan, chTVSubst :: IORef (Maybe TVSubst)} +data ChState = ChS {chCount :: IORef Int, chSpan :: SrcSpan, ufM :: IORef Union.UF, chTVSubst :: IORef (Maybe TVSubst)} type Env = Symbol -> SESearch Sort type ElabEnv = (SymEnv, Env) @@ -410,8 +411,9 @@ varCounterRef = unsafePerformIO $ newIORef 42 -- value of counter. runCM0 :: SrcSpan -> CheckM a -> Either ChError a runCM0 sp act = unsafePerformIO $ do - ref <- newIORef Nothing - try (runReaderT act (ChS varCounterRef sp ref)) + suR <- newIORef Nothing + ufR <- newIORef Union.new + try (runReaderT act (ChS varCounterRef sp ufR suR)) fresh :: CheckM Int fresh = do @@ -546,9 +548,9 @@ elab f@(!_, !g) e@(EBin !o !e1 !e2) = do elab !f (EApp !e1 !e2) = do (!e1', !s1, !e2', !s2, !s) <- elabEApp f e1 e2 let !e = eAppC s (eCst e1' s1) (eCst e2' s2) - let !θ = unifyExpr (snd f) e - composeTVSubst θ - return (e, maybe s (`apply` s) θ) + -- let !θ = unifyExpr (snd f) e + -- composeTVSubst θ + return (e, s) elab !_ e@(ESym _) = @@ -693,14 +695,14 @@ elabAs f t e = notracepp _msg <$> go e -- DUPLICATION with `checkApp'` elabAppAs :: ElabEnv -> Sort -> Expr -> Expr -> CheckM Expr elabAppAs env@(_, f) t g e = do - gT <- checkExpr f g - eT <- checkExpr f e - (iT, oT, isu) <- checkFunSort gT + tg <- checkExpr f g + te <- checkExpr f e + (iT, oT) <- checkFunSort tg let ge = Just (EApp g e) - su <- unifyMany f ge isu [oT, iT] [t, eT] - let tg = apply su gT + _ <- unifyMany f ge emptySubst [oT, iT] [t, te] + -- let tg = apply su tg g' <- elabAs env tg g - let te = apply su eT + -- let te = apply su te e' <- elabAs env te e pure $ EApp (ECst g' tg) (ECst e' te) @@ -714,11 +716,11 @@ elabEApp f@(_, g) e1 e2 = do elabAppSort :: Env -> Expr -> Expr -> Sort -> Sort -> CheckM (Expr, Expr, Sort, Sort, Sort) elabAppSort f e1 e2 s1 s2 = do let e = Just (EApp e1 e2) - (sIn, sOut, su) <- checkFunSort s1 - su' <- unify1 f e su sIn s2 - composeTVSubst (Just su) - composeTVSubst (Just su') - return (e1 , e2, apply su' s1, apply su' s2, apply su' sOut) + (sIn, sOut) <- checkFunSort s1 + _ <- unify1 f e emptySubst sIn s2 + -- composeTVSubst (Just su) + -- composeTVSubst (Just su') + return (e1 , e2, s1, s2, sOut) -------------------------------------------------------------------------------- @@ -1039,16 +1041,16 @@ checkApp' :: Env -> Maybe Sort -> Expr -> Expr -> CheckM (TVSubst, Sort) checkApp' f to g e = do gt <- checkExpr f g et <- checkExpr f e - (it, ot, isu) <- checkFunSort gt + (it, ot) <- checkFunSort gt let ge = Just (EApp g e) - su <- unifyMany f ge isu [it] [et] - let t = apply su ot + su <- unifyMany f ge emptySubst [it] [et] + -- let t = apply su ot case to of - Nothing -> return (su, t) - Just t' -> do θ' <- unifyMany f ge su [t] [t'] - let ti = apply θ' et - _ <- checkExprAs f ti e - return (θ', apply θ' t) + Nothing -> return (su, ot) + Just t' -> do θ' <- unifyMany f ge su [ot] [t'] + -- let ti = apply θ' et + _ <- checkExprAs f et e + return (θ', ot) -- | Helper for checking binary (numeric) operations @@ -1162,26 +1164,26 @@ checkURel e s1 s2 = unless (b1 == b2) (throwErrorAt $ errRel e s1 s2) -- | Sort Unification on Expressions -------------------------------------------------------------------------------- -{-# SCC unifyExpr #-} -unifyExpr :: Env -> Expr -> Maybe TVSubst -unifyExpr f (EApp e1 e2) = Just $ mconcat $ catMaybes [θ1, θ2, θ] - where - θ1 = unifyExpr f e1 - θ2 = unifyExpr f e2 - θ = unifyExprApp f e1 e2 -unifyExpr f (ECst e _) - = unifyExpr f e -unifyExpr _ _ - = Nothing - -unifyExprApp :: Env -> Expr -> Expr -> Maybe TVSubst -unifyExprApp f e1 e2 = do - t1 <- getArg $ exprSortMaybe e1 - t2 <- exprSortMaybe e2 - unify f (Just $ EApp e1 e2) t1 t2 - where - getArg (Just (FFunc t1 _)) = Just t1 - getArg _ = Nothing +-- {-# SCC unifyExpr #-} +-- unifyExpr :: Env -> Expr -> Maybe TVSubst +-- unifyExpr f (EApp e1 e2) = Just $ mconcat $ catMaybes [θ1, θ2, θ] +-- where +-- θ1 = unifyExpr f e1 +-- θ2 = unifyExpr f e2 +-- θ = unifyExprApp f e1 e2 +-- unifyExpr f (ECst e _) +-- = unifyExpr f e +-- unifyExpr _ _ +-- = Nothing + +-- unifyExprApp :: Env -> Expr -> Expr -> Maybe TVSubst +-- unifyExprApp f e1 e2 = do +-- t1 <- getArg $ exprSortMaybe e1 +-- t2 <- exprSortMaybe e2 +-- unify f (Just $ EApp e1 e2) t1 t2 +-- where +-- getArg (Just (FFunc t1 _)) = Just t1 +-- getArg _ = Nothing -------------------------------------------------------------------------------- @@ -1367,23 +1369,23 @@ unifyVar f e θ !i !t -- | Update global subst to be applied to expressions -------------------------------------------------------------------------------- -updateTVSubst :: TVSubst -> CheckM () -updateTVSubst theta = do - refTheta <- asks chTVSubst - liftIO $ atomicModifyIORef' refTheta $ const (Just theta, ()) +-- updateTVSubst :: TVSubst -> CheckM () +-- updateTVSubst theta = do +-- refTheta <- asks chTVSubst +-- liftIO $ atomicModifyIORef' refTheta $ const (Just theta, ()) --- local (\s -> s {chTVSubst = theta}) (return ()) +-- -- local (\s -> s {chTVSubst = theta}) (return ()) -mergeTVSubst :: TVSubst -> Maybe TVSubst -> TVSubst -mergeTVSubst (Th m1) Nothing = Th m1 -mergeTVSubst (Th m1) (Just (Th m2)) = Th m1 <> Th m2 +-- mergeTVSubst :: TVSubst -> Maybe TVSubst -> TVSubst +-- mergeTVSubst (Th m1) Nothing = Th m1 +-- mergeTVSubst (Th m1) (Just (Th m2)) = Th m1 <> Th m2 -composeTVSubst :: Maybe TVSubst -> CheckM () -composeTVSubst Nothing = return () -composeTVSubst (Just theta1) = do - refTheta <- asks chTVSubst - theta <- liftIO $ readIORef refTheta - updateTVSubst (mergeTVSubst theta1 theta) +-- composeTVSubst :: Maybe TVSubst -> CheckM () +-- composeTVSubst Nothing = return () +-- composeTVSubst (Just theta1) = do +-- refTheta <- asks chTVSubst +-- theta <- liftIO $ readIORef refTheta +-- updateTVSubst (mergeTVSubst theta1 theta) -------------------------------------------------------------------------------- -- | Applying a Type Substitution ---------------------------------------------- @@ -1415,12 +1417,15 @@ _applyCoercion a t = Vis.mapSort f -------------------------------------------------------------------------------- -- | Deconstruct a function-sort ----------------------------------------------- -------------------------------------------------------------------------------- -checkFunSort :: Sort -> CheckM (Sort, Sort, TVSubst) +checkFunSort :: Sort -> CheckM (Sort, Sort) checkFunSort (FAbs _ t) = checkFunSort t -checkFunSort (FFunc t1 t2) = return (t1, t2, emptySubst) -checkFunSort (FVar i) = do j <- fresh - k <- fresh - return (FVar j, FVar k, updateVar i (FFunc (FVar j) (FVar k)) emptySubst) +checkFunSort (FFunc t1 t2) = return (t1, t2) +checkFunSort (FVar i) = do + k <- fresh + j <- fresh + ufRef <- asks ufM + _ <- liftIO $ atomicModifyIORef' ufRef $ \uf -> (Union.union uf i (FFunc (FVar j) (FVar k)), ()) + return (FVar j, FVar k) checkFunSort t = throwErrorAt (errNonFunction 1 t) -------------------------------------------------------------------------------- diff --git a/src/Language/Fixpoint/Union.hs b/src/Language/Fixpoint/Union.hs new file mode 100644 index 000000000..9d28e9ec4 --- /dev/null +++ b/src/Language/Fixpoint/Union.hs @@ -0,0 +1,33 @@ +module Language.Fixpoint.Union where +import Data.HashMap.Strict (lookup, insert, HashMap, empty) +import Prelude hiding (lookup) +import Language.Fixpoint.Types.Sorts (Sort(..)) +next :: Sort -> Maybe Int +next (FVar i) = Just i +next _ = Nothing +unionVals :: UF -> Int -> Sort -> Sort -> UF +unionVals _ _ _ _ = error "todo" +-- unionVals ufM i s1 s2 = error "todo" +-- unionVals ufM _ FInt FInt = ufM +-- unionVals ufM _ FReal FReal = ufM +-- unionVals ufM _ FInt FReal = ufM +-- unionVals ufM _ FReal FInt = ufM +newtype UF = MkUF (HashMap Int Sort) deriving (Show) +new :: UF +new = MkUF empty +union :: UF -> Int -> Sort -> UF +union u@(MkUF ufM) tyv s = + -- find the root for tyv + let tyv_root = find (MkUF ufM) tyv in + case tyv_root of + -- if tyv not in union find, insert + Nothing -> MkUF (insert tyv s ufM) + -- otherwise, unify the current sort with + -- the new one and insert that + Just (i, s') -> unionVals u i s s' +find :: UF -> Int -> Maybe (Int, Sort) +find (MkUF ufM) k = do + s <- lookup k ufM + case next s of + Nothing -> Just (k, s) + Just i -> find (MkUF ufM) i \ No newline at end of file From d9ab837c89cc7bee7b35919c7f31a9cb93952b86 Mon Sep 17 00:00:00 2001 From: vrindisbacher Date: Thu, 30 Jan 2025 12:34:34 -0800 Subject: [PATCH 09/33] Extract unify to unifyUF --- src/Language/Fixpoint/SortCheck.hs | 119 +++++++++++++++++++++++------ 1 file changed, 94 insertions(+), 25 deletions(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index ea502017a..e2ce8df32 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -89,6 +89,7 @@ import GHC.Stack import qualified Language.Fixpoint.Types as F import System.IO.Unsafe (unsafePerformIO) import qualified Language.Fixpoint.Union as Union +import Language.Fixpoint.Union (UF) --import Debug.Trace as Debug @@ -276,7 +277,7 @@ elabExprE :: Located String -> SymEnv -> Expr -> Either Error Expr elabExprE msg env e = case runCM0 (srcSpan msg) $ do (!e', _) <- elab (env, envLookup) e - finalThetaRef <- asks chTVSubst + finalThetaRef <- asks chTVSubst finalTheta <- liftIO $ readIORef finalThetaRef return (applyExpr finalTheta e') of Left (ChError f') -> @@ -540,9 +541,13 @@ elab :: ElabEnv -> Expr -> CheckM (Expr, Sort) elab f@(!_, !g) e@(EBin !o !e1 !e2) = do (!e1', !s1) <- elab f e1 (!e2', !s2) <- elab f e2 - !s <- checkOpTy g e s1 s2 + -- !s <- checkOpTy g e s1 s2 + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifyUF g uf (Just e) s1 s2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) let !result = EBin o (eCst e1' s1) (eCst e2' s2) - return (result, s) + return (result, s2) elab !f (EApp !e1 !e2) = do @@ -1163,27 +1168,84 @@ checkURel e s1 s2 = unless (b1 == b2) (throwErrorAt $ errRel e s1 s2) -------------------------------------------------------------------------------- -- | Sort Unification on Expressions -------------------------------------------------------------------------------- +unifyUF :: Env -> UF -> Maybe Expr -> Sort -> Sort -> CheckM UF +-------------------------------------------------------------------------------- +unifyUF f uf e t1 t2 + = unify1UF f e uf t1 t2 + +-------------------------------------------------------------------------------- +unifyTo1UF :: Env -> UF -> [Sort] -> CheckM UF +-------------------------------------------------------------------------------- +unifyTo1UF f uf ts + = unifyTo1MUF f uf ts + + +-------------------------------------------------------------------------------- +unifyTo1MUF :: Env -> UF -> [Sort] -> CheckM UF +-------------------------------------------------------------------------------- +unifyTo1MUF _ _ [] = panic "unifyTo1: empty list" +unifyTo1MUF f uf (t0:ts) = fst <$> foldM step (uf, t0) ts + where + step :: (UF, Sort) -> Sort -> CheckM (UF, Sort) + step (ufm, t) t' = do + ufm' <- unify1UF f Nothing ufm t t' + return (ufm', t) + +-------------------------------------------------------------------------------- +unifysUF :: HasCallStack => Env -> Maybe Expr -> UF -> [Sort] -> [Sort] -> CheckM UF +-------------------------------------------------------------------------------- +unifysUF f e uf = unifyManyUF f e uf + +unifyManyUF :: HasCallStack => Env -> Maybe Expr -> UF -> [Sort] -> [Sort] -> CheckM UF +unifyManyUF f e uf ts ts' + | length ts == length ts' = foldM (uncurry . unify1UF f e) uf $ zip ts ts' + | otherwise = throwErrorAt (errUnifyMany ts ts') + +unify1UF :: Env -> Maybe Expr -> UF -> Sort -> Sort -> CheckM UF +unify1UF f e !uf (FVar !i) !t + = unifyVarUF f e uf i t +unify1UF f e !uf !t (FVar !i) + = unifyVarUF f e uf i t +unify1UF f e !uf (FApp !t1 !t2) (FApp !t1' !t2') + = unifyManyUF f e uf [t1, t2] [t1', t2'] +unify1UF _ _ !θ (FTC !l1) (FTC !l2) + | isListTC l1 && isListTC l2 + = return θ +unify1UF f e !uf t1@(FAbs _ _) !t2 = do + !t1' <- instantiate t1 + unifyManyUF f e uf [t1'] [t2] +unify1UF f e !uf !t1 t2@(FAbs _ _) = do + !t2' <- instantiate t2 + unifyManyUF f e uf [t1] [t2'] +unify1UF _ _ !uf !s1 !s2 + | isString s1, isString s2 + = return uf +unify1UF _ _ !uf FInt FReal = return uf + +unify1UF _ _ !uf FReal FInt = return uf + +unify1UF f e !uf !t FInt = do + checkNumeric f t `withError` errUnify e t FInt + return uf + +unify1UF f e !uf FInt !t = do + checkNumeric f t `withError` errUnify e FInt t + return uf + +unify1UF f e !uf (FFunc !t1 !t2) (FFunc !t1' !t2') = + unifyManyUF f e uf [t1, t2] [t1', t2'] --- {-# SCC unifyExpr #-} --- unifyExpr :: Env -> Expr -> Maybe TVSubst --- unifyExpr f (EApp e1 e2) = Just $ mconcat $ catMaybes [θ1, θ2, θ] --- where --- θ1 = unifyExpr f e1 --- θ2 = unifyExpr f e2 --- θ = unifyExprApp f e1 e2 --- unifyExpr f (ECst e _) --- = unifyExpr f e --- unifyExpr _ _ --- = Nothing +unify1UF f e uf (FObj a) !t = + checkEqConstr f e uf a t --- unifyExprApp :: Env -> Expr -> Expr -> Maybe TVSubst --- unifyExprApp f e1 e2 = do --- t1 <- getArg $ exprSortMaybe e1 --- t2 <- exprSortMaybe e2 --- unify f (Just $ EApp e1 e2) t1 t2 --- where --- getArg (Just (FFunc t1 _)) = Just t1 --- getArg _ = Nothing +unify1UF f e uf !t (FObj a) = + checkEqConstr f e uf a t + +unify1UF _ e uf !t1 !t2 + | t1 == t2 + = return uf + | otherwise + = throwErrorAt (errUnify e t1 t2) -------------------------------------------------------------------------------- @@ -1364,6 +1426,12 @@ unifyVar f e θ !i !t Just !t' -> if t == t' then return θ else unify1 f e θ t t' Nothing -> return (updateVar i t θ) +unifyVarUF :: Env -> Maybe Expr -> UF -> Int -> Sort -> CheckM UF +unifyVarUF _ _ uf !_ t@(FVar !j) + = return (Union.union uf j t) + +unifyVarUF _ _ uf !i !t + = return (Union.union uf i t) -------------------------------------------------------------------------------- -- | Update global subst to be applied to expressions @@ -1420,12 +1488,12 @@ _applyCoercion a t = Vis.mapSort f checkFunSort :: Sort -> CheckM (Sort, Sort) checkFunSort (FAbs _ t) = checkFunSort t checkFunSort (FFunc t1 t2) = return (t1, t2) -checkFunSort (FVar i) = do - k <- fresh +checkFunSort (FVar i) = do + k <- fresh j <- fresh ufRef <- asks ufM _ <- liftIO $ atomicModifyIORef' ufRef $ \uf -> (Union.union uf i (FFunc (FVar j) (FVar k)), ()) - return (FVar j, FVar k) + return (FVar j, FVar k) checkFunSort t = throwErrorAt (errNonFunction 1 t) -------------------------------------------------------------------------------- @@ -1514,3 +1582,4 @@ errNonFractional l = printf "The sort %s is not fractional" (showpp l) errBoolSort :: Expr -> Sort -> String errBoolSort e s = printf "Expressions %s should have bool sort, but has %s" (showpp e) (showpp s) + From acb4f460cf4cc71b91253de7a465efdb801afe20 Mon Sep 17 00:00:00 2001 From: vrindisbacher Date: Thu, 30 Jan 2025 15:31:52 -0800 Subject: [PATCH 10/33] rip out all applies and applyExpr from elab --- src/Language/Fixpoint/SortCheck.hs | 112 +++++++++++++++++++---------- src/Language/Fixpoint/Union.hs | 22 +++--- 2 files changed, 88 insertions(+), 46 deletions(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index e2ce8df32..07ea3b419 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -277,9 +277,9 @@ elabExprE :: Located String -> SymEnv -> Expr -> Either Error Expr elabExprE msg env e = case runCM0 (srcSpan msg) $ do (!e', _) <- elab (env, envLookup) e - finalThetaRef <- asks chTVSubst - finalTheta <- liftIO $ readIORef finalThetaRef - return (applyExpr finalTheta e') of + fufRef <- asks ufM + finalUF <- liftIO $ readIORef fufRef + return (applyExprUF finalUF e') of Left (ChError f') -> let e' = f' () in Left $ err (srcSpan e') (d (val e')) @@ -377,7 +377,7 @@ instance Show ChError where show (ChError f) = show (f ()) instance Exception ChError where -data ChState = ChS {chCount :: IORef Int, chSpan :: SrcSpan, ufM :: IORef Union.UF, chTVSubst :: IORef (Maybe TVSubst)} +data ChState = ChS {chCount :: IORef Int, chSpan :: SrcSpan, ufM :: IORef Union.UF} type Env = Symbol -> SESearch Sort type ElabEnv = (SymEnv, Env) @@ -412,9 +412,8 @@ varCounterRef = unsafePerformIO $ newIORef 42 -- value of counter. runCM0 :: SrcSpan -> CheckM a -> Either ChError a runCM0 sp act = unsafePerformIO $ do - suR <- newIORef Nothing ufR <- newIORef Union.new - try (runReaderT act (ChS varCounterRef sp ufR suR)) + try (runReaderT act (ChS varCounterRef sp ufR)) fresh :: CheckM Int fresh = do @@ -532,6 +531,7 @@ addEnv f bs x Just s -> Found s Nothing -> f x +-------------------------------------------------------------------------------- -------------------------------------------------------------------------------- -- | Elaborate expressions with types to make polymorphic instantiation explicit. -------------------------------------------------------------------------------- @@ -590,16 +590,22 @@ elab f@(!_,!g) (ECst (EIte !p !e1 !e2) !t) = do (!p', !_) <- elab f p (!e1', !s1) <- elab f (eCst e1 t) (!e2', !s2) <- elab f (eCst e2 t) - !s <- checkIteTy g p e1' e2' s1 s2 - return (EIte p' (eCst e1' s) (eCst e2' s), t) + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + !uf' <- checkIteTyUF g uf p e1' e2' s1 s2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return (EIte p' (eCst e1' s1) (eCst e2' s2), t) elab f@(!_,!g) (EIte !p !e1 !e2) = do !t <- getIte g e1 e2 (!p', !_) <- elab f p (!e1', !s1) <- elab f (eCst e1 t) (!e2', !s2) <- elab f (eCst e2 t) - !s <- checkIteTy g p e1' e2' s1 s2 - return (EIte p' (eCst e1' s) (eCst e2' s), s) + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + !uf' <- checkIteTyUF g uf p e1' e2' s1 s2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return (EIte p' (eCst e1' s1) (eCst e2' s2), s2) elab !f (ECst !e !t) = do @@ -704,10 +710,11 @@ elabAppAs env@(_, f) t g e = do te <- checkExpr f e (iT, oT) <- checkFunSort tg let ge = Just (EApp g e) - _ <- unifyMany f ge emptySubst [oT, iT] [t, te] - -- let tg = apply su tg + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifyManyUF f ge uf [oT, iT] [t, te] + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) g' <- elabAs env tg g - -- let te = apply su te e' <- elabAs env te e pure $ EApp (ECst g' tg) (ECst e' te) @@ -722,9 +729,10 @@ elabAppSort :: Env -> Expr -> Expr -> Sort -> Sort -> CheckM (Expr, Expr, Sort, elabAppSort f e1 e2 s1 s2 = do let e = Just (EApp e1 e2) (sIn, sOut) <- checkFunSort s1 - _ <- unify1 f e emptySubst sIn s2 - -- composeTVSubst (Just su) - -- composeTVSubst (Just su') + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unify1UF f e uf sIn s2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) return (e1 , e2, s1, s2, sOut) @@ -1015,6 +1023,12 @@ checkIteTy f p e1 e2 t1 t2 = where e' = Just (EIte p e1 e2) +checkIteTyUF :: Env -> UF -> Expr -> Expr -> Expr -> Sort -> Sort -> CheckM UF +checkIteTyUF f uf p e1 e2 t1 t2 = + unifysUF f e' uf [t1] [t2] `withError` errIte e1 e2 t1 t2 + where + e' = Just (EIte p e1 e2) + -- | Helper for checking cast expressions checkCst :: Env -> Sort -> Expr -> CheckM Sort checkCst f t (EApp g e) @@ -1114,6 +1128,14 @@ checkEqConstr f e θ a t = Found tA -> unify1 f e θ tA t _ -> throwErrorAt $ errUnifyMsg (Just "ceq2") e (FObj a) t +checkEqConstrUF :: Env -> Maybe Expr -> UF -> Symbol -> Sort -> CheckM UF +checkEqConstrUF _ _ uf a (FObj b) + | a == b + = return uf +checkEqConstrUF f e uf a t = + case f a of + Found tA -> unify1UF f e uf tA t + _ -> throwErrorAt $ errUnifyMsg (Just "ceq2") e (FObj a) t -------------------------------------------------------------------------------- -- | Checking Predicates ------------------------------------------------------- -------------------------------------------------------------------------------- @@ -1173,23 +1195,23 @@ unifyUF :: Env -> UF -> Maybe Expr -> Sort -> Sort -> CheckM UF unifyUF f uf e t1 t2 = unify1UF f e uf t1 t2 --------------------------------------------------------------------------------- -unifyTo1UF :: Env -> UF -> [Sort] -> CheckM UF --------------------------------------------------------------------------------- -unifyTo1UF f uf ts - = unifyTo1MUF f uf ts +-- -------------------------------------------------------------------------------- +-- unifyTo1UF :: Env -> UF -> [Sort] -> CheckM UF +-- -------------------------------------------------------------------------------- +-- unifyTo1UF f uf ts +-- = unifyTo1MUF f uf ts --------------------------------------------------------------------------------- -unifyTo1MUF :: Env -> UF -> [Sort] -> CheckM UF --------------------------------------------------------------------------------- -unifyTo1MUF _ _ [] = panic "unifyTo1: empty list" -unifyTo1MUF f uf (t0:ts) = fst <$> foldM step (uf, t0) ts - where - step :: (UF, Sort) -> Sort -> CheckM (UF, Sort) - step (ufm, t) t' = do - ufm' <- unify1UF f Nothing ufm t t' - return (ufm', t) +-- -------------------------------------------------------------------------------- +-- unifyTo1MUF :: Env -> UF -> [Sort] -> CheckM UF +-- -------------------------------------------------------------------------------- +-- unifyTo1MUF _ _ [] = panic "unifyTo1: empty list" +-- unifyTo1MUF f uf (t0:ts) = fst <$> foldM step (uf, t0) ts +-- where +-- step :: (UF, Sort) -> Sort -> CheckM (UF, Sort) +-- step (ufm, t) t' = do +-- ufm' <- unify1UF f Nothing ufm t t' +-- return (ufm', t) -------------------------------------------------------------------------------- unifysUF :: HasCallStack => Env -> Maybe Expr -> UF -> [Sort] -> [Sort] -> CheckM UF @@ -1236,10 +1258,10 @@ unify1UF f e !uf (FFunc !t1 !t2) (FFunc !t1' !t2') = unifyManyUF f e uf [t1, t2] [t1', t2'] unify1UF f e uf (FObj a) !t = - checkEqConstr f e uf a t + checkEqConstrUF f e uf a t unify1UF f e uf !t (FObj a) = - checkEqConstr f e uf a t + checkEqConstrUF f e uf a t unify1UF _ e uf !t1 !t2 | t1 == t2 @@ -1465,11 +1487,27 @@ apply !θ = Vis.mapSort f f t@(FVar !i) = fromMaybe t (lookupVar i θ) f !t = t -applyExpr :: Maybe TVSubst -> Expr -> Expr -applyExpr Nothing e = e -applyExpr (Just θ) e = Vis.mapExprOnExpr f e +-- applyExpr :: Maybe TVSubst -> Expr -> Expr +-- applyExpr Nothing e = e +-- applyExpr (Just θ) e = Vis.mapExprOnExpr f e +-- where +-- f (ECst !e' !s) = ECst e' (apply θ s) +-- f !e' = e' + +-------------------------------------------------------------------------------- +-- | Applying the result of union find +-------------------------------------------------------------------------------- +applyUF :: UF -> Sort -> Sort +applyUF uf = Vis.mapSort f + where + f t@(FVar !i) = fromMaybe t (Union.find uf i) + f !t = t + +{-# SCC applyExprUF #-} +applyExprUF :: UF -> Expr -> Expr +applyExprUF uf e = Vis.mapExprOnExpr f e where - f (ECst !e' !s) = ECst e' (apply θ s) + f (ECst !e' !s) = ECst e' (applyUF uf s) f !e' = e' -------------------------------------------------------------------------------- diff --git a/src/Language/Fixpoint/Union.hs b/src/Language/Fixpoint/Union.hs index 9d28e9ec4..b5cb776fb 100644 --- a/src/Language/Fixpoint/Union.hs +++ b/src/Language/Fixpoint/Union.hs @@ -5,13 +5,14 @@ import Language.Fixpoint.Types.Sorts (Sort(..)) next :: Sort -> Maybe Int next (FVar i) = Just i next _ = Nothing -unionVals :: UF -> Int -> Sort -> Sort -> UF -unionVals _ _ _ _ = error "todo" +unionVals :: UF -> Sort -> Sort -> UF +unionVals _ _ _ = error "todo" -- unionVals ufM i s1 s2 = error "todo" -- unionVals ufM _ FInt FInt = ufM -- unionVals ufM _ FReal FReal = ufM -- unionVals ufM _ FInt FReal = ufM -- unionVals ufM _ FReal FInt = ufM + newtype UF = MkUF (HashMap Int Sort) deriving (Show) new :: UF new = MkUF empty @@ -24,10 +25,13 @@ union u@(MkUF ufM) tyv s = Nothing -> MkUF (insert tyv s ufM) -- otherwise, unify the current sort with -- the new one and insert that - Just (i, s') -> unionVals u i s s' -find :: UF -> Int -> Maybe (Int, Sort) -find (MkUF ufM) k = do - s <- lookup k ufM - case next s of - Nothing -> Just (k, s) - Just i -> find (MkUF ufM) i \ No newline at end of file + Just s' -> unionVals u s s' + +find :: UF -> Int -> Maybe Sort +find (MkUF ufM) = f + where + f k = do + s <- lookup k ufM + case next s of + Nothing -> Just s + Just i -> f i \ No newline at end of file From 3c64ac170a3a8c7ce8bfa0c698278be40faf9c0f Mon Sep 17 00:00:00 2001 From: vrindisbacher Date: Thu, 30 Jan 2025 16:19:34 -0800 Subject: [PATCH 11/33] hack together some union method for sorts --- src/Language/Fixpoint/SortCheck.hs | 6 +-- src/Language/Fixpoint/Union.hs | 73 ++++++++++++++++++++++++------ 2 files changed, 61 insertions(+), 18 deletions(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index 07ea3b419..724f220b2 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -1187,7 +1187,7 @@ checkURel e s1 s2 = unless (b1 == b2) (throwErrorAt $ errRel e s1 s2) b1 = s1 == boolSort b2 = s2 == boolSort --------------------------------------------------------------------------------- +------------------------------------------------------------------- -- | Sort Unification on Expressions -------------------------------------------------------------------------------- unifyUF :: Env -> UF -> Maybe Expr -> Sort -> Sort -> CheckM UF @@ -1449,8 +1449,8 @@ unifyVar f e θ !i !t Nothing -> return (updateVar i t θ) unifyVarUF :: Env -> Maybe Expr -> UF -> Int -> Sort -> CheckM UF -unifyVarUF _ _ uf !_ t@(FVar !j) - = return (Union.union uf j t) +unifyVarUF _ _ uf !i (FVar !j) + = return (Union.union uf j (FVar i)) unifyVarUF _ _ uf !i !t = return (Union.union uf i t) diff --git a/src/Language/Fixpoint/Union.hs b/src/Language/Fixpoint/Union.hs index b5cb776fb..5ff83ebb2 100644 --- a/src/Language/Fixpoint/Union.hs +++ b/src/Language/Fixpoint/Union.hs @@ -2,16 +2,52 @@ module Language.Fixpoint.Union where import Data.HashMap.Strict (lookup, insert, HashMap, empty) import Prelude hiding (lookup) import Language.Fixpoint.Types.Sorts (Sort(..)) -next :: Sort -> Maybe Int -next (FVar i) = Just i -next _ = Nothing -unionVals :: UF -> Sort -> Sort -> UF -unionVals _ _ _ = error "todo" --- unionVals ufM i s1 s2 = error "todo" --- unionVals ufM _ FInt FInt = ufM --- unionVals ufM _ FReal FReal = ufM --- unionVals ufM _ FInt FReal = ufM --- unionVals ufM _ FReal FInt = ufM + +-- unionVals :: UF Sort -> Int -> Sort -> Sort -> UF Sort +-- unionVals ufM _ SInt SInt = ufM +-- unionVals ufM _ SFloat SFloat = ufM +-- unionVals ufM _ (SFVar j) s = Union.union ufM j s +-- unionVals ufM _ s (SFVar j) = Union.union ufM j s +-- unionVals u i (SFunc s1 s2) (SFunc s1' s2') = +-- let u' = unionFuncArgs u i s1 s1' in +-- unionFuncArgs u' i s2 s2' +-- unionVals _ _ s1 s2 = error ("Cannot unify " ++ show s1 ++ " " ++ show s2) +-- next s = case s of +-- SFVar i -> Just i +-- _ -> Nothing + +unionMany :: UF -> Int -> Sort -> Sort -> UF +unionMany uf i s1 s2 = case (s1, s2) of + (FVar i1, _) -> union uf i1 s2 + (_, FVar i2) -> union uf i2 s1 + (_, _) -> unionVals uf i s1 s2 + +-------------------------------------------------------------------------------- +-- | union for sorts in union find +-------------------------------------------------------------------------------- +unionVals :: UF -> Int -> Sort -> Sort -> UF +-------------------------------------------------------------------------------- +unionVals uf _ s1 s2 + | isNumericSort s1 && isNumericSort s2 = uf + where + isNumericSort FReal = True + isNumericSort FNum = True + isNumericSort FFrac = True + isNumericSort _ = False + +unionVals uf _ (FObj _) (FObj _) = uf +unionVals uf _ (FVar i) s = union uf i s +unionVals uf _ s (FVar i) = union uf i s +unionVals uf i (FFunc s1 s2) (FFunc s1' s2') = + let uf' = unionMany uf i s1 s1' in + unionMany uf' i s2 s2' +unionVals uf i (FApp s1 s2) (FApp s1' s2') = + let uf' = unionMany uf i s1 s1' in + unionMany uf' i s2 s2' +unionVals uf i (FAbs _ s) (FAbs _ s') = unionMany uf i s s' +unionVals uf _ (FTC _) (FTC _) = uf +unionVals _ _ _ _ = error "Cannot unify" + newtype UF = MkUF (HashMap Int Sort) deriving (Show) new :: UF @@ -19,19 +55,26 @@ new = MkUF empty union :: UF -> Int -> Sort -> UF union u@(MkUF ufM) tyv s = -- find the root for tyv - let tyv_root = find (MkUF ufM) tyv in + let tyv_root = findWithIndex (MkUF ufM) tyv in case tyv_root of -- if tyv not in union find, insert Nothing -> MkUF (insert tyv s ufM) -- otherwise, unify the current sort with -- the new one and insert that - Just s' -> unionVals u s s' + Just (i, s') -> unionVals u i s s' + +findWithIndex :: UF -> Int -> Maybe (Int, Sort) +findWithIndex u@(MkUF ufM) k = do + s <- lookup k ufM + case s of + FVar i -> findWithIndex u i + s' -> Just (k, s') find :: UF -> Int -> Maybe Sort find (MkUF ufM) = f where f k = do s <- lookup k ufM - case next s of - Nothing -> Just s - Just i -> f i \ No newline at end of file + case s of + FVar i -> f i + s' -> Just s' \ No newline at end of file From 3b592cc025e50913d668f5e0b677c2a49e2a8455 Mon Sep 17 00:00:00 2001 From: vrindisbacher Date: Thu, 30 Jan 2025 16:30:59 -0800 Subject: [PATCH 12/33] fixup mistake --- src/Language/Fixpoint/Union.hs | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/src/Language/Fixpoint/Union.hs b/src/Language/Fixpoint/Union.hs index 5ff83ebb2..1bf264a04 100644 --- a/src/Language/Fixpoint/Union.hs +++ b/src/Language/Fixpoint/Union.hs @@ -3,19 +3,6 @@ import Data.HashMap.Strict (lookup, insert, HashMap, empty) import Prelude hiding (lookup) import Language.Fixpoint.Types.Sorts (Sort(..)) --- unionVals :: UF Sort -> Int -> Sort -> Sort -> UF Sort --- unionVals ufM _ SInt SInt = ufM --- unionVals ufM _ SFloat SFloat = ufM --- unionVals ufM _ (SFVar j) s = Union.union ufM j s --- unionVals ufM _ s (SFVar j) = Union.union ufM j s --- unionVals u i (SFunc s1 s2) (SFunc s1' s2') = --- let u' = unionFuncArgs u i s1 s1' in --- unionFuncArgs u' i s2 s2' --- unionVals _ _ s1 s2 = error ("Cannot unify " ++ show s1 ++ " " ++ show s2) --- next s = case s of --- SFVar i -> Just i --- _ -> Nothing - unionMany :: UF -> Int -> Sort -> Sort -> UF unionMany uf i s1 s2 = case (s1, s2) of (FVar i1, _) -> union uf i1 s2 @@ -33,6 +20,7 @@ unionVals uf _ s1 s2 isNumericSort FReal = True isNumericSort FNum = True isNumericSort FFrac = True + isNumericSort FInt = True isNumericSort _ = False unionVals uf _ (FObj _) (FObj _) = uf @@ -46,7 +34,7 @@ unionVals uf i (FApp s1 s2) (FApp s1' s2') = unionMany uf' i s2 s2' unionVals uf i (FAbs _ s) (FAbs _ s') = unionMany uf i s s' unionVals uf _ (FTC _) (FTC _) = uf -unionVals _ _ _ _ = error "Cannot unify" +unionVals _ _ s1 s2 = error ("Cannot unify " ++ show s1 ++ " and " ++ show s2) newtype UF = MkUF (HashMap Int Sort) deriving (Show) From 3cac458985ff99a1204cda0f13f99ecd1d1032bf Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Fri, 31 Jan 2025 10:08:32 -0800 Subject: [PATCH 13/33] Change union for FObj and FTC --- src/Language/Fixpoint/Union.hs | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/Language/Fixpoint/Union.hs b/src/Language/Fixpoint/Union.hs index 1bf264a04..af9fbc3d8 100644 --- a/src/Language/Fixpoint/Union.hs +++ b/src/Language/Fixpoint/Union.hs @@ -3,8 +3,8 @@ import Data.HashMap.Strict (lookup, insert, HashMap, empty) import Prelude hiding (lookup) import Language.Fixpoint.Types.Sorts (Sort(..)) -unionMany :: UF -> Int -> Sort -> Sort -> UF -unionMany uf i s1 s2 = case (s1, s2) of +unionSub :: UF -> Int -> Sort -> Sort -> UF +unionSub uf i s1 s2 = case (s1, s2) of (FVar i1, _) -> union uf i1 s2 (_, FVar i2) -> union uf i2 s1 (_, _) -> unionVals uf i s1 s2 @@ -23,17 +23,19 @@ unionVals uf _ s1 s2 isNumericSort FInt = True isNumericSort _ = False -unionVals uf _ (FObj _) (FObj _) = uf +unionVals uf _ (FObj x) (FObj y) + | x == y = uf unionVals uf _ (FVar i) s = union uf i s unionVals uf _ s (FVar i) = union uf i s unionVals uf i (FFunc s1 s2) (FFunc s1' s2') = - let uf' = unionMany uf i s1 s1' in - unionMany uf' i s2 s2' + let uf' = unionSub uf i s1 s1' in + unionSub uf' i s2 s2' unionVals uf i (FApp s1 s2) (FApp s1' s2') = - let uf' = unionMany uf i s1 s1' in - unionMany uf' i s2 s2' -unionVals uf i (FAbs _ s) (FAbs _ s') = unionMany uf i s s' -unionVals uf _ (FTC _) (FTC _) = uf + let uf' = unionSub uf i s1 s1' in + unionSub uf' i s2 s2' +unionVals uf i (FAbs _ s) (FAbs _ s') = unionSub uf i s s' +unionVals uf _ (FTC s1) (FTC s2) + | s1 == s2 = uf unionVals _ _ s1 s2 = error ("Cannot unify " ++ show s1 ++ " and " ++ show s2) From 1111410a304769b1047a0ab03e27e025eb469453 Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Fri, 31 Jan 2025 12:49:56 -0800 Subject: [PATCH 14/33] insert directly on free variables --- src/Language/Fixpoint/Union.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Language/Fixpoint/Union.hs b/src/Language/Fixpoint/Union.hs index af9fbc3d8..e480e4881 100644 --- a/src/Language/Fixpoint/Union.hs +++ b/src/Language/Fixpoint/Union.hs @@ -25,8 +25,8 @@ unionVals uf _ s1 s2 unionVals uf _ (FObj x) (FObj y) | x == y = uf -unionVals uf _ (FVar i) s = union uf i s -unionVals uf _ s (FVar i) = union uf i s +unionVals (MkUF uf) _ (FVar i) s = MkUF (insert i s uf) +unionVals (MkUF uf) _ s (FVar i) = MkUF (insert i s uf) unionVals uf i (FFunc s1 s2) (FFunc s1' s2') = let uf' = unionSub uf i s1 s1' in unionSub uf' i s2 s2' From 8300b802ae343bc9f1196ddd6bd60dbddfbc03de Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Fri, 31 Jan 2025 13:31:54 -0800 Subject: [PATCH 15/33] add check that we aren't unifying a type variable with itself --- src/Language/Fixpoint/SortCheck.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index 724f220b2..9bd367de9 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -1449,8 +1449,8 @@ unifyVar f e θ !i !t Nothing -> return (updateVar i t θ) unifyVarUF :: Env -> Maybe Expr -> UF -> Int -> Sort -> CheckM UF -unifyVarUF _ _ uf !i (FVar !j) - = return (Union.union uf j (FVar i)) +unifyVarUF _ _ uf !i (FVar !j) = + if i == j then return uf else return (Union.union uf j (FVar i)) unifyVarUF _ _ uf !i !t = return (Union.union uf i t) From 1bde2dcfcfa41f6f096c7c10a680ef2b258934db Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Mon, 3 Feb 2025 11:04:42 -0800 Subject: [PATCH 16/33] change unite --- src/Language/Fixpoint/SortCheck.hs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index 9bd367de9..510e26087 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -541,7 +541,6 @@ elab :: ElabEnv -> Expr -> CheckM (Expr, Sort) elab f@(!_, !g) e@(EBin !o !e1 !e2) = do (!e1', !s1) <- elab f e1 (!e2', !s2) <- elab f e2 - -- !s <- checkOpTy g e s1 s2 ufRef <- asks ufM uf <- liftIO $ readIORef ufRef uf' <- unifyUF g uf (Just e) s1 s2 @@ -965,8 +964,11 @@ genSort t = t unite :: Env -> Expr -> Sort -> Sort -> CheckM (Sort, Sort) unite f e t1 t2 = do - su <- unifys f (Just e) [t1] [t2] - return (apply su t1, apply su t2) + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifysUF f (Just e) uf [t1] [t2] + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return (t1, t2) throwErrorAt :: String -> CheckM a throwErrorAt ~err' = do -- Lazy pattern needed because we use LANGUAGE Strict in this module @@ -1084,7 +1086,6 @@ checkOp f e1 o e2 t2 <- checkExpr f e2 checkOpTy f (EBin o e1 e2) t1 t2 - checkOpTy :: Env -> Expr -> Sort -> Sort -> CheckM Sort checkOpTy _ _ FInt FInt = return FInt From f79afbc50baf97e13956eecec467c34890a19f70 Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Mon, 3 Feb 2025 11:33:06 -0800 Subject: [PATCH 17/33] add applyExprUF --- src/Language/Fixpoint/SortCheck.hs | 32 +++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index 510e26087..69f9e0826 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -283,7 +283,9 @@ elabExprE msg env e = Left (ChError f') -> let e' = f' () in Left $ err (srcSpan e') (d (val e')) - Right s -> Right s + Right s -> + let !_ = unsafePerformIO $ print ("Got " ++ show s) in + Right s where sEnv = seSort env envLookup = (`lookupSEnvWithDistance` sEnv) @@ -552,8 +554,10 @@ elab f@(!_, !g) e@(EBin !o !e1 !e2) = do elab !f (EApp !e1 !e2) = do (!e1', !s1, !e2', !s2, !s) <- elabEApp f e1 e2 let !e = eAppC s (eCst e1' s1) (eCst e2' s2) - -- let !θ = unifyExpr (snd f) e - -- composeTVSubst θ + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifyExprUF (snd f) uf e + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) return (e, s) @@ -1137,6 +1141,28 @@ checkEqConstrUF f e uf a t = case f a of Found tA -> unify1UF f e uf tA t _ -> throwErrorAt $ errUnifyMsg (Just "ceq2") e (FObj a) t + +{-# SCC unifyExprUF #-} +unifyExprUF :: Env -> UF -> Expr -> CheckM UF +unifyExprUF f uf (EApp e1 e2) = do + uf1 <- unifyExprUF f uf e1 + uf2 <- unifyExprUF f uf1 e2 + unifyExprAppUF f uf2 e1 e2 + +unifyExprUF f uf (ECst e _) + = unifyExprUF f uf e +unifyExprUF _ uf _ + = return uf + +unifyExprAppUF :: Env -> UF -> Expr -> Expr -> CheckM UF +unifyExprAppUF f uf e1 e2 = do + case (getArg $ exprSortMaybe e1, exprSortMaybe e2) of + (Just s1, Just s2) -> unifyUF f uf (Just $ EApp e1 e2) s1 s2 + _ -> return uf + where + getArg (Just (FFunc t1 _)) = Just t1 + getArg _ = Nothing + -------------------------------------------------------------------------------- -- | Checking Predicates ------------------------------------------------------- -------------------------------------------------------------------------------- From c676973e5ec947e03daf489bc0635d83c70b03f9 Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Mon, 3 Feb 2025 12:07:11 -0800 Subject: [PATCH 18/33] add a bunch of missing unifyUFs --- src/Language/Fixpoint/SortCheck.hs | 65 +++++++++++++++++++----------- 1 file changed, 42 insertions(+), 23 deletions(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index 69f9e0826..a571578f4 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -275,6 +275,7 @@ elabExpr msg env e = case elabExprE msg env e of elabExprE :: Located String -> SymEnv -> Expr -> Either Error Expr elabExprE msg env e = + let !_ = unsafePerformIO $ print ("elab " ++ show e) in case runCM0 (srcSpan msg) $ do (!e', _) <- elab (env, envLookup) e fufRef <- asks ufM @@ -600,14 +601,16 @@ elab f@(!_,!g) (ECst (EIte !p !e1 !e2) !t) = do return (EIte p' (eCst e1' s1) (eCst e2' s2), t) elab f@(!_,!g) (EIte !p !e1 !e2) = do - !t <- getIte g e1 e2 - (!p', !_) <- elab f p - (!e1', !s1) <- elab f (eCst e1 t) - (!e2', !s2) <- elab f (eCst e2 t) ufRef <- asks ufM uf <- liftIO $ readIORef ufRef - !uf' <- checkIteTyUF g uf p e1' e2' s1 s2 + (!t, uf') <- getIteUF g uf e1 e2 liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + (!p', !_) <- elab f p + (!e1', !s1) <- elab f (eCst e1 t) + (!e2', !s2) <- elab f (eCst e2 t) + uf'' <- liftIO $ readIORef ufRef + !uf''' <- checkIteTyUF g uf'' p e1' e2' s1 s2 + liftIO $ atomicModifyIORef' ufRef $ const (uf''', ()) return (EIte p' (eCst e1' s1) (eCst e2' s2), s2) @@ -959,7 +962,8 @@ exprSortMaybe = go go (ELam (_, sx) e) = FFunc sx <$> go e go (EApp e ex) | Just (FFunc sx s) <- genSort <$> go e - = maybe s (`apply` s) . (`unifySorts` sx) <$> go ex + = + maybe s (`apply` s) . (`unifySorts` sx) <$> go ex go _ = Nothing genSort :: Sort -> Sort @@ -1017,11 +1021,11 @@ checkIte f p e1 e2 = do t2 <- checkExpr f e2 checkIteTy f p e1 e2 t1 t2 -getIte :: Env -> Expr -> Expr -> CheckM Sort -getIte f e1 e2 = do - t1 <- checkExpr f e1 - t2 <- checkExpr f e2 - (`apply` t1) <$> unifys f Nothing [t1] [t2] +-- getIte :: Env -> Expr -> Expr -> CheckM Sort +-- getIte f e1 e2 = do +-- t1 <- checkExpr f e1 +-- t2 <- checkExpr f e2 +-- (`apply` t1) <$> unifys f Nothing [t1] [t2] checkIteTy :: Env -> Expr -> Expr -> Expr -> Sort -> Sort -> CheckM Sort checkIteTy f p e1 e2 t1 t2 = @@ -1029,6 +1033,13 @@ checkIteTy f p e1 e2 t1 t2 = where e' = Just (EIte p e1 e2) +getIteUF :: Env -> UF -> Expr -> Expr -> CheckM (Sort, UF) +getIteUF f uf e1 e2 = do + t1 <- checkExpr f e1 + t2 <- checkExpr f e2 + uf' <- unifysUF f Nothing uf [t1] [t2] + return (t1, uf') + checkIteTyUF :: Env -> UF -> Expr -> Expr -> Expr -> Sort -> Sort -> CheckM UF checkIteTyUF f uf p e1 e2 t1 t2 = unifysUF f e' uf [t1] [t2] `withError` errIte e1 e2 t1 t2 @@ -1046,15 +1057,19 @@ checkCst f t e checkApp :: Env -> Maybe Sort -> Expr -> Expr -> CheckM Sort checkApp f to g es - = snd <$> checkApp' f to g es + = checkApp' f to g es checkExprAs :: Env -> Sort -> Expr -> CheckM Sort checkExprAs f t (EApp g e) = checkApp f (Just t) g e checkExprAs f t e - = do t' <- checkExpr f e - θ <- unifys f (Just e) [t'] [t] - pure $ apply θ t + = do + t' <- checkExpr f e + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifysUF f (Just e) uf [t'] [t] + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return t -- | Helper for checking uninterpreted function applications -- | Checking function application should be curried, e.g. @@ -1062,20 +1077,24 @@ checkExprAs f t e -- RJ: The above comment makes no sense to me :( -- DUPLICATION with 'elabAppAs' -checkApp' :: Env -> Maybe Sort -> Expr -> Expr -> CheckM (TVSubst, Sort) +checkApp' :: Env -> Maybe Sort -> Expr -> Expr -> CheckM Sort checkApp' f to g e = do gt <- checkExpr f g et <- checkExpr f e (it, ot) <- checkFunSort gt let ge = Just (EApp g e) - su <- unifyMany f ge emptySubst [it] [et] - -- let t = apply su ot + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifyManyUF f ge uf [it] [et] + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) case to of - Nothing -> return (su, ot) - Just t' -> do θ' <- unifyMany f ge su [ot] [t'] - -- let ti = apply θ' et - _ <- checkExprAs f et e - return (θ', ot) + Nothing -> return ot + Just t' -> do + uf'' <- liftIO $ readIORef ufRef + uf''' <- unifyManyUF f ge uf'' [ot] [t'] + liftIO $ atomicModifyIORef' ufRef $ const (uf''', ()) + _ <- checkExprAs f et e + return ot -- | Helper for checking binary (numeric) operations From 4b055d92223335c23a5568f8c088da38b7f2941c Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Mon, 3 Feb 2025 12:46:56 -0800 Subject: [PATCH 19/33] add more missing uf --- src/Language/Fixpoint/SortCheck.hs | 86 +++++++++++++++++------------- 1 file changed, 48 insertions(+), 38 deletions(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index a571578f4..dc1fa9d73 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -596,7 +596,7 @@ elab f@(!_,!g) (ECst (EIte !p !e1 !e2) !t) = do (!e2', !s2) <- elab f (eCst e2 t) ufRef <- asks ufM uf <- liftIO $ readIORef ufRef - !uf' <- checkIteTyUF g uf p e1' e2' s1 s2 + (!_, !uf') <- checkIteTyUF g uf p e1' e2' s1 s2 liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) return (EIte p' (eCst e1' s1) (eCst e2' s2), t) @@ -609,7 +609,7 @@ elab f@(!_,!g) (EIte !p !e1 !e2) = do (!e1', !s1) <- elab f (eCst e1 t) (!e2', !s2) <- elab f (eCst e2 t) uf'' <- liftIO $ readIORef ufRef - !uf''' <- checkIteTyUF g uf'' p e1' e2' s1 s2 + (_, uf''') <- checkIteTyUF g uf'' p e1' e2' s1 s2 liftIO $ atomicModifyIORef' ufRef $ const (uf''', ()) return (EIte p' (eCst e1' s1) (eCst e2' s2), s2) @@ -1019,19 +1019,11 @@ checkIte f p e1 e2 = do checkPred f p t1 <- checkExpr f e1 t2 <- checkExpr f e2 - checkIteTy f p e1 e2 t1 t2 - --- getIte :: Env -> Expr -> Expr -> CheckM Sort --- getIte f e1 e2 = do --- t1 <- checkExpr f e1 --- t2 <- checkExpr f e2 --- (`apply` t1) <$> unifys f Nothing [t1] [t2] - -checkIteTy :: Env -> Expr -> Expr -> Expr -> Sort -> Sort -> CheckM Sort -checkIteTy f p e1 e2 t1 t2 = - ((`apply` t1) <$> unifys f e' [t1] [t2]) `withError` errIte e1 e2 t1 t2 - where - e' = Just (EIte p e1 e2) + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + (s, uf') <- checkIteTyUF f uf p e1 e2 t1 t2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return s getIteUF :: Env -> UF -> Expr -> Expr -> CheckM (Sort, UF) getIteUF f uf e1 e2 = do @@ -1040,9 +1032,10 @@ getIteUF f uf e1 e2 = do uf' <- unifysUF f Nothing uf [t1] [t2] return (t1, uf') -checkIteTyUF :: Env -> UF -> Expr -> Expr -> Expr -> Sort -> Sort -> CheckM UF -checkIteTyUF f uf p e1 e2 t1 t2 = - unifysUF f e' uf [t1] [t2] `withError` errIte e1 e2 t1 t2 +checkIteTyUF :: Env -> UF -> Expr -> Expr -> Expr -> Sort -> Sort -> CheckM (Sort, UF) +checkIteTyUF f uf p e1 e2 t1 t2 = do + uf' <- unifysUF f e' uf [t1] [t2] `withError` errIte e1 e2 t1 t2 + return (t1, uf') where e' = Just (EIte p e1 e2) @@ -1051,9 +1044,13 @@ checkCst :: Env -> Sort -> Expr -> CheckM Sort checkCst f t (EApp g e) = checkApp f (Just t) g e checkCst f t e - = do t' <- checkExpr f e - su <- unifys f (Just e) [t] [t'] `withError` errCast e t' t - pure (apply su t) + = do + t' <- checkExpr f e + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifysUF f (Just e) uf [t] [t'] `withError` errCast e t' t + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + pure t checkApp :: Env -> Maybe Sort -> Expr -> Expr -> CheckM Sort checkApp f to g es @@ -1122,12 +1119,12 @@ checkOpTy _ _ FInt FReal checkOpTy _ _ FReal FInt = return FReal -checkOpTy f e t t' - | Just s <- unify f (Just e) t t' - = checkNumeric f (apply s t) >> return (apply s t) - -checkOpTy _ e t t' - = throwErrorAt (errOp e t t') +checkOpTy f e t t' = do + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifyUF f uf (Just e) t t' + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + checkNumeric f t >> return t checkFractional :: Env -> Sort -> CheckM () checkFractional f s@(FObj l) @@ -1198,9 +1195,12 @@ checkRel :: HasCallStack => Env -> Brel -> Expr -> Expr -> CheckM () checkRel f Eq e1 e2 = do t1 <- checkExpr f e1 t2 <- checkExpr f e2 - su <- unifys f (Just e) [t1] [t2] `withError` errRel e t1 t2 - _ <- checkExprAs f (apply su t1) e1 - _ <- checkExprAs f (apply su t2) e2 + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifysUF f (Just e) uf [t1] [t2] `withError` errRel e t1 t2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + _ <- checkExprAs f t1 e1 + _ <- checkExprAs f t2 e2 checkRelTy f e Eq t1 t2 where e = PAtom Eq e1 e2 @@ -1223,8 +1223,18 @@ checkRelTy f _ _ FInt s2 = checkNumeric f s2 `withError` errNonNumeric s2 checkRelTy f _ _ s1 FInt = checkNumeric f s1 `withError` errNonNumeric s1 checkRelTy f _ _ FReal s2 = checkFractional f s2 `withError` errNonFractional s2 checkRelTy f _ _ s1 FReal = checkFractional f s1 `withError` errNonFractional s1 -checkRelTy f e Eq t1 t2 = void (unifys f (Just e) [t1] [t2] `withError` errRel e t1 t2) -checkRelTy f e Ne t1 t2 = void (unifys f (Just e) [t1] [t2] `withError` errRel e t1 t2) +checkRelTy f e Eq t1 t2 = do + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifysUF f (Just e) uf [t1] [t2] `withError` errRel e t1 t2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return () +checkRelTy f e Ne t1 t2 = do + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifysUF f (Just e) uf [t1] [t2] `withError` errRel e t1 t2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return () checkRelTy _ e _ t1 t2 = unless (t1 == t2) (throwErrorAt $ errRel e t1 t2) checkURel :: Expr -> Sort -> Sort -> CheckM () @@ -1633,12 +1643,12 @@ errRel e t1 t2 = traced $ printf "Invalid Relation %s with operand types %s and %s" (showpp e) (showpp t1) (showpp t2) -errOp :: Expr -> Sort -> Sort -> String -errOp e t t' - | t == t' = printf "Operands have non-numeric types %s in %s" - (showpp t) (showpp e) - | otherwise = printf "Operands have different types %s and %s in %s" - (showpp t) (showpp t') (showpp e) +-- errOp :: Expr -> Sort -> Sort -> String +-- errOp e t t' +-- | t == t' = printf "Operands have non-numeric types %s in %s" +-- (showpp t) (showpp e) +-- | otherwise = printf "Operands have different types %s and %s in %s" +-- (showpp t) (showpp t') (showpp e) errIte :: Expr -> Expr -> Sort -> Sort -> String errIte e1 e2 t1 t2 = printf "Mismatched branches in Ite: then %s : %s, else %s : %s" From 5985fa92ec79487adc3d422ccb09b5560b3e73db Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Mon, 3 Feb 2025 15:24:23 -0800 Subject: [PATCH 20/33] Found mystery --- src/Language/Fixpoint/SortCheck.hs | 34 ++++++++++++++++-------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index dc1fa9d73..7d2462026 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -275,17 +275,17 @@ elabExpr msg env e = case elabExprE msg env e of elabExprE :: Located String -> SymEnv -> Expr -> Either Error Expr elabExprE msg env e = - let !_ = unsafePerformIO $ print ("elab " ++ show e) in case runCM0 (srcSpan msg) $ do (!e', _) <- elab (env, envLookup) e fufRef <- asks ufM finalUF <- liftIO $ readIORef fufRef + let !_ = unsafePerformIO $ print ("Final UF " ++ show finalUF) return (applyExprUF finalUF e') of Left (ChError f') -> let e' = f' () in Left $ err (srcSpan e') (d (val e')) Right s -> - let !_ = unsafePerformIO $ print ("Got " ++ show s) in + let !_ = unsafePerformIO $ print ("Result " ++ show s) in Right s where sEnv = seSort env @@ -648,6 +648,7 @@ elab f@(!_,!g) e@(PAtom !eq !e1 !e2) | eq == Eq || eq == Ne = do !e2' <- elabAs f t2' e2 !e1'' <- eCstAtom f e1' t1' !e2'' <- eCstAtom f e2' t2' + let !_ = unsafePerformIO $ print ("Right side is " ++ show e2'') return (PAtom eq e1'' e2'', boolSort) elab !f (PAtom !r !e1 !e2) @@ -963,6 +964,7 @@ exprSortMaybe = go go (EApp e ex) | Just (FFunc sx s) <- genSort <$> go e = + let !_ = unsafePerformIO $ print ("About to apply " ++ show sx) in maybe s (`apply` s) . (`unifySorts` sx) <$> go ex go _ = Nothing @@ -1029,7 +1031,10 @@ getIteUF :: Env -> UF -> Expr -> Expr -> CheckM (Sort, UF) getIteUF f uf e1 e2 = do t1 <- checkExpr f e1 t2 <- checkExpr f e2 - uf' <- unifysUF f Nothing uf [t1] [t2] + let !_ = unsafePerformIO $ print ("T1 is " ++ show t1) + let !_ = unsafePerformIO $ print ("T2 is " ++ show t1) + uf' <- unifyUF f uf Nothing t1 t2 + let !_ = unsafePerformIO $ print ("Unifying gives " ++ show uf') return (t1, uf') checkIteTyUF :: Env -> UF -> Expr -> Expr -> Expr -> Sort -> Sort -> CheckM (Sort, UF) @@ -1083,13 +1088,17 @@ checkApp' f to g e = do ufRef <- asks ufM uf <- liftIO $ readIORef ufRef uf' <- unifyManyUF f ge uf [it] [et] + let !_ = unsafePerformIO $ print ("Unified " ++ show it ++ " and " ++ show et) + let !_ = unsafePerformIO $ print ("Which gave " ++ show uf') liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) case to of Nothing -> return ot Just t' -> do - uf'' <- liftIO $ readIORef ufRef + ufRef' <- asks ufM + uf'' <- liftIO $ readIORef ufRef' uf''' <- unifyManyUF f ge uf'' [ot] [t'] - liftIO $ atomicModifyIORef' ufRef $ const (uf''', ()) + let !_ = unsafePerformIO $ print ("Unified " ++ show ot ++ " and " ++ show t') + liftIO $ atomicModifyIORef' ufRef' $ const (uf''', ()) _ <- checkExprAs f et e return ot @@ -1223,19 +1232,12 @@ checkRelTy f _ _ FInt s2 = checkNumeric f s2 `withError` errNonNumeric s2 checkRelTy f _ _ s1 FInt = checkNumeric f s1 `withError` errNonNumeric s1 checkRelTy f _ _ FReal s2 = checkFractional f s2 `withError` errNonFractional s2 checkRelTy f _ _ s1 FReal = checkFractional f s1 `withError` errNonFractional s1 -checkRelTy f e Eq t1 t2 = do +checkRelTy f e _ t1 t2 = do ufRef <- asks ufM - uf <- liftIO $ readIORef ufRef - uf' <- unifysUF f (Just e) uf [t1] [t2] `withError` errRel e t1 t2 - liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) - return () -checkRelTy f e Ne t1 t2 = do - ufRef <- asks ufM - uf <- liftIO $ readIORef ufRef + uf <- liftIO $ readIORef ufRef uf' <- unifysUF f (Just e) uf [t1] [t2] `withError` errRel e t1 t2 liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) return () -checkRelTy _ e _ t1 t2 = unless (t1 == t2) (throwErrorAt $ errRel e t1 t2) checkURel :: Expr -> Sort -> Sort -> CheckM () checkURel e s1 s2 = unless (b1 == b2) (throwErrorAt $ errRel e s1 s2) @@ -1286,9 +1288,9 @@ unify1UF f e !uf !t (FVar !i) = unifyVarUF f e uf i t unify1UF f e !uf (FApp !t1 !t2) (FApp !t1' !t2') = unifyManyUF f e uf [t1, t2] [t1', t2'] -unify1UF _ _ !θ (FTC !l1) (FTC !l2) +unify1UF _ _ !uf (FTC !l1) (FTC !l2) | isListTC l1 && isListTC l2 - = return θ + = return uf unify1UF f e !uf t1@(FAbs _ _) !t2 = do !t1' <- instantiate t1 unifyManyUF f e uf [t1'] [t2] From 24ca915897e70d7fafb3042cf7f9c49803400319 Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Mon, 3 Feb 2025 16:01:42 -0800 Subject: [PATCH 21/33] make sure union find is saved correctly --- src/Language/Fixpoint/SortCheck.hs | 41 +++++++++++------------------- 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index 7d2462026..c4de9b3cc 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -279,14 +279,11 @@ elabExprE msg env e = (!e', _) <- elab (env, envLookup) e fufRef <- asks ufM finalUF <- liftIO $ readIORef fufRef - let !_ = unsafePerformIO $ print ("Final UF " ++ show finalUF) return (applyExprUF finalUF e') of Left (ChError f') -> let e' = f' () in Left $ err (srcSpan e') (d (val e')) - Right s -> - let !_ = unsafePerformIO $ print ("Result " ++ show s) in - Right s + Right s -> Right s where sEnv = seSort env envLookup = (`lookupSEnvWithDistance` sEnv) @@ -591,25 +588,23 @@ elab !f (ENeg !e) = do return (ENeg e', s) elab f@(!_,!g) (ECst (EIte !p !e1 !e2) !t) = do + ufRef <- asks ufM (!p', !_) <- elab f p (!e1', !s1) <- elab f (eCst e1 t) (!e2', !s2) <- elab f (eCst e2 t) - ufRef <- asks ufM uf <- liftIO $ readIORef ufRef (!_, !uf') <- checkIteTyUF g uf p e1' e2' s1 s2 liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) return (EIte p' (eCst e1' s1) (eCst e2' s2), t) elab f@(!_,!g) (EIte !p !e1 !e2) = do - ufRef <- asks ufM - uf <- liftIO $ readIORef ufRef - (!t, uf') <- getIteUF g uf e1 e2 - liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + !t <- getIteUF g e1 e2 (!p', !_) <- elab f p (!e1', !s1) <- elab f (eCst e1 t) (!e2', !s2) <- elab f (eCst e2 t) - uf'' <- liftIO $ readIORef ufRef - (_, uf''') <- checkIteTyUF g uf'' p e1' e2' s1 s2 + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + (_, uf''') <- checkIteTyUF g uf p e1' e2' s1 s2 liftIO $ atomicModifyIORef' ufRef $ const (uf''', ()) return (EIte p' (eCst e1' s1) (eCst e2' s2), s2) @@ -648,7 +643,6 @@ elab f@(!_,!g) e@(PAtom !eq !e1 !e2) | eq == Eq || eq == Ne = do !e2' <- elabAs f t2' e2 !e1'' <- eCstAtom f e1' t1' !e2'' <- eCstAtom f e2' t2' - let !_ = unsafePerformIO $ print ("Right side is " ++ show e2'') return (PAtom eq e1'' e2'', boolSort) elab !f (PAtom !r !e1 !e2) @@ -708,7 +702,7 @@ elabAs f t e = notracepp _msg <$> go e where _msg = "elabAs: t = " ++ showpp t ++ "; e = " ++ showpp e go (EApp e1 e2) = elabAppAs f t e1 e2 - go e' = fst <$> elab f e' + go e' = let !_ = unsafePerformIO $ print ("Down the elab path " ++ show e') in fst <$> elab f e' -- DUPLICATION with `checkApp'` elabAppAs :: ElabEnv -> Sort -> Expr -> Expr -> CheckM Expr @@ -1027,15 +1021,15 @@ checkIte f p e1 e2 = do liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) return s -getIteUF :: Env -> UF -> Expr -> Expr -> CheckM (Sort, UF) -getIteUF f uf e1 e2 = do +getIteUF :: Env -> Expr -> Expr -> CheckM Sort +getIteUF f e1 e2 = do t1 <- checkExpr f e1 t2 <- checkExpr f e2 - let !_ = unsafePerformIO $ print ("T1 is " ++ show t1) - let !_ = unsafePerformIO $ print ("T2 is " ++ show t1) + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef uf' <- unifyUF f uf Nothing t1 t2 - let !_ = unsafePerformIO $ print ("Unifying gives " ++ show uf') - return (t1, uf') + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return t1 checkIteTyUF :: Env -> UF -> Expr -> Expr -> Expr -> Sort -> Sort -> CheckM (Sort, UF) checkIteTyUF f uf p e1 e2 t1 t2 = do @@ -1088,17 +1082,12 @@ checkApp' f to g e = do ufRef <- asks ufM uf <- liftIO $ readIORef ufRef uf' <- unifyManyUF f ge uf [it] [et] - let !_ = unsafePerformIO $ print ("Unified " ++ show it ++ " and " ++ show et) - let !_ = unsafePerformIO $ print ("Which gave " ++ show uf') liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) case to of Nothing -> return ot Just t' -> do - ufRef' <- asks ufM - uf'' <- liftIO $ readIORef ufRef' - uf''' <- unifyManyUF f ge uf'' [ot] [t'] - let !_ = unsafePerformIO $ print ("Unified " ++ show ot ++ " and " ++ show t') - liftIO $ atomicModifyIORef' ufRef' $ const (uf''', ()) + uf'' <- unifyManyUF f ge uf' [ot] [t'] + liftIO $ atomicModifyIORef' ufRef $ const (uf'', ()) _ <- checkExprAs f et e return ot From 78f6048e235fcd465fb493f2df3a1410cb3b8a65 Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Mon, 3 Feb 2025 16:02:38 -0800 Subject: [PATCH 22/33] remove debug prints --- src/Language/Fixpoint/SortCheck.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index c4de9b3cc..a43219b70 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -702,7 +702,7 @@ elabAs f t e = notracepp _msg <$> go e where _msg = "elabAs: t = " ++ showpp t ++ "; e = " ++ showpp e go (EApp e1 e2) = elabAppAs f t e1 e2 - go e' = let !_ = unsafePerformIO $ print ("Down the elab path " ++ show e') in fst <$> elab f e' + go e' = fst <$> elab f e' -- DUPLICATION with `checkApp'` elabAppAs :: ElabEnv -> Sort -> Expr -> Expr -> CheckM Expr From d3b71f5086902bb8148096fa0f7232b534ff1c79 Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Mon, 3 Feb 2025 18:43:35 -0800 Subject: [PATCH 23/33] avoiding unifying self in sub sorts --- src/Language/Fixpoint/Union.hs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/Language/Fixpoint/Union.hs b/src/Language/Fixpoint/Union.hs index e480e4881..90f62745e 100644 --- a/src/Language/Fixpoint/Union.hs +++ b/src/Language/Fixpoint/Union.hs @@ -1,10 +1,13 @@ +{-# LANGUAGE BangPatterns #-} module Language.Fixpoint.Union where import Data.HashMap.Strict (lookup, insert, HashMap, empty) import Prelude hiding (lookup) import Language.Fixpoint.Types.Sorts (Sort(..)) +import GHC.IO (unsafePerformIO) unionSub :: UF -> Int -> Sort -> Sort -> UF unionSub uf i s1 s2 = case (s1, s2) of + (FVar i1, FVar i2) -> if i1 == i2 then uf else union uf i1 s2 (FVar i1, _) -> union uf i1 s2 (_, FVar i2) -> union uf i2 s1 (_, _) -> unionVals uf i s1 s2 @@ -25,6 +28,7 @@ unionVals uf _ s1 s2 unionVals uf _ (FObj x) (FObj y) | x == y = uf +unionVals (MkUF uf) _ (FVar i) (FVar j) = if i == j then MkUF uf else MkUF (insert i (FVar j) uf) unionVals (MkUF uf) _ (FVar i) s = MkUF (insert i s uf) unionVals (MkUF uf) _ s (FVar i) = MkUF (insert i s uf) unionVals uf i (FFunc s1 s2) (FFunc s1' s2') = @@ -44,6 +48,7 @@ new :: UF new = MkUF empty union :: UF -> Int -> Sort -> UF union u@(MkUF ufM) tyv s = + let !_ = unsafePerformIO $ print ("union " ++ show tyv ++ " with " ++ show s ++ " with curr " ++ show ufM) in -- find the root for tyv let tyv_root = findWithIndex (MkUF ufM) tyv in case tyv_root of From 61e13704454e8bd0c06eea3a8559259a5faf6f0a Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Mon, 3 Feb 2025 18:43:55 -0800 Subject: [PATCH 24/33] Cleanup --- src/Language/Fixpoint/Union.hs | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/Language/Fixpoint/Union.hs b/src/Language/Fixpoint/Union.hs index 90f62745e..b4699456b 100644 --- a/src/Language/Fixpoint/Union.hs +++ b/src/Language/Fixpoint/Union.hs @@ -1,9 +1,7 @@ -{-# LANGUAGE BangPatterns #-} module Language.Fixpoint.Union where import Data.HashMap.Strict (lookup, insert, HashMap, empty) import Prelude hiding (lookup) import Language.Fixpoint.Types.Sorts (Sort(..)) -import GHC.IO (unsafePerformIO) unionSub :: UF -> Int -> Sort -> Sort -> UF unionSub uf i s1 s2 = case (s1, s2) of @@ -48,7 +46,6 @@ new :: UF new = MkUF empty union :: UF -> Int -> Sort -> UF union u@(MkUF ufM) tyv s = - let !_ = unsafePerformIO $ print ("union " ++ show tyv ++ " with " ++ show s ++ " with curr " ++ show ufM) in -- find the root for tyv let tyv_root = findWithIndex (MkUF ufM) tyv in case tyv_root of From 2c148da482fedf8bc0959cd10aba58084f6f90d7 Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Mon, 3 Feb 2025 19:14:14 -0800 Subject: [PATCH 25/33] make sure unified type variables are not unified in a looping manner --- src/Language/Fixpoint/Union.hs | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/Language/Fixpoint/Union.hs b/src/Language/Fixpoint/Union.hs index b4699456b..eb50c4032 100644 --- a/src/Language/Fixpoint/Union.hs +++ b/src/Language/Fixpoint/Union.hs @@ -1,7 +1,9 @@ +{-# LANGUAGE BangPatterns #-} module Language.Fixpoint.Union where import Data.HashMap.Strict (lookup, insert, HashMap, empty) import Prelude hiding (lookup) import Language.Fixpoint.Types.Sorts (Sort(..)) +import GHC.IO (unsafePerformIO) unionSub :: UF -> Int -> Sort -> Sort -> UF unionSub uf i s1 s2 = case (s1, s2) of @@ -44,8 +46,10 @@ unionVals _ _ s1 s2 = error ("Cannot unify " ++ show s1 ++ " and " ++ show s2) newtype UF = MkUF (HashMap Int Sort) deriving (Show) new :: UF new = MkUF empty -union :: UF -> Int -> Sort -> UF -union u@(MkUF ufM) tyv s = + + +unionSafe :: UF -> Int -> Sort -> UF +unionSafe u@(MkUF ufM) tyv s = -- find the root for tyv let tyv_root = findWithIndex (MkUF ufM) tyv in case tyv_root of @@ -53,7 +57,18 @@ union u@(MkUF ufM) tyv s = Nothing -> MkUF (insert tyv s ufM) -- otherwise, unify the current sort with -- the new one and insert that - Just (i, s') -> unionVals u i s s' + Just (i, s') -> + let !_ = unsafePerformIO $ print ("Here with " ++ show s' ++ " and " ++ show s) in + unionVals u i s s' + + +union :: UF -> Int -> Sort -> UF +union u tyv s = + case s of + FVar i -> case find u i of + Just (FVar j) -> if tyv == j then u else unionSafe u tyv s + _ -> unionSafe u tyv s + _ -> unionSafe u tyv s findWithIndex :: UF -> Int -> Maybe (Int, Sort) findWithIndex u@(MkUF ufM) k = do @@ -68,5 +83,7 @@ find (MkUF ufM) = f f k = do s <- lookup k ufM case s of - FVar i -> f i + FVar i -> case f i of + Nothing -> Just (FVar i) + Just s' -> Just s' s' -> Just s' \ No newline at end of file From 43c55bfb4bd8650d180cbf1f604cf6b9b40e9513 Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Mon, 3 Feb 2025 21:49:28 -0800 Subject: [PATCH 26/33] More cleanup --- src/Language/Fixpoint/Union.hs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/Language/Fixpoint/Union.hs b/src/Language/Fixpoint/Union.hs index eb50c4032..f031d983e 100644 --- a/src/Language/Fixpoint/Union.hs +++ b/src/Language/Fixpoint/Union.hs @@ -1,9 +1,7 @@ -{-# LANGUAGE BangPatterns #-} module Language.Fixpoint.Union where import Data.HashMap.Strict (lookup, insert, HashMap, empty) import Prelude hiding (lookup) import Language.Fixpoint.Types.Sorts (Sort(..)) -import GHC.IO (unsafePerformIO) unionSub :: UF -> Int -> Sort -> Sort -> UF unionSub uf i s1 s2 = case (s1, s2) of @@ -57,9 +55,7 @@ unionSafe u@(MkUF ufM) tyv s = Nothing -> MkUF (insert tyv s ufM) -- otherwise, unify the current sort with -- the new one and insert that - Just (i, s') -> - let !_ = unsafePerformIO $ print ("Here with " ++ show s' ++ " and " ++ show s) in - unionVals u i s s' + Just (i, s') -> unionVals u i s s' union :: UF -> Int -> Sort -> UF From 39c85c57d3b02175bdf8fe26cc2a59cc388b1900 Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Mon, 3 Feb 2025 22:05:14 -0800 Subject: [PATCH 27/33] wip --- src/Language/Fixpoint/SortCheck.hs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index a43219b70..dc93634fd 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -1122,6 +1122,7 @@ checkOpTy f e t t' = do uf <- liftIO $ readIORef ufRef uf' <- unifyUF f uf (Just e) t t' liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + -- probably need to unify this with a numeric sort for this to work properly checkNumeric f t >> return t checkFractional :: Env -> Sort -> CheckM () From 8165447648ec3b60d36dac38a80403d2244190ca Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Tue, 4 Feb 2025 08:19:26 -0800 Subject: [PATCH 28/33] remove checkNumeric for type variables and unify them instead --- src/Language/Fixpoint/SortCheck.hs | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index dc93634fd..91d65046e 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -1123,7 +1123,8 @@ checkOpTy f e t t' = do uf' <- unifyUF f uf (Just e) t t' liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) -- probably need to unify this with a numeric sort for this to work properly - checkNumeric f t >> return t + return t + -- checkNumeric f t >> return t checkFractional :: Env -> Sort -> CheckM () checkFractional f s@(FObj l) @@ -1218,9 +1219,33 @@ checkRelTy f _ _ s1@(FObj l) s2@(FObj l') | l /= l' checkRelTy _ _ _ FReal FReal = return () checkRelTy _ _ _ FInt FReal = return () checkRelTy _ _ _ FReal FInt = return () -checkRelTy f _ _ FInt s2 = checkNumeric f s2 `withError` errNonNumeric s2 +checkRelTy f e _ s1@FInt s2@(FVar _) = do + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifyUF f uf (Just e) s1 s2 `withError` errUnify (Just e) s1 s2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return () +checkRelTy f _ _ FInt s2 = checkNumeric f s2 `withError` errNonNumeric s2 +checkRelTy f e _ s1@(FVar _) s2@FInt = do + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifyUF f uf (Just e) s1 s2 `withError` errUnify (Just e) s1 s2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return () checkRelTy f _ _ s1 FInt = checkNumeric f s1 `withError` errNonNumeric s1 +checkRelTy f e _ s1@FReal s2@(FVar _) = do + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifyUF f uf (Just e) s1 s2 `withError` errUnify (Just e) s1 s2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return () checkRelTy f _ _ FReal s2 = checkFractional f s2 `withError` errNonFractional s2 +checkRelTy f e _ s1@(FVar _) s2@FReal = do + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + uf' <- unifyUF f uf (Just e) s1 s2 `withError` errUnify (Just e) s1 s2 + liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + return () checkRelTy f _ _ s1 FReal = checkFractional f s1 `withError` errNonFractional s1 checkRelTy f e _ t1 t2 = do ufRef <- asks ufM From f783b09fe62e00b1d669579ddfa87bafeffc9074 Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Tue, 4 Feb 2025 14:53:00 -0800 Subject: [PATCH 29/33] patch check book sort --- src/Language/Fixpoint/SortCheck.hs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index 91d65046e..8fdb60f99 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -1186,6 +1186,13 @@ checkPred :: Env -> Expr -> CheckM () checkPred f e = checkExpr f e >>= checkBoolSort e checkBoolSort :: Expr -> Sort -> CheckM () + +checkBoolSort e s@(FVar i) = do + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + case Union.find uf i of + Nothing -> throwErrorAt (errBoolSort e s) + Just s' -> if s' == boolSort then return () else throwErrorAt (errBoolSort e s) checkBoolSort e s | s == boolSort = return () | otherwise = throwErrorAt (errBoolSort e s) From c2fb110d39cffc30a525df7d4f840f89f2ab7bfd Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Tue, 4 Feb 2025 20:13:44 -0800 Subject: [PATCH 30/33] refactor union find --- src/Language/Fixpoint/SortCheck.hs | 13 ++-- src/Language/Fixpoint/Union.hs | 100 ++++++++++++----------------- 2 files changed, 48 insertions(+), 65 deletions(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index 8fdb60f99..6fb773ea9 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -1187,12 +1187,15 @@ checkPred f e = checkExpr f e >>= checkBoolSort e checkBoolSort :: Expr -> Sort -> CheckM () -checkBoolSort e s@(FVar i) = do +checkBoolSort e (FVar i) = do ufRef <- asks ufM uf <- liftIO $ readIORef ufRef - case Union.find uf i of - Nothing -> throwErrorAt (errBoolSort e s) - Just s' -> if s' == boolSort then return () else throwErrorAt (errBoolSort e s) + let s = Union.find uf i + if s == boolSort then return () else throwErrorAt (errBoolSort e s) + + -- case Union.find uf i of + -- Nothing -> throwErrorAt (errBoolSort e s) + -- Just s' -> if s' == boolSort then return () else throwErrorAt (errBoolSort e s) checkBoolSort e s | s == boolSort = return () | otherwise = throwErrorAt (errBoolSort e s) @@ -1580,7 +1583,7 @@ apply !θ = Vis.mapSort f applyUF :: UF -> Sort -> Sort applyUF uf = Vis.mapSort f where - f t@(FVar !i) = fromMaybe t (Union.find uf i) + f (FVar !i) = Union.find uf i f !t = t {-# SCC applyExprUF #-} diff --git a/src/Language/Fixpoint/Union.hs b/src/Language/Fixpoint/Union.hs index f031d983e..d1c8eebc3 100644 --- a/src/Language/Fixpoint/Union.hs +++ b/src/Language/Fixpoint/Union.hs @@ -1,21 +1,23 @@ -module Language.Fixpoint.Union where +{-# LANGUAGE BangPatterns #-} +module Language.Fixpoint.Union where import Data.HashMap.Strict (lookup, insert, HashMap, empty) import Prelude hiding (lookup) -import Language.Fixpoint.Types.Sorts (Sort(..)) +import Language.Fixpoint.Types (Sort(..)) -unionSub :: UF -> Int -> Sort -> Sort -> UF -unionSub uf i s1 s2 = case (s1, s2) of - (FVar i1, FVar i2) -> if i1 == i2 then uf else union uf i1 s2 - (FVar i1, _) -> union uf i1 s2 - (_, FVar i2) -> union uf i2 s1 - (_, _) -> unionVals uf i s1 s2 +unionSub :: UF -> Sort -> Sort -> UF +unionSub uf s1 s2 = + -- get the parents and call unionVals + let s1_root = getRep uf s1 + s2_root = getRep uf s2 + in + unionVals uf s1_root s2_root -------------------------------------------------------------------------------- -- | union for sorts in union find -------------------------------------------------------------------------------- -unionVals :: UF -> Int -> Sort -> Sort -> UF +unionVals :: UF -> Sort -> Sort -> UF -------------------------------------------------------------------------------- -unionVals uf _ s1 s2 +unionVals uf s1 s2 | isNumericSort s1 && isNumericSort s2 = uf where isNumericSort FReal = True @@ -24,62 +26,40 @@ unionVals uf _ s1 s2 isNumericSort FInt = True isNumericSort _ = False -unionVals uf _ (FObj x) (FObj y) +unionVals uf (FObj x) (FObj y) | x == y = uf -unionVals (MkUF uf) _ (FVar i) (FVar j) = if i == j then MkUF uf else MkUF (insert i (FVar j) uf) -unionVals (MkUF uf) _ (FVar i) s = MkUF (insert i s uf) -unionVals (MkUF uf) _ s (FVar i) = MkUF (insert i s uf) -unionVals uf i (FFunc s1 s2) (FFunc s1' s2') = - let uf' = unionSub uf i s1 s1' in - unionSub uf' i s2 s2' -unionVals uf i (FApp s1 s2) (FApp s1' s2') = - let uf' = unionSub uf i s1 s1' in - unionSub uf' i s2 s2' -unionVals uf i (FAbs _ s) (FAbs _ s') = unionSub uf i s s' -unionVals uf _ (FTC s1) (FTC s2) - | s1 == s2 = uf -unionVals _ _ s1 s2 = error ("Cannot unify " ++ show s1 ++ " and " ++ show s2) - +-- unionVals u@(MkUF uf) (FVar i) (FVar j) = if i == j then u else MkUF (insert i (FVar j) uf) +unionVals (MkUF uf) (FVar i) s = MkUF (insert i s uf) +unionVals (MkUF uf) s (FVar i) = MkUF (insert i s uf) +unionVals uf (FFunc s1 s2) (FFunc s1' s2') = let uf' = unionSub uf s1 s1' in unionSub uf' s2 s2' +unionVals uf (FApp s1 s2) (FApp s1' s2') = let uf' = unionSub uf s1 s1' in unionSub uf' s2 s2' +unionVals uf (FAbs _ s) (FAbs _ s') = unionSub uf s s' +unionVals uf (FTC s1) (FTC s2) | s1 == s2 = uf +unionVals _ s1 s2 = error ("Cannot unify " ++ show s1 ++ " and " ++ show s2) + newtype UF = MkUF (HashMap Int Sort) deriving (Show) new :: UF new = MkUF empty - -unionSafe :: UF -> Int -> Sort -> UF -unionSafe u@(MkUF ufM) tyv s = - -- find the root for tyv - let tyv_root = findWithIndex (MkUF ufM) tyv in - case tyv_root of - -- if tyv not in union find, insert - Nothing -> MkUF (insert tyv s ufM) - -- otherwise, unify the current sort with - -- the new one and insert that - Just (i, s') -> unionVals u i s s' - - union :: UF -> Int -> Sort -> UF -union u tyv s = - case s of - FVar i -> case find u i of - Just (FVar j) -> if tyv == j then u else unionSafe u tyv s - _ -> unionSafe u tyv s - _ -> unionSafe u tyv s +union !u !tyv !s = + let tyv_root = find u tyv + sort_root = getRep u s + in + if tyv_root == sort_root then u else unionVals u tyv_root sort_root -findWithIndex :: UF -> Int -> Maybe (Int, Sort) -findWithIndex u@(MkUF ufM) k = do - s <- lookup k ufM - case s of - FVar i -> findWithIndex u i - s' -> Just (k, s') +getRep :: UF -> Sort -> Sort +getRep u s = + case s of + FVar i -> find u i + _ -> s -find :: UF -> Int -> Maybe Sort -find (MkUF ufM) = f - where - f k = do - s <- lookup k ufM - case s of - FVar i -> case f i of - Nothing -> Just (FVar i) - Just s' -> Just s' - s' -> Just s' \ No newline at end of file +find :: UF -> Int -> Sort +find (MkUF ufM) = f + where + f k = do + case lookup k ufM of + Nothing -> FVar k + Just (FVar i) -> f i + Just s -> s \ No newline at end of file From 8482e2806993ffb2726fe6de0d3c00617a9047ae Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Tue, 4 Feb 2025 21:09:30 -0800 Subject: [PATCH 31/33] rip out some things --- src/Language/Fixpoint/SortCheck.hs | 48 +++++++++++++++--------------- src/Language/Fixpoint/Union.hs | 18 ++++------- 2 files changed, 30 insertions(+), 36 deletions(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index 6fb773ea9..8c11bbcc2 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -552,10 +552,10 @@ elab f@(!_, !g) e@(EBin !o !e1 !e2) = do elab !f (EApp !e1 !e2) = do (!e1', !s1, !e2', !s2, !s) <- elabEApp f e1 e2 let !e = eAppC s (eCst e1' s1) (eCst e2' s2) - ufRef <- asks ufM - uf <- liftIO $ readIORef ufRef - uf' <- unifyExprUF (snd f) uf e - liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) + -- ufRef <- asks ufM + -- uf <- liftIO $ readIORef ufRef + -- uf' <- unifyExprUF (snd f) uf e + -- liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) return (e, s) @@ -1158,26 +1158,26 @@ checkEqConstrUF f e uf a t = Found tA -> unify1UF f e uf tA t _ -> throwErrorAt $ errUnifyMsg (Just "ceq2") e (FObj a) t -{-# SCC unifyExprUF #-} -unifyExprUF :: Env -> UF -> Expr -> CheckM UF -unifyExprUF f uf (EApp e1 e2) = do - uf1 <- unifyExprUF f uf e1 - uf2 <- unifyExprUF f uf1 e2 - unifyExprAppUF f uf2 e1 e2 - -unifyExprUF f uf (ECst e _) - = unifyExprUF f uf e -unifyExprUF _ uf _ - = return uf - -unifyExprAppUF :: Env -> UF -> Expr -> Expr -> CheckM UF -unifyExprAppUF f uf e1 e2 = do - case (getArg $ exprSortMaybe e1, exprSortMaybe e2) of - (Just s1, Just s2) -> unifyUF f uf (Just $ EApp e1 e2) s1 s2 - _ -> return uf - where - getArg (Just (FFunc t1 _)) = Just t1 - getArg _ = Nothing +-- {-# SCC unifyExprUF #-} +-- unifyExprUF :: Env -> UF -> Expr -> CheckM UF +-- unifyExprUF f uf (EApp e1 e2) = do +-- uf1 <- unifyExprUF f uf e1 +-- uf2 <- unifyExprUF f uf1 e2 +-- unifyExprAppUF f uf2 e1 e2 + +-- unifyExprUF f uf (ECst e _) +-- = unifyExprUF f uf e +-- unifyExprUF _ uf _ +-- = return uf + +-- unifyExprAppUF :: Env -> UF -> Expr -> Expr -> CheckM UF +-- unifyExprAppUF f uf e1 e2 = do +-- case (getArg $ exprSortMaybe e1, exprSortMaybe e2) of +-- (Just s1, Just s2) -> unifyUF f uf (Just $ EApp e1 e2) s1 s2 +-- _ -> return uf +-- where +-- getArg (Just (FFunc t1 _)) = Just t1 +-- getArg _ = Nothing -------------------------------------------------------------------------------- -- | Checking Predicates ------------------------------------------------------- diff --git a/src/Language/Fixpoint/Union.hs b/src/Language/Fixpoint/Union.hs index d1c8eebc3..12e4b9e9b 100644 --- a/src/Language/Fixpoint/Union.hs +++ b/src/Language/Fixpoint/Union.hs @@ -4,14 +4,6 @@ import Data.HashMap.Strict (lookup, insert, HashMap, empty) import Prelude hiding (lookup) import Language.Fixpoint.Types (Sort(..)) -unionSub :: UF -> Sort -> Sort -> UF -unionSub uf s1 s2 = - -- get the parents and call unionVals - let s1_root = getRep uf s1 - s2_root = getRep uf s2 - in - unionVals uf s1_root s2_root - -------------------------------------------------------------------------------- -- | union for sorts in union find -------------------------------------------------------------------------------- @@ -28,12 +20,14 @@ unionVals uf s1 s2 unionVals uf (FObj x) (FObj y) | x == y = uf --- unionVals u@(MkUF uf) (FVar i) (FVar j) = if i == j then u else MkUF (insert i (FVar j) uf) +unionVals u@(MkUF uf) (FVar i) (FVar j) = if i == j then u else MkUF (insert i (FVar j) uf) unionVals (MkUF uf) (FVar i) s = MkUF (insert i s uf) unionVals (MkUF uf) s (FVar i) = MkUF (insert i s uf) -unionVals uf (FFunc s1 s2) (FFunc s1' s2') = let uf' = unionSub uf s1 s1' in unionSub uf' s2 s2' -unionVals uf (FApp s1 s2) (FApp s1' s2') = let uf' = unionSub uf s1 s1' in unionSub uf' s2 s2' -unionVals uf (FAbs _ s) (FAbs _ s') = unionSub uf s s' +unionVals uf (FFunc s1 s2) (FFunc s1' s2') = unionVals uf' (getRep uf' s2) (getRep uf' s2') + where uf' = unionVals uf (getRep uf s1) (getRep uf s1') +unionVals uf (FApp s1 s2) (FApp s1' s2') = unionVals uf' (getRep uf' s2) (getRep uf' s2') + where uf' = unionVals uf (getRep uf s1) (getRep uf s1') +unionVals uf (FAbs _ s) (FAbs _ s') = unionVals uf (getRep uf s) (getRep uf s') unionVals uf (FTC s1) (FTC s2) | s1 == s2 = uf unionVals _ s1 s2 = error ("Cannot unify " ++ show s1 ++ " and " ++ show s2) From 9d80a4773675dbcf1c44ea8a721b1d8210200734 Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Tue, 4 Feb 2025 22:21:25 -0800 Subject: [PATCH 32/33] fix check numeric --- src/Language/Fixpoint/SortCheck.hs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index 8c11bbcc2..3d7caf486 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -1137,6 +1137,11 @@ checkNumeric :: Env -> Sort -> CheckM () checkNumeric f s@(FObj l) = do t <- checkSym f l unless (t `elem` [FNum, FFrac, intSort, FInt]) (throwErrorAt $ errNonNumeric s) +checkNumeric _ s@(FVar i) = do + ufRef <- asks ufM + uf <- liftIO $ readIORef ufRef + unless (isNumeric (Union.find uf i)) (throwErrorAt $ errNonNumeric s) + checkNumeric _ s = unless (isNumeric s) (throwErrorAt $ errNonNumeric s) From 6de29b5717627ed18f40b5b0d974959597450ec5 Mon Sep 17 00:00:00 2001 From: Vivien Rindisbacher Date: Tue, 4 Feb 2025 22:32:11 -0800 Subject: [PATCH 33/33] rip out unecessary cases in checRelTy --- src/Language/Fixpoint/SortCheck.hs | 55 +----------------------------- 1 file changed, 1 insertion(+), 54 deletions(-) diff --git a/src/Language/Fixpoint/SortCheck.hs b/src/Language/Fixpoint/SortCheck.hs index 3d7caf486..d66740b50 100644 --- a/src/Language/Fixpoint/SortCheck.hs +++ b/src/Language/Fixpoint/SortCheck.hs @@ -552,10 +552,6 @@ elab f@(!_, !g) e@(EBin !o !e1 !e2) = do elab !f (EApp !e1 !e2) = do (!e1', !s1, !e2', !s2, !s) <- elabEApp f e1 e2 let !e = eAppC s (eCst e1' s1) (eCst e2' s2) - -- ufRef <- asks ufM - -- uf <- liftIO $ readIORef ufRef - -- uf' <- unifyExprUF (snd f) uf e - -- liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) return (e, s) @@ -1163,27 +1159,6 @@ checkEqConstrUF f e uf a t = Found tA -> unify1UF f e uf tA t _ -> throwErrorAt $ errUnifyMsg (Just "ceq2") e (FObj a) t --- {-# SCC unifyExprUF #-} --- unifyExprUF :: Env -> UF -> Expr -> CheckM UF --- unifyExprUF f uf (EApp e1 e2) = do --- uf1 <- unifyExprUF f uf e1 --- uf2 <- unifyExprUF f uf1 e2 --- unifyExprAppUF f uf2 e1 e2 - --- unifyExprUF f uf (ECst e _) --- = unifyExprUF f uf e --- unifyExprUF _ uf _ --- = return uf - --- unifyExprAppUF :: Env -> UF -> Expr -> Expr -> CheckM UF --- unifyExprAppUF f uf e1 e2 = do --- case (getArg $ exprSortMaybe e1, exprSortMaybe e2) of --- (Just s1, Just s2) -> unifyUF f uf (Just $ EApp e1 e2) s1 s2 --- _ -> return uf --- where --- getArg (Just (FFunc t1 _)) = Just t1 --- getArg _ = Nothing - -------------------------------------------------------------------------------- -- | Checking Predicates ------------------------------------------------------- -------------------------------------------------------------------------------- @@ -1196,11 +1171,7 @@ checkBoolSort e (FVar i) = do ufRef <- asks ufM uf <- liftIO $ readIORef ufRef let s = Union.find uf i - if s == boolSort then return () else throwErrorAt (errBoolSort e s) - - -- case Union.find uf i of - -- Nothing -> throwErrorAt (errBoolSort e s) - -- Just s' -> if s' == boolSort then return () else throwErrorAt (errBoolSort e s) + if s == boolSort then return () else throwErrorAt (errBoolSort e s) checkBoolSort e s | s == boolSort = return () | otherwise = throwErrorAt (errBoolSort e s) @@ -1234,33 +1205,9 @@ checkRelTy f _ _ s1@(FObj l) s2@(FObj l') | l /= l' checkRelTy _ _ _ FReal FReal = return () checkRelTy _ _ _ FInt FReal = return () checkRelTy _ _ _ FReal FInt = return () -checkRelTy f e _ s1@FInt s2@(FVar _) = do - ufRef <- asks ufM - uf <- liftIO $ readIORef ufRef - uf' <- unifyUF f uf (Just e) s1 s2 `withError` errUnify (Just e) s1 s2 - liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) - return () checkRelTy f _ _ FInt s2 = checkNumeric f s2 `withError` errNonNumeric s2 -checkRelTy f e _ s1@(FVar _) s2@FInt = do - ufRef <- asks ufM - uf <- liftIO $ readIORef ufRef - uf' <- unifyUF f uf (Just e) s1 s2 `withError` errUnify (Just e) s1 s2 - liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) - return () checkRelTy f _ _ s1 FInt = checkNumeric f s1 `withError` errNonNumeric s1 -checkRelTy f e _ s1@FReal s2@(FVar _) = do - ufRef <- asks ufM - uf <- liftIO $ readIORef ufRef - uf' <- unifyUF f uf (Just e) s1 s2 `withError` errUnify (Just e) s1 s2 - liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) - return () checkRelTy f _ _ FReal s2 = checkFractional f s2 `withError` errNonFractional s2 -checkRelTy f e _ s1@(FVar _) s2@FReal = do - ufRef <- asks ufM - uf <- liftIO $ readIORef ufRef - uf' <- unifyUF f uf (Just e) s1 s2 `withError` errUnify (Just e) s1 s2 - liftIO $ atomicModifyIORef' ufRef $ const (uf', ()) - return () checkRelTy f _ _ s1 FReal = checkFractional f s1 `withError` errNonFractional s1 checkRelTy f e _ t1 t2 = do ufRef <- asks ufM