diff --git a/amuletml.cabal b/amuletml.cabal index 8376eb73f..cf332eace 100644 --- a/amuletml.cabal +++ b/amuletml.cabal @@ -281,6 +281,7 @@ library , Core.Lower , Core.Lower.Basic , Core.Lower.Pattern + , Core.Lower.TypeRepr , Core.Types , Core.Builtin , Core.Optimise @@ -291,7 +292,6 @@ library , Core.Optimise.Reduce.Inline , Core.Optimise.Reduce.Pattern , Core.Optimise.Sinking - , Core.Optimise.Newtype , Core.Optimise.Uncurry , Core.Optimise.DeadCode , Core.Optimise.CommonExpElim diff --git a/bin/Amc.hs b/bin/Amc.hs index fe1379589..2295e1faf 100644 --- a/bin/Amc.hs +++ b/bin/Amc.hs @@ -22,7 +22,6 @@ import Backend.Lua import qualified Syntax.Builtin as Bi import Core.Optimise.Reduce (reducePass) -import Core.Optimise.Newtype (killNewtypePass) import Core.Optimise.DeadCode (deadCodePass) import Core.Simplify (optimise) import Core.Core (Stmt) @@ -65,7 +64,7 @@ runCompile opt (DoLint lint) dconfig file = do Opt -> optimise lint core NoOpt -> do lintIt "Lower" (checkStmt emptyScope core) (pure ()) - (lintIt "Optimised" =<< checkStmt emptyScope) . deadCodePass <$> (reducePass =<< killNewtypePass core) + (lintIt "Optimised" =<< checkStmt emptyScope) . deadCodePass <$> reducePass core lua = compileProgram optimised in ( Just (env, core, optimised, lua) , errors diff --git a/src/Core/Lower.hs b/src/Core/Lower.hs index 2df18f714..594340e46 100644 --- a/src/Core/Lower.hs +++ b/src/Core/Lower.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE LambdaCase, TupleSections, +{-# LANGUAGE TupleSections, PatternSynonyms, RankNTypes, ScopedTypeVariables, FlexibleContexts, ConstraintKinds, OverloadedStrings, TypeFamilies #-} module Core.Lower @@ -32,6 +32,7 @@ import Core.Optimise (substituteInType, substituteInTys, fresh, freshFrom) import Core.Core hiding (Atom, Term, Stmt, Type, Pattern, Arm) import Core.Core (pattern Atom) import Core.Types (unify, unifyClosed, replaceTy) +import Core.Lower.TypeRepr import Core.Lower.Pattern import Core.Lower.Basic import Core.Var @@ -40,7 +41,7 @@ import qualified Syntax as S import Syntax.Let import Syntax.Var (Var, Typed, VarResolved(..)) import Syntax.Transform -import Syntax (Expr(..), Pattern(..), Skolem(..), ModuleTerm(..), Toplevel(..), Constructor(..), Arm(..)) +import Syntax (Expr(..), Pattern(..), Skolem(..), ModuleTerm(..), Toplevel(..), Arm(..)) import Text.Pretty.Semantic (pretty) @@ -52,13 +53,18 @@ type Stmt = C.Stmt CoVar type Lower = ContT Term defaultState :: LowerState -defaultState = LS mempty ctors mempty where +defaultState = LS mempty ctors types where ctors :: VarMap.Map (C.Type CoVar) ctors = VarMap.fromList [ (C.vCONS, ForallTy (Relevant name) StarTy $ VarTy name `prodTy` AppTy C.tyList (VarTy name) `arrTy` AppTy C.tyList (VarTy name)) , (C.vNIL, ForallTy (Relevant name) StarTy $ AppTy C.tyList (VarTy name))] + types :: VarMap.Map TypeRepr + types = VarMap.fromList + ( (C.vList, SumTy (VarSet.fromList [C.vCONS, C.vNIL])) + : map (,OpaqueTy) [ C.vBool, C.vInt, C.vString, C.vFloat, C.vUnit + , C.vLazy, C.vArrow, C.vProduct, C.vRefTy ] ) name = C.tyvarA arrTy = ForallTy Irrelevant prodTy a b = RowsTy NilTy [("_1", a), ("_2", b)] @@ -365,22 +371,11 @@ lowerProg' (LetStmt _ vs:prg) = do vs' <- lowerLet vs foldr ((.) . ((:) . C.StmtLet)) id vs' <$$> lowerProg' prg -lowerProg' (TypeDecl _ var _ Nothing _:prg) = - (C.Type (mkType var) []:) <$$> lowerProg' prg -lowerProg' (TypeDecl _ var _ (Just cons) _:prg) = do - let cons' = map (\case - UnitCon _ p (_, t) -> (p, mkCon p, lowerType t) - ArgCon _ p _ (_, t) -> (p, mkCon p, lowerType t) - GadtCon _ p t _ -> (p, mkCon p, lowerType t)) - cons - ccons = map (\(_, a, b) -> (a, b)) cons' - scons = map (\(a, _, b) -> (mkCon a, b)) cons' - - conset = VarSet.fromList (map fst scons) - - (C.Type (mkType var) ccons:) <$$> local (\s -> - s { ctors = VarMap.union (VarMap.fromList scons) (ctors s) - , types = VarMap.insert (mkType var) conset (types s) +lowerProg' (TypeDecl _ var _ cons _:prg) = do + ~(tyStmts@(C.Type _ cs:_), repr) <- getTypeRepr (mkType var) cons + (tyStmts++) <$$> local (\s -> + s { ctors = VarMap.union (VarMap.fromList cs) (ctors s) + , types = VarMap.insert (mkType var) repr (types s) }) (lowerProg' prg) lowerLet :: MonadLower m => [S.Binding Typed] -> m [Binding CoVar] diff --git a/src/Core/Lower/Basic.hs b/src/Core/Lower/Basic.hs index 8dbaa0404..0b7343c74 100644 --- a/src/Core/Lower/Basic.hs +++ b/src/Core/Lower/Basic.hs @@ -1,6 +1,7 @@ {-# LANGUAGE OverloadedStrings, ConstraintKinds, FlexibleContexts #-} module Core.Lower.Basic - ( LowerState(..) + ( TypeRepr(..) + , LowerState(..) , LowerTrack , MonadLower , mkTyvar, mkVal, mkType, mkCo, mkCon, mkVar @@ -24,12 +25,23 @@ import qualified Syntax as S import Syntax.Var (VarResolved(..), Var, Resolved, Typed) import Syntax (Lit(..), Skolem(..)) +data TypeRepr + = OpaqueTy -- ^ An opaque type, for interfacing with foreign values. + | SumTy VarSet.Set -- ^ A sum type, with the set of constructors. + + -- | A type which just wraps another. + -- + -- This holds the name of the constructor, and the inner and outer + -- type, both sharing their free variables. + | WrapperTy CoVar (C.Type CoVar) (C.Type CoVar) + deriving (Show, Eq) + data LowerState = LS { vars :: VarMap.Map (C.Type CoVar) , ctors :: VarMap.Map (C.Type CoVar) -- | The map of types to their constructors /if they have any/. - , types :: VarMap.Map VarSet.Set + , types :: VarMap.Map TypeRepr } deriving (Eq, Show) instance Semigroup LowerState where diff --git a/src/Core/Lower/Pattern.hs b/src/Core/Lower/Pattern.hs index 8d4cd59d1..d7fd2e9bc 100644 --- a/src/Core/Lower/Pattern.hs +++ b/src/Core/Lower/Pattern.hs @@ -31,6 +31,7 @@ import Data.Bifunctor import Data.Foldable import Data.Triple import Data.Maybe +import Data.Span import qualified Core.Core as C import Core.Optimise (substituteInType, substituteInTys, fresh, freshFrom) @@ -85,6 +86,12 @@ data ArmNode , nodeNodes :: [(Pattern CoVar, ArmNode)] -- ^ The child nodes, and their associated pattern. } + | ArmLet + { nodeArms :: ArmSet + , nodeSuccess :: [ArmLeaf] -- Should be empty + , nodeBind :: (CoVar, Type CoVar, Term CoVar) + , nodeBody :: ArmNode + } deriving (Show) instance Pretty ArmNode where @@ -92,6 +99,10 @@ instance Pretty ArmNode where pretty (ArmMatch arms success atom nodes) = "Match" <+> parens (shown arms) <+> shown success <+> pretty atom <#> (indent 2 . vsep $ map (\(p, n) -> pretty p <+> "=>" <#> indent 2 (pretty n)) nodes) + pretty (ArmLet arms success (v, ty, x) node) + = "Let" <+> parens (shown arms) <+> shown success <+> pretty v <+> colon <+> pretty ty <+> equals <+> pretty x + <#> indent 2 (pretty node) + -- | A of a single case in a match expression. data PatternRow @@ -197,6 +208,8 @@ flattenNode bodies guards (ArmMatch _ leafs atom' children) = do let branches = foldr (flip (HSet.foldr add) . nodeArms . snd) mempty children in foldr (add . leafArm) branches leafs where add k = HMap.insertWith (+) k (1 :: Int) +flattenNode bodies guards (ArmLet _ _ bind child) = + C.Let (One bind) <$> flattenNode bodies guards child -- | Lift a pattern match into a lambda, passing arguments as values. generateBinds :: forall m. MonadLower m @@ -359,12 +372,11 @@ lowerOne tys rss = do getCtors v = do ctor <- VarMap.lookup v (ctors state) ty <- getType ctor - VarMap.lookup ty (types state) - - getType (ForallTy _ _ t) = getType t - getType (ConTy a) = pure a - getType (AppTy f _) = getType f - getType _ = Nothing + case VarMap.lookup ty (types state) of + Nothing -> error ("Cannot find " ++ show ty) + Just OpaqueTy -> Nothing + Just (SumTy ctors) -> Just ctors + Just (WrapperTy ctor _ _) -> Just (VarSet.singleton ctor) -- | Compute the "arity" heuristic for a given row variable. -- @@ -415,8 +427,8 @@ lowerOneOf preLeafs var ty tys = go [] . map prepare go unc [] = lowerOne tys (reverse unc) go unc rs@((S.PRecord{},_):_) = goRows unc mempty rs - go unc rs@((S.Destructure{},_):_) = goCtors unc mempty rs - go unc rs@((S.PGadtCon{},_):_) = goCtors unc mempty rs + go unc rs@((S.Destructure{},_):_) = goCtorsWith unc rs + go unc rs@((S.PGadtCon{},_):_) = goCtorsWith unc rs go unc rs@((S.PLiteral{},_):_) = goLiterals unc mempty rs go unc ((p, r):rs) = go (goGeneric p r:unc) rs @@ -464,6 +476,46 @@ lowerOneOf preLeafs var ty tys = go [] . map prepare pure ( Map.insert f (v, lowerType (S.getType p)) fs , VarMap.insert v p ps ) + goCtorsWith unc rs = do + let Just tyName = getType ty + repr <- asks (fromMaybe (error ("Cannot find " ++ show tyName)) . VarMap.lookup tyName . types) + case repr of + OpaqueTy -> error "Impossible matching on opaque type" + SumTy _ -> goCtors unc mempty rs + WrapperTy _ from to -> do + let Just map = unify to ty + from' = substituteInType map from + coVar <- case rs of + ((S.PGadtCon _ _ _ (Just child) _, _):_) -> freshFromPat child + _ -> fresh ValueVar + node <- goNewtype unc (Capture coVar from') rs + pure (ArmLet (nodeArms node) mempty + (coVar, from', Cast (Ref var ty) from' (SameRepr ty from')) + node) + + -- | Split patterns into those matching against the constructor and those not + goNewtype :: [PatternRow] -> Capture CoVar + -> [(S.Pattern Typed, PatternRow)] + -> m ArmNode + goNewtype unc (Capture c cty) [] = + lowerOne (VarMap.insert c cty tys) (reverse unc) + + goNewtype unc cap@(Capture c _) (( S.PGadtCon _ _ [] (Just p) _ + , PR arm pats gd vBind tyBind ):rs) = + -- The wrapped value is matched by the pattern - focus on that next. + let r' = PR arm (VarMap.insert c p pats) gd vBind tyBind + in goNewtype (r':unc) cap rs + + goNewtype unc cap@(Capture c _) (( S.PGadtCon _ _ [(v, t)] Nothing _ + , PR arm pats gd vBind tyBind ):rs) = + -- The wrapped value is the dictionary - just add a wildcard pattern. + let r' = PR arm (VarMap.insert c (S.Capture v (internal, t)) pats) gd vBind tyBind + in goNewtype (r':unc) cap rs + + goNewtype _ _ ((S.PGadtCon{}, _):_) = error "Impossible: Malformed pattern for newtype." + + goNewtype unc cap ((p, r):rs) = goNewtype (goGeneric p r:unc) cap rs + -- | Build up a mapping of (constructors -> (contents variable, rows)). goCtors :: [PatternRow] -> VarMap.Map ([Capture CoVar], [PatternRow]) -> [(S.Pattern Typed, PatternRow)] @@ -577,3 +629,9 @@ dropNForalls :: Int -> Type a -> Type a dropNForalls 0 t = t dropNForalls x (ForallTy _ _ t) = dropNForalls (x - 1) t dropNForalls _ _ = undefined + +getType :: Type a -> Maybe a +getType (ForallTy _ _ t) = getType t +getType (ConTy a) = pure a +getType (AppTy f _) = getType f +getType _ = Nothing diff --git a/src/Core/Lower/TypeRepr.hs b/src/Core/Lower/TypeRepr.hs new file mode 100644 index 000000000..067e77c73 --- /dev/null +++ b/src/Core/Lower/TypeRepr.hs @@ -0,0 +1,62 @@ +{-# LANGUAGE LambdaCase, ScopedTypeVariables #-} +module Core.Lower.TypeRepr (getTypeRepr) where + +import qualified Data.VarSet as VarSet + +import Control.Monad.Namey +import Control.Monad + +import Core.Lower.Basic +import Core.Optimise + +import Syntax (Constructor(..)) +import Syntax.Var (Typed) + +getTypeRepr :: MonadNamey m + => CoVar -> Maybe [Constructor Typed] + -> m ([Stmt CoVar], TypeRepr) +getTypeRepr var Nothing = pure ([Type var []], OpaqueTy) +getTypeRepr var (Just ctors) = + let ctors' = map (\case + UnitCon _ p (_, t) -> (mkCon p, lowerType t) + ArgCon _ p _ (_, t) -> (mkCon p, lowerType t) + GadtCon _ p t _ -> (mkCon p, lowerType t)) ctors + in case ctors' of + [(ctor, ty)] | Just nt@(Spine _ dom cod) <- isNewtype ty -> do + let CoVar name id _ = ctor + + wrapper <- newtypeWorker nt + pure ( [ Type var [], StmtLet (One (CoVar name id ValueVar, ty, wrapper))] + , WrapperTy ctor dom cod ) + + _ -> pure ( [ Type var ctors' ] + , SumTy (VarSet.fromList (map fst ctors')) ) + +isNewtype :: IsVar a => Type a -> Maybe (Spine a) +isNewtype (ForallTy Irrelevant _ ForallTy{}) = Nothing -- Cannot have multiple relevant arguments +isNewtype (ForallTy Irrelevant from to) = + pure (Spine [(Irrelevant, from)] from to) +isNewtype (ForallTy (Relevant var) k rest) = do + (Spine tys from to) <- isNewtype rest + guard (var `occursInTy` to) + pure (Spine ((Relevant var, k):tys) from to) +isNewtype _ = Nothing + +data Spine a = + Spine [(BoundTv a, Type a)] (Type a) (Type a) + deriving (Eq, Show, Ord) + +newtypeWorker :: forall a m. (IsVar a, MonadNamey m) + => Spine a -> m (Term a) +newtypeWorker (Spine tys dom cod) = do + let wrap :: [(BoundTv a, Type a)] -> (a -> Type a -> Term a) -> m (Term a) + wrap ((Relevant v, c):ts) ex = Lam (TypeArgument v c) <$> wrap ts ex + wrap [(Irrelevant, c)] ex = do + v <- fresh ValueVar + Lam (TermArgument (fromVar v) c) <$> pure (ex (fromVar v) c) + wrap _ _ = undefined + + work :: a -> Type a -> Term a + work var ty = Cast (Ref var ty) cod (SameRepr dom cod) + + wrap tys work diff --git a/src/Core/Optimise/Newtype.hs b/src/Core/Optimise/Newtype.hs deleted file mode 100644 index dadb4e898..000000000 --- a/src/Core/Optimise/Newtype.hs +++ /dev/null @@ -1,119 +0,0 @@ -{-# LANGUAGE ScopedTypeVariables #-} - -{- | Eliminate new types with a single case, converting constructors and - pattern matches on them into a coercion. --} -module Core.Optimise.Newtype (killNewtypePass) where - -import Control.Monad.Namey -import Control.Monad -import Control.Lens - -import qualified Data.VarMap as V -import Data.Triple - -import Core.Optimise -import Core.Types - --- | Run the new-type elimination pass. -killNewtypePass :: forall a m. (IsVar a, MonadNamey m) => [Stmt a] -> m [Stmt a] -killNewtypePass = go mempty mempty where - go :: V.Map (Atom a) -> V.Map (Coercion a) -> [Stmt a] -> m [Stmt a] - go ss m (Type n cs:xs) = case cs of - [(var, tp)] | Just nt <- isNewtype tp -> do - (con, phi, sub) <- newtypeWorker (var, tp) nt - (Type n [] :) . (con :) <$> go (ss <> sub) (V.insert (toVar var) phi m) xs - _ -> (Type n cs:) <$> go ss m xs - go ss m (x@Foreign{}:xs) = (x:) <$> go ss m xs - go ss m (StmtLet (Many vs):xs) = do - vs' <- goBinding ss m vs - xs' <- go ss m xs - pure (StmtLet (Many vs'):xs') - go ss m (StmtLet (One v):xs) = do - ~[v'] <- goBinding ss m [v] - xs' <- go ss m xs - pure (StmtLet (One v'):xs') - - go _ _ [] = pure [] - -isNewtype :: IsVar a => Type a -> Maybe (Spine a) -isNewtype (ForallTy Irrelevant from to) = - case isNewtype to of - Just _ -> Nothing - _ -> Just $ Spine [(Irrelevant, from)] to -isNewtype (ForallTy (Relevant var) k t) = do - (Spine tys res) <- isNewtype t - guard (var `occursInTy` res) - pure (Spine ((Relevant var, k):tys) res) -isNewtype _ = Nothing - -newtypeMatch :: IsVar a => V.Map (Coercion a) -> [Arm a] -> Maybe (Coercion a, Arm a) -newtypeMatch m (it@Arm { _armPtrn = Destr c _, _armTy = ty }:xs) - | Just phi@(SameRepr _ cod) <- V.lookup (toVar c) m = - case unify cod ty of - Just map -> pure (substituteInCo map (Symmetry phi), it) - Nothing -> error $ "failed to match newtype-constructor types " ++ show cod ++ " and " ++ show ty - | otherwise = newtypeMatch m xs -newtypeMatch m (_:xs) = newtypeMatch m xs -newtypeMatch _ [] = Nothing - --- Note: (again) the order of parameters to unify matters! The --- substitution is always in terms of the *first* parameter. Here, the --- results were backwards, so the solution wasn't being applied properly --- and the generated code was wrong. - -newtypeWorker :: forall a m. (IsVar a, MonadNamey m) - => (a, Type a) -> Spine a -> m (Stmt a, Coercion a, V.Map (Atom a)) -newtypeWorker (cn, tp) (Spine tys cod) = do - let CoVar nam id _ = toVar cn - (Irrelevant, dom) = last tys - - cname :: a - cname = fromVar (CoVar nam id ValueVar) - - phi = SameRepr dom cod - - wrap ((Relevant v, c):ts) ex = Lam (TypeArgument v c) <$> wrap ts ex - wrap [(Irrelevant, c)] ex = do - v <- fresh ValueVar - Lam (TermArgument (fromVar v) c) <$> pure (ex (fromVar v) c) - wrap _ _ = undefined - - work var ty = Cast (Ref var ty) cod phi - - work :: a -> Type a -> Term a - wrap :: [(BoundTv a, Type a)] -> (a -> Type a -> Term a) -> m (Term a) - - wrapper <- wrap tys work - let con = ( cname, tp, wrapper ) - pure (StmtLet (One con), phi, V.singleton (toVar cn) (Ref (fromVar (CoVar nam id ValueVar)) tp)) - -goBinding :: forall a m. (IsVar a, MonadNamey m) - => V.Map (Atom a) -> V.Map (Coercion a) -> [(a, Type a, Term a)] -> m [(a, Type a, Term a)] -goBinding ss m = traverse (third3A (fmap (substitute ss) . goTerm)) where - goTerm :: Term a -> m (Term a) - goTerm e@Atom{} = pure e - goTerm e@App{} = pure e - goTerm e@Extend{} = pure e - goTerm e@Values{} = pure e - goTerm e@Values{} = pure e - goTerm e@Cast{} = pure e - goTerm e@TyApp{} = pure e - - goTerm (Lam arg e) = Lam arg <$> goTerm e - goTerm (Let (Many vs) e) = Let . Many <$> goBinding ss m vs <*> goTerm e - goTerm (Let (One (v, t, e)) b) = do - e' <- goTerm e - Let (One (v, t, e')) <$> goTerm b - - goTerm (Match a x) = case newtypeMatch m x of - Just (phi, Arm { _armPtrn = Destr _ [Capture v ty], _armBody = bd }) -> do - var <- fresh ValueVar - let Just (_, castCodomain) = relates phi - bd' <- goTerm (Let (One (v, ty, Atom (Ref (fromVar var) castCodomain))) bd) - pure $ Let (One (fromVar var, castCodomain, Cast a castCodomain phi)) bd' - _ -> Match a <$> traverse (armBody %%~ goTerm) x - -data Spine a = - Spine [(BoundTv a, Type a)] (Type a) - deriving (Eq, Show, Ord) diff --git a/src/Core/Simplify.hs b/src/Core/Simplify.hs index d9d3997c2..81131f938 100644 --- a/src/Core/Simplify.hs +++ b/src/Core/Simplify.hs @@ -5,7 +5,6 @@ module Core.Simplify ) where import Core.Optimise.CommonExpElim -import Core.Optimise.Newtype import Core.Optimise.DeadCode import Core.Optimise.Sinking import Core.Optimise.Reduce @@ -38,13 +37,10 @@ linted pass fn = fmap (runLint pass =<< checkStmt emptyScope) . fn -- | Run the optimiser multiple times over the input core. optimise :: forall m. Monad m => Bool -> [Stmt CoVar] -> NameyT m [Stmt CoVar] -optimise lint = go 10 <=< prepasses <=< linting "Lower" pure where +optimise lint = go 10 <=< linting "Lower" pure where go :: Integer -> [Stmt CoVar] -> NameyT m [Stmt CoVar] go k sts | k > 0 = go (k - 1) =<< optmOnce lint sts | otherwise = pure sts - prepasses :: [Stmt CoVar] -> NameyT m [Stmt CoVar] - prepasses = linting "Newtype" killNewtypePass - linting = if lint then linted else flip const diff --git a/tests/lua/monoid.lua b/tests/lua/monoid.lua index 298041381..48c6b4022 100644 --- a/tests/lua/monoid.lua +++ b/tests/lua/monoid.lua @@ -25,7 +25,7 @@ do local function _dollar_d7(cjk, x, ys) if x.__tag == "Nil" then return ys end local tmp = x[1] - return { { _1 = tmp._1, _2 = _dollar_d7(nil, tmp._2, ys) }, __tag = "Cons" } + return { { _2 = _dollar_d7(nil, tmp._2, ys), _1 = tmp._1 }, __tag = "Cons" } end local tmp = { _1 = 1, _2 = nil } writeln(_dollarshow(function(x) diff --git a/tests/lua/promotion.lua b/tests/lua/promotion.lua index 823b0df68..e26385482 100644 --- a/tests/lua/promotion.lua +++ b/tests/lua/promotion.lua @@ -1,8 +1,8 @@ do local ignore = function(x) end local function main(f) - local tmp = f(nil) - return { _1 = tmp._1, _2 = tmp._2 } + local x = f(nil) + return { _1 = x._1, _2 = x._2 } end ignore(main) end