Skip to content
This repository has been archived by the owner on Oct 18, 2021. It is now read-only.

Run newtype optimisation as part of lower #210

Merged
merged 1 commit into from
Oct 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion amuletml.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ library
, Core.Lower
, Core.Lower.Basic
, Core.Lower.Pattern
, Core.Lower.TypeRepr
, Core.Types
, Core.Builtin
, Core.Optimise
Expand All @@ -291,7 +292,6 @@ library
, Core.Optimise.Reduce.Inline
, Core.Optimise.Reduce.Pattern
, Core.Optimise.Sinking
, Core.Optimise.Newtype
, Core.Optimise.Uncurry
, Core.Optimise.DeadCode
, Core.Optimise.CommonExpElim
Expand Down
3 changes: 1 addition & 2 deletions bin/Amc.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import Backend.Lua
import qualified Syntax.Builtin as Bi

import Core.Optimise.Reduce (reducePass)
import Core.Optimise.Newtype (killNewtypePass)
import Core.Optimise.DeadCode (deadCodePass)
import Core.Simplify (optimise)
import Core.Core (Stmt)
Expand Down Expand Up @@ -65,7 +64,7 @@ runCompile opt (DoLint lint) dconfig file = do
Opt -> optimise lint core
NoOpt -> do
lintIt "Lower" (checkStmt emptyScope core) (pure ())
(lintIt "Optimised" =<< checkStmt emptyScope) . deadCodePass <$> (reducePass =<< killNewtypePass core)
(lintIt "Optimised" =<< checkStmt emptyScope) . deadCodePass <$> reducePass core
lua = compileProgram optimised
in ( Just (env, core, optimised, lua)
, errors
Expand Down
33 changes: 14 additions & 19 deletions src/Core/Lower.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{-# LANGUAGE LambdaCase, TupleSections,
{-# LANGUAGE TupleSections,
PatternSynonyms, RankNTypes, ScopedTypeVariables, FlexibleContexts,
ConstraintKinds, OverloadedStrings, TypeFamilies #-}
module Core.Lower
Expand Down Expand Up @@ -32,6 +32,7 @@ import Core.Optimise (substituteInType, substituteInTys, fresh, freshFrom)
import Core.Core hiding (Atom, Term, Stmt, Type, Pattern, Arm)
import Core.Core (pattern Atom)
import Core.Types (unify, unifyClosed, replaceTy)
import Core.Lower.TypeRepr
import Core.Lower.Pattern
import Core.Lower.Basic
import Core.Var
Expand All @@ -40,7 +41,7 @@ import qualified Syntax as S
import Syntax.Let
import Syntax.Var (Var, Typed, VarResolved(..))
import Syntax.Transform
import Syntax (Expr(..), Pattern(..), Skolem(..), ModuleTerm(..), Toplevel(..), Constructor(..), Arm(..))
import Syntax (Expr(..), Pattern(..), Skolem(..), ModuleTerm(..), Toplevel(..), Arm(..))

import Text.Pretty.Semantic (pretty)

Expand All @@ -52,13 +53,18 @@ type Stmt = C.Stmt CoVar
type Lower = ContT Term

defaultState :: LowerState
defaultState = LS mempty ctors mempty where
defaultState = LS mempty ctors types where
ctors :: VarMap.Map (C.Type CoVar)
ctors = VarMap.fromList
[ (C.vCONS,
ForallTy (Relevant name) StarTy $
VarTy name `prodTy` AppTy C.tyList (VarTy name) `arrTy` AppTy C.tyList (VarTy name))
, (C.vNIL, ForallTy (Relevant name) StarTy $ AppTy C.tyList (VarTy name))]
types :: VarMap.Map TypeRepr
types = VarMap.fromList
( (C.vList, SumTy (VarSet.fromList [C.vCONS, C.vNIL]))
: map (,OpaqueTy) [ C.vBool, C.vInt, C.vString, C.vFloat, C.vUnit
, C.vLazy, C.vArrow, C.vProduct, C.vRefTy ] )
name = C.tyvarA
arrTy = ForallTy Irrelevant
prodTy a b = RowsTy NilTy [("_1", a), ("_2", b)]
Expand Down Expand Up @@ -365,22 +371,11 @@ lowerProg' (LetStmt _ vs:prg) = do
vs' <- lowerLet vs
foldr ((.) . ((:) . C.StmtLet)) id vs' <$$> lowerProg' prg

lowerProg' (TypeDecl _ var _ Nothing _:prg) =
(C.Type (mkType var) []:) <$$> lowerProg' prg
lowerProg' (TypeDecl _ var _ (Just cons) _:prg) = do
let cons' = map (\case
UnitCon _ p (_, t) -> (p, mkCon p, lowerType t)
ArgCon _ p _ (_, t) -> (p, mkCon p, lowerType t)
GadtCon _ p t _ -> (p, mkCon p, lowerType t))
cons
ccons = map (\(_, a, b) -> (a, b)) cons'
scons = map (\(a, _, b) -> (mkCon a, b)) cons'

conset = VarSet.fromList (map fst scons)

(C.Type (mkType var) ccons:) <$$> local (\s ->
s { ctors = VarMap.union (VarMap.fromList scons) (ctors s)
, types = VarMap.insert (mkType var) conset (types s)
lowerProg' (TypeDecl _ var _ cons _:prg) = do
~(tyStmts@(C.Type _ cs:_), repr) <- getTypeRepr (mkType var) cons
(tyStmts++) <$$> local (\s ->
s { ctors = VarMap.union (VarMap.fromList cs) (ctors s)
, types = VarMap.insert (mkType var) repr (types s)
}) (lowerProg' prg)

lowerLet :: MonadLower m => [S.Binding Typed] -> m [Binding CoVar]
Expand Down
16 changes: 14 additions & 2 deletions src/Core/Lower/Basic.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE OverloadedStrings, ConstraintKinds, FlexibleContexts #-}
module Core.Lower.Basic
( LowerState(..)
( TypeRepr(..)
, LowerState(..)
, LowerTrack
, MonadLower
, mkTyvar, mkVal, mkType, mkCo, mkCon, mkVar
Expand All @@ -24,12 +25,23 @@ import qualified Syntax as S
import Syntax.Var (VarResolved(..), Var, Resolved, Typed)
import Syntax (Lit(..), Skolem(..))

data TypeRepr
= OpaqueTy -- ^ An opaque type, for interfacing with foreign values.
| SumTy VarSet.Set -- ^ A sum type, with the set of constructors.

-- | A type which just wraps another.
--
-- This holds the name of the constructor, and the inner and outer
-- type, both sharing their free variables.
| WrapperTy CoVar (C.Type CoVar) (C.Type CoVar)
deriving (Show, Eq)

data LowerState
= LS
{ vars :: VarMap.Map (C.Type CoVar)
, ctors :: VarMap.Map (C.Type CoVar)
-- | The map of types to their constructors /if they have any/.
, types :: VarMap.Map VarSet.Set
, types :: VarMap.Map TypeRepr
} deriving (Eq, Show)

instance Semigroup LowerState where
Expand Down
74 changes: 66 additions & 8 deletions src/Core/Lower/Pattern.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import Data.Bifunctor
import Data.Foldable
import Data.Triple
import Data.Maybe
import Data.Span

import qualified Core.Core as C
import Core.Optimise (substituteInType, substituteInTys, fresh, freshFrom)
Expand Down Expand Up @@ -85,13 +86,23 @@ data ArmNode
, nodeNodes :: [(Pattern CoVar, ArmNode)]
-- ^ The child nodes, and their associated pattern.
}
| ArmLet
{ nodeArms :: ArmSet
, nodeSuccess :: [ArmLeaf] -- Should be empty
, nodeBind :: (CoVar, Type CoVar, Term CoVar)
, nodeBody :: ArmNode
}
deriving (Show)

instance Pretty ArmNode where
pretty (ArmComplete arms success) = "Complete" <+> parens (shown arms) <+> shown success
pretty (ArmMatch arms success atom nodes)
= "Match" <+> parens (shown arms) <+> shown success <+> pretty atom
<#> (indent 2 . vsep $ map (\(p, n) -> pretty p <+> "=>" <#> indent 2 (pretty n)) nodes)
pretty (ArmLet arms success (v, ty, x) node)
= "Let" <+> parens (shown arms) <+> shown success <+> pretty v <+> colon <+> pretty ty <+> equals <+> pretty x
<#> indent 2 (pretty node)


-- | A of a single case in a match expression.
data PatternRow
Expand Down Expand Up @@ -197,6 +208,8 @@ flattenNode bodies guards (ArmMatch _ leafs atom' children) = do
let branches = foldr (flip (HSet.foldr add) . nodeArms . snd) mempty children
in foldr (add . leafArm) branches leafs
where add k = HMap.insertWith (+) k (1 :: Int)
flattenNode bodies guards (ArmLet _ _ bind child) =
C.Let (One bind) <$> flattenNode bodies guards child

-- | Lift a pattern match into a lambda, passing arguments as values.
generateBinds :: forall m. MonadLower m
Expand Down Expand Up @@ -359,12 +372,11 @@ lowerOne tys rss = do
getCtors v = do
ctor <- VarMap.lookup v (ctors state)
ty <- getType ctor
VarMap.lookup ty (types state)

getType (ForallTy _ _ t) = getType t
getType (ConTy a) = pure a
getType (AppTy f _) = getType f
getType _ = Nothing
case VarMap.lookup ty (types state) of
Nothing -> error ("Cannot find " ++ show ty)
Just OpaqueTy -> Nothing
Just (SumTy ctors) -> Just ctors
Just (WrapperTy ctor _ _) -> Just (VarSet.singleton ctor)

-- | Compute the "arity" heuristic for a given row variable.
--
Expand Down Expand Up @@ -415,8 +427,8 @@ lowerOneOf preLeafs var ty tys = go [] . map prepare

go unc [] = lowerOne tys (reverse unc)
go unc rs@((S.PRecord{},_):_) = goRows unc mempty rs
go unc rs@((S.Destructure{},_):_) = goCtors unc mempty rs
go unc rs@((S.PGadtCon{},_):_) = goCtors unc mempty rs
go unc rs@((S.Destructure{},_):_) = goCtorsWith unc rs
go unc rs@((S.PGadtCon{},_):_) = goCtorsWith unc rs
go unc rs@((S.PLiteral{},_):_) = goLiterals unc mempty rs
go unc ((p, r):rs) = go (goGeneric p r:unc) rs

Expand Down Expand Up @@ -464,6 +476,46 @@ lowerOneOf preLeafs var ty tys = go [] . map prepare
pure ( Map.insert f (v, lowerType (S.getType p)) fs
, VarMap.insert v p ps )

goCtorsWith unc rs = do
let Just tyName = getType ty
repr <- asks (fromMaybe (error ("Cannot find " ++ show tyName)) . VarMap.lookup tyName . types)
case repr of
OpaqueTy -> error "Impossible matching on opaque type"
SumTy _ -> goCtors unc mempty rs
WrapperTy _ from to -> do
let Just map = unify to ty
from' = substituteInType map from
coVar <- case rs of
((S.PGadtCon _ _ _ (Just child) _, _):_) -> freshFromPat child
_ -> fresh ValueVar
node <- goNewtype unc (Capture coVar from') rs
pure (ArmLet (nodeArms node) mempty
(coVar, from', Cast (Ref var ty) from' (SameRepr ty from'))
node)

-- | Split patterns into those matching against the constructor and those not
goNewtype :: [PatternRow] -> Capture CoVar
-> [(S.Pattern Typed, PatternRow)]
-> m ArmNode
goNewtype unc (Capture c cty) [] =
lowerOne (VarMap.insert c cty tys) (reverse unc)

goNewtype unc cap@(Capture c _) (( S.PGadtCon _ _ [] (Just p) _
, PR arm pats gd vBind tyBind ):rs) =
-- The wrapped value is matched by the pattern - focus on that next.
let r' = PR arm (VarMap.insert c p pats) gd vBind tyBind
in goNewtype (r':unc) cap rs

goNewtype unc cap@(Capture c _) (( S.PGadtCon _ _ [(v, t)] Nothing _
, PR arm pats gd vBind tyBind ):rs) =
-- The wrapped value is the dictionary - just add a wildcard pattern.
let r' = PR arm (VarMap.insert c (S.Capture v (internal, t)) pats) gd vBind tyBind
in goNewtype (r':unc) cap rs

goNewtype _ _ ((S.PGadtCon{}, _):_) = error "Impossible: Malformed pattern for newtype."

goNewtype unc cap ((p, r):rs) = goNewtype (goGeneric p r:unc) cap rs

-- | Build up a mapping of (constructors -> (contents variable, rows)).
goCtors :: [PatternRow] -> VarMap.Map ([Capture CoVar], [PatternRow])
-> [(S.Pattern Typed, PatternRow)]
Expand Down Expand Up @@ -577,3 +629,9 @@ dropNForalls :: Int -> Type a -> Type a
dropNForalls 0 t = t
dropNForalls x (ForallTy _ _ t) = dropNForalls (x - 1) t
dropNForalls _ _ = undefined

getType :: Type a -> Maybe a
getType (ForallTy _ _ t) = getType t
getType (ConTy a) = pure a
getType (AppTy f _) = getType f
getType _ = Nothing
62 changes: 62 additions & 0 deletions src/Core/Lower/TypeRepr.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
{-# LANGUAGE LambdaCase, ScopedTypeVariables #-}
module Core.Lower.TypeRepr (getTypeRepr) where

import qualified Data.VarSet as VarSet

import Control.Monad.Namey
import Control.Monad

import Core.Lower.Basic
import Core.Optimise

import Syntax (Constructor(..))
import Syntax.Var (Typed)

getTypeRepr :: MonadNamey m
=> CoVar -> Maybe [Constructor Typed]
-> m ([Stmt CoVar], TypeRepr)
getTypeRepr var Nothing = pure ([Type var []], OpaqueTy)
getTypeRepr var (Just ctors) =
let ctors' = map (\case
UnitCon _ p (_, t) -> (mkCon p, lowerType t)
ArgCon _ p _ (_, t) -> (mkCon p, lowerType t)
GadtCon _ p t _ -> (mkCon p, lowerType t)) ctors
in case ctors' of
[(ctor, ty)] | Just nt@(Spine _ dom cod) <- isNewtype ty -> do
let CoVar name id _ = ctor

wrapper <- newtypeWorker nt
pure ( [ Type var [], StmtLet (One (CoVar name id ValueVar, ty, wrapper))]
, WrapperTy ctor dom cod )

_ -> pure ( [ Type var ctors' ]
, SumTy (VarSet.fromList (map fst ctors')) )

isNewtype :: IsVar a => Type a -> Maybe (Spine a)
isNewtype (ForallTy Irrelevant _ ForallTy{}) = Nothing -- Cannot have multiple relevant arguments
isNewtype (ForallTy Irrelevant from to) =
pure (Spine [(Irrelevant, from)] from to)
isNewtype (ForallTy (Relevant var) k rest) = do
(Spine tys from to) <- isNewtype rest
guard (var `occursInTy` to)
pure (Spine ((Relevant var, k):tys) from to)
isNewtype _ = Nothing

data Spine a =
Spine [(BoundTv a, Type a)] (Type a) (Type a)
deriving (Eq, Show, Ord)

newtypeWorker :: forall a m. (IsVar a, MonadNamey m)
=> Spine a -> m (Term a)
newtypeWorker (Spine tys dom cod) = do
let wrap :: [(BoundTv a, Type a)] -> (a -> Type a -> Term a) -> m (Term a)
wrap ((Relevant v, c):ts) ex = Lam (TypeArgument v c) <$> wrap ts ex
wrap [(Irrelevant, c)] ex = do
v <- fresh ValueVar
Lam (TermArgument (fromVar v) c) <$> pure (ex (fromVar v) c)
wrap _ _ = undefined

work :: a -> Type a -> Term a
work var ty = Cast (Ref var ty) cod (SameRepr dom cod)

wrap tys work
Loading