Skip to content

Commit

Permalink
Conversion of Nat representation to JuvixCore integers (#1661)
Browse files Browse the repository at this point in the history
* nat to int wip

* nat to int wip

* fix condition

* nats in core

* bugfixes

* tests

* make ormolu happy

* fix case
  • Loading branch information
lukaszcz authored Dec 20, 2022
1 parent af0379a commit 445376e
Show file tree
Hide file tree
Showing 26 changed files with 386 additions and 107 deletions.
12 changes: 6 additions & 6 deletions src/Juvix/Compiler/Core/Data/BinderList.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Juvix.Compiler.Core.Data.BinderList where

import Juvix.Compiler.Core.Language hiding (cons, lookup, uncons)
import Juvix.Compiler.Core.Language hiding (cons, drop, lookup, uncons)
import Juvix.Prelude qualified as Prelude

-- | if we have \x\y. b, the binderlist in b is [y, x]
Expand All @@ -14,11 +14,11 @@ makeLenses ''BinderList
fromList :: [a] -> BinderList a
fromList l = BinderList (length l) l

drop' :: Int -> BinderList a -> BinderList a
drop' k (BinderList n l) = BinderList (n - k) (dropExact k l)
drop :: Int -> BinderList a -> BinderList a
drop k (BinderList n l) = BinderList (n - k) (dropExact k l)

tail' :: BinderList a -> BinderList a
tail' = snd . fromJust . uncons
tail :: BinderList a -> BinderList a
tail = snd . fromJust . uncons

uncons :: BinderList a -> Maybe (a, BinderList a)
uncons l = second helper <$> Prelude.uncons (l ^. blMap)
Expand Down Expand Up @@ -60,7 +60,7 @@ lookupsSortedRev bl = go [] 0 bl
(v : vs) ->
let skipped = v ^. varIndex - off
off' = off + skipped
ctx' = drop' skipped ctx
ctx' = drop skipped ctx
in go ((v, head' ctx') : acc) off' ctx' vs
head' :: BinderList a -> a
head' = lookup 0
Expand Down
16 changes: 12 additions & 4 deletions src/Juvix/Compiler/Core/Data/InfoTable.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
module Juvix.Compiler.Core.Data.InfoTable where
module Juvix.Compiler.Core.Data.InfoTable
( module Juvix.Compiler.Core.Data.InfoTable,
module Juvix.Compiler.Concrete.Data.Builtins,
)
where

import Juvix.Compiler.Concrete.Data.Builtins
import Juvix.Compiler.Core.Language

type IdentContext = HashMap Symbol Node
Expand Down Expand Up @@ -43,7 +48,8 @@ data IdentifierInfo = IdentifierInfo
_identifierType :: Type,
_identifierArgsNum :: Int,
_identifierArgsInfo :: [ArgumentInfo],
_identifierIsExported :: Bool
_identifierIsExported :: Bool,
_identifierBuiltin :: Maybe BuiltinFunction
}

data ArgumentInfo = ArgumentInfo
Expand All @@ -60,7 +66,8 @@ data InductiveInfo = InductiveInfo
_inductiveKind :: Type,
_inductiveConstructors :: [ConstructorInfo],
_inductiveParams :: [ParameterInfo],
_inductivePositive :: Bool
_inductivePositive :: Bool,
_inductiveBuiltin :: Maybe BuiltinInductive
}

data ConstructorInfo = ConstructorInfo
Expand All @@ -69,7 +76,8 @@ data ConstructorInfo = ConstructorInfo
_constructorTag :: Tag,
_constructorType :: Type,
_constructorArgsNum :: Int,
_constructorInductive :: Symbol
_constructorInductive :: Symbol,
_constructorBuiltin :: Maybe BuiltinConstructor
}

data ParameterInfo = ParameterInfo
Expand Down
15 changes: 15 additions & 0 deletions src/Juvix/Compiler/Core/Data/InfoTableBuilder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@ getConstructorInfo tag = do
tab <- getInfoTable
return $ fromJust (HashMap.lookup tag (tab ^. infoConstructors))

getInductiveInfo :: Member InfoTableBuilder r => Symbol -> Sem r InductiveInfo
getInductiveInfo sym = do
tab <- getInfoTable
return $ fromJust (HashMap.lookup sym (tab ^. infoInductives))

getIdentifierInfo :: Member InfoTableBuilder r => Symbol -> Sem r IdentifierInfo
getIdentifierInfo sym = do
tab <- getInfoTable
return $ fromJust (HashMap.lookup sym (tab ^. infoIdentifiers))

getBoolSymbol :: Member InfoTableBuilder r => Sem r Symbol
getBoolSymbol = do
ci <- getConstructorInfo (BuiltinTag TagTrue)
return $ ci ^. constructorInductive

checkSymbolDefined :: Member InfoTableBuilder r => Symbol -> Sem r Bool
checkSymbolDefined sym = do
tab <- getInfoTable
Expand Down
1 change: 1 addition & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ data TransformationId
| TopEtaExpand
| RemoveTypeArgs
| MoveApps
| NatToInt
| Identity
deriving stock (Data)
5 changes: 5 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pcompletions = do
Identity -> strIdentity
RemoveTypeArgs -> strRemoveTypeArgs
MoveApps -> strMoveApps
NatToInt -> strNatToInt

lexeme :: MonadParsec e Text m => m a -> m a
lexeme = L.lexeme L.hspace
Expand All @@ -63,6 +64,7 @@ transformation =
<|> symbol strTopEtaExpand $> TopEtaExpand
<|> symbol strRemoveTypeArgs $> RemoveTypeArgs
<|> symbol strMoveApps $> MoveApps
<|> symbol strNatToInt $> NatToInt

allStrings :: [Text]
allStrings =
Expand All @@ -87,3 +89,6 @@ strRemoveTypeArgs = "remove-type-args"

strMoveApps :: Text
strMoveApps = "move-apps"

strNatToInt :: Text
strNatToInt = "nat-to-int"
4 changes: 2 additions & 2 deletions src/Juvix/Compiler/Core/Evaluator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ eval !ctx !env0 = convertRuntimeNodes . eval' env0
let !vs' = map (eval' env' . (^. letItemValue)) (toList vs)
!env' = revAppend vs' env
in foldr GHC.pseq (eval' env' b) vs'
NCase (Case i v bs def) ->
NCase (Case i sym v bs def) ->
case eval' env v of
NCtr (Constr _ tag args) -> branch n env args tag def bs
v' -> evalError "matching on non-data" (substEnv env (mkCase i v' bs def))
v' -> evalError "matching on non-data" (substEnv env (mkCase i sym v' bs def))
NMatch (Match _ vs bs) ->
let !vs' = map' (eval' env) (toList vs)
in match n env vs' bs
Expand Down
21 changes: 12 additions & 9 deletions src/Juvix/Compiler/Core/Extra/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ mkLambda i bi b = NLam (Lambda i bi b)
mkLambda' :: Node -> Node
mkLambda' = mkLambda Info.empty emptyBinder

mkLambdaTy :: Node -> Node -> Node
mkLambdaTy ty = mkLambda Info.empty (Binder "?" Nothing ty)

mkLetItem' :: Node -> LetItem
mkLetItem' = LetItem emptyBinder

Expand All @@ -70,10 +73,10 @@ mkLetRec i vs b = NRec (LetRec i vs b)
mkLetRec' :: NonEmpty Node -> Node -> Node
mkLetRec' = mkLetRec Info.empty . fmap mkLetItem'

mkCase :: Info -> Node -> [CaseBranch] -> Maybe Node -> Node
mkCase i v bs def = NCase (Case i v bs def)
mkCase :: Info -> Symbol -> Node -> [CaseBranch] -> Maybe Node -> Node
mkCase i sym v bs def = NCase (Case i sym v bs def)

mkCase' :: Node -> [CaseBranch] -> Maybe Node -> Node
mkCase' :: Symbol -> Node -> [CaseBranch] -> Maybe Node -> Node
mkCase' = mkCase Info.empty

mkMatch :: Info -> NonEmpty Node -> [MatchBranch] -> Node
Expand All @@ -82,8 +85,8 @@ mkMatch i vs bs = NMatch (Match i vs bs)
mkMatch' :: NonEmpty Node -> [MatchBranch] -> Node
mkMatch' = mkMatch Info.empty

mkIf :: Info -> Node -> Node -> Node -> Node
mkIf i v b1 b2 = mkCase i v [br] (Just b2)
mkIf :: Info -> Symbol -> Node -> Node -> Node -> Node
mkIf i sym v b1 b2 = mkCase i sym v [br] (Just b2)
where
br =
CaseBranch
Expand All @@ -94,7 +97,7 @@ mkIf i v b1 b2 = mkCase i v [br] (Just b2)
_caseBranchBody = b1
}

mkIf' :: Node -> Node -> Node -> Node
mkIf' :: Symbol -> Node -> Node -> Node -> Node
mkIf' = mkIf Info.empty

{------------------------------------------------------------------------}
Expand Down Expand Up @@ -480,7 +483,7 @@ destruct = \case
]
in mkLetRec i' items' (b' ^. childNode)
}
NCase (Case i v brs mdef) ->
NCase (Case i sym v brs mdef) ->
let branchChildren :: [([Binder], NodeChild)]
branchChildren =
[ (binders, manyBinders binders (br ^. caseBranchBody))
Expand Down Expand Up @@ -520,15 +523,15 @@ destruct = \case
_nodeSubinfos = map (^. caseBranchInfo) brs,
_nodeChildren = noBinders v : allNodes,
_nodeReassemble = someChildrenI $ \i' is' (v' :| allNodes') ->
mkCase i' (v' ^. childNode) (mkBranches is' allNodes') Nothing
mkCase i' sym (v' ^. childNode) (mkBranches is' allNodes') Nothing
}
Just def ->
NodeDetails
{ _nodeInfo = i,
_nodeSubinfos = map (^. caseBranchInfo) brs,
_nodeChildren = noBinders v : noBinders def : allNodes,
_nodeReassemble = twoManyChildrenI $ \i' is' v' def' allNodes' ->
mkCase i' (v' ^. childNode) (mkBranches is' allNodes') (Just (def' ^. childNode))
mkCase i' sym (v' ^. childNode) (mkBranches is' allNodes') (Just (def' ^. childNode))
}
NMatch (Match i vs branches) ->
let allNodes :: [NodeChild]
Expand Down
8 changes: 8 additions & 0 deletions src/Juvix/Compiler/Core/Extra/Recursors/Map.hs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ fromRecur c =
fromPair :: Functor g => d -> g (c, Node) -> g (Recur' (c, d))
fromPair d = fmap (\(c, x) -> Recur' ((c, d), x))

fromRecur' :: Functor g => d -> g (Recur' c) -> g (Recur' (c, d))
fromRecur' d =
fmap
( \case
End' x -> End' x
Recur' (c, x) -> Recur' ((c, d), x)
)

nodeMapG' ::
Monad m =>
Sing dir ->
Expand Down
56 changes: 37 additions & 19 deletions src/Juvix/Compiler/Core/Extra/Recursors/Map/Named.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@ import Juvix.Compiler.Core.Extra.Recursors.Map
import Juvix.Compiler.Core.Extra.Recursors.Parameters

dmapLRM :: Monad m => (BinderList Binder -> Node -> m Recur) -> Node -> m Node
dmapLRM f = nodeMapG' STopDown binderInfoCollector (\bi -> fromRecur bi . f bi)
dmapLRM f = dmapLRM' (mempty, f)

dmapLM :: Monad m => (BinderList Binder -> Node -> m Node) -> Node -> m Node
dmapLM f = nodeMapG' STopDown binderInfoCollector (\bi -> fromSimple bi . f bi)
dmapLM f = dmapLM' (mempty, f)

umapLM :: Monad m => (BinderList Binder -> Node -> m Node) -> Node -> m Node
umapLM f = nodeMapG' SBottomUp binderInfoCollector f

dmapNRM :: Monad m => (Index -> Node -> m Recur) -> Node -> m Node
dmapNRM f = nodeMapG' STopDown binderNumCollector (\bi -> fromRecur bi . f bi)
dmapNRM :: Monad m => (Level -> Node -> m Recur) -> Node -> m Node
dmapNRM f = dmapNRM' (0, f)

dmapNM :: Monad m => (Index -> Node -> m Node) -> Node -> m Node
dmapNM f = nodeMapG' STopDown binderNumCollector (\bi -> fromSimple bi . f bi)
dmapNM :: Monad m => (Level -> Node -> m Node) -> Node -> m Node
dmapNM f = dmapNM' (0, f)

umapNM :: Monad m => (Index -> Node -> m Node) -> Node -> m Node
umapNM :: Monad m => (Level -> Node -> m Node) -> Node -> m Node
umapNM f = nodeMapG' SBottomUp binderNumCollector f

dmapRM :: Monad m => (Node -> m Recur) -> Node -> m Node
Expand All @@ -41,13 +41,13 @@ dmapLM' f = nodeMapG' STopDown (binderInfoCollector' (fst f)) (\bi -> fromSimple
umapLM' :: Monad m => (BinderList Binder, BinderList Binder -> Node -> m Node) -> Node -> m Node
umapLM' f = nodeMapG' SBottomUp (binderInfoCollector' (fst f)) (snd f)

dmapNRM' :: Monad m => (Index, Index -> Node -> m Recur) -> Node -> m Node
dmapNRM' :: Monad m => (Level, Level -> Node -> m Recur) -> Node -> m Node
dmapNRM' f = nodeMapG' STopDown (binderNumCollector' (fst f)) (\bi -> fromRecur bi . snd f bi)

dmapNM' :: Monad m => (Index, Index -> Node -> m Node) -> Node -> m Node
dmapNM' :: Monad m => (Level, Level -> Node -> m Node) -> Node -> m Node
dmapNM' f = nodeMapG' STopDown (binderNumCollector' (fst f)) (\bi -> fromSimple bi . snd f bi)

umapNM' :: Monad m => (Index, Index -> Node -> m Node) -> Node -> m Node
umapNM' :: Monad m => (Level, Level -> Node -> m Node) -> Node -> m Node
umapNM' f = nodeMapG' SBottomUp (binderNumCollector' (fst f)) (snd f)

dmapLR :: (BinderList Binder -> Node -> Recur) -> Node -> Node
Expand All @@ -59,13 +59,13 @@ dmapL f = runIdentity . dmapLM (embedIden f)
umapL :: (BinderList Binder -> Node -> Node) -> Node -> Node
umapL f = runIdentity . umapLM (embedIden f)

dmapNR :: (Index -> Node -> Recur) -> Node -> Node
dmapNR :: (Level -> Node -> Recur) -> Node -> Node
dmapNR f = runIdentity . dmapNRM (embedIden f)

dmapN :: (Index -> Node -> Node) -> Node -> Node
dmapN :: (Level -> Node -> Node) -> Node -> Node
dmapN f = runIdentity . dmapNM (embedIden f)

umapN :: (Index -> Node -> Node) -> Node -> Node
umapN :: (Level -> Node -> Node) -> Node -> Node
umapN f = runIdentity . umapNM (embedIden f)

dmapR :: (Node -> Recur) -> Node -> Node
Expand All @@ -86,28 +86,46 @@ dmapL' f = runIdentity . dmapLM' (embedIden f)
umapL' :: (BinderList Binder, BinderList Binder -> Node -> Node) -> Node -> Node
umapL' f = runIdentity . umapLM' (embedIden f)

dmapNR' :: (Index, Index -> Node -> Recur) -> Node -> Node
dmapNR' :: (Level, Level -> Node -> Recur) -> Node -> Node
dmapNR' f = runIdentity . dmapNRM' (embedIden f)

dmapN' :: (Index, Index -> Node -> Node) -> Node -> Node
dmapN' :: (Level, Level -> Node -> Node) -> Node -> Node
dmapN' f = runIdentity . dmapNM' (embedIden f)

umapN' :: (Index, Index -> Node -> Node) -> Node -> Node
umapN' :: (Level, Level -> Node -> Node) -> Node -> Node
umapN' f = runIdentity . umapNM' (embedIden f)

dmapCLM' :: Monad m => (BinderList Binder, c -> BinderList Binder -> Node -> m (c, Node)) -> c -> Node -> m Node
dmapCLM' f ini = nodeMapG' STopDown (pairCollector (identityCollector ini) (binderInfoCollector' (fst f))) (\(c, bi) -> fromPair bi . snd f c bi)

dmapCLRM' :: Monad m => (BinderList Binder, c -> BinderList Binder -> Node -> m (Recur' c)) -> c -> Node -> m Node
dmapCLRM' f ini = nodeMapG' STopDown (pairCollector (identityCollector ini) (binderInfoCollector' (fst f))) (\(c, bi) -> fromRecur' bi . snd f c bi)

dmapCNRM' :: Monad m => (Level, c -> Level -> Node -> m (Recur' c)) -> c -> Node -> m Node
dmapCNRM' f ini = nodeMapG' STopDown (pairCollector (identityCollector ini) (binderNumCollector' (fst f))) (\(c, bi) -> fromRecur' bi . snd f c bi)

dmapCLM :: Monad m => (c -> BinderList Binder -> Node -> m (c, Node)) -> c -> Node -> m Node
dmapCLM f ini = nodeMapG' STopDown (pairCollector (identityCollector ini) binderInfoCollector) (\(c, bi) -> fromPair bi . f c bi)
dmapCLM f = dmapCLM' (mempty, f)

dmapCNM :: Monad m => (c -> Index -> Node -> m (c, Node)) -> c -> Node -> m Node
dmapCNM :: Monad m => (c -> Level -> Node -> m (c, Node)) -> c -> Node -> m Node
dmapCNM f ini = nodeMapG' STopDown (pairCollector (identityCollector ini) binderNumCollector) (\(c, bi) -> fromPair bi . f c bi)

dmapCM :: Monad m => (c -> Node -> m (c, Node)) -> c -> Node -> m Node
dmapCM f ini = nodeMapG' STopDown (identityCollector ini) (\c -> fmap Recur' . f c)

dmapCL' :: (BinderList Binder, c -> BinderList Binder -> Node -> (c, Node)) -> c -> Node -> Node
dmapCL' f ini = runIdentity . dmapCLM' (embedIden f) ini

dmapCLR' :: (BinderList Binder, c -> BinderList Binder -> Node -> Recur' c) -> c -> Node -> Node
dmapCLR' f ini = runIdentity . dmapCLRM' (embedIden f) ini

dmapCNR' :: (Level, c -> Level -> Node -> Recur' c) -> c -> Node -> Node
dmapCNR' f ini = runIdentity . dmapCNRM' (embedIden f) ini

dmapCL :: (c -> BinderList Binder -> Node -> (c, Node)) -> c -> Node -> Node
dmapCL f ini = runIdentity . dmapCLM (embedIden f) ini

dmapCN :: (c -> Index -> Node -> (c, Node)) -> c -> Node -> Node
dmapCN :: (c -> Level -> Node -> (c, Node)) -> c -> Node -> Node
dmapCN f ini = runIdentity . dmapCNM (embedIden f) ini

dmapC :: (c -> Node -> (c, Node)) -> c -> Node -> Node
Expand Down
4 changes: 2 additions & 2 deletions src/Juvix/Compiler/Core/Extra/Stripped/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,5 @@ mkLet i v b = NLet (Let i item b)
mkLet' :: Node -> Node -> Node
mkLet' = mkLet (LetInfo "" Nothing TyDynamic)

mkCase :: CaseInfo -> Node -> [CaseBranch] -> Maybe Node -> Node
mkCase ci v bs def = NCase (Case ci v bs def)
mkCase :: CaseInfo -> Symbol -> Node -> [CaseBranch] -> Maybe Node -> Node
mkCase ci sym v bs def = NCase (Case ci sym v bs def)
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Keywords.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Juvix.Data.Keyword.All
( kwAny,
kwAssign,
kwBind,
kwBuiltin,
kwCase,
kwColon,
kwComma,
Expand Down Expand Up @@ -52,6 +53,7 @@ allKeywordStrings = keywordsStrings allKeywords
allKeywords :: [Keyword]
allKeywords =
[ kwAssign,
kwBuiltin,
kwLet,
kwLetRec,
kwIn,
Expand Down
3 changes: 2 additions & 1 deletion src/Juvix/Compiler/Core/Language/Nodes.hs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ data LetRec' i a ty = LetRec
-- branches default`. `Case` is lazy: only the selected branch is evaluated.
data Case' i bi a = Case
{ _caseInfo :: i,
_caseInductive :: Symbol,
_caseValue :: !a,
_caseBranches :: ![CaseBranch' bi a],
_caseDefault :: !(Maybe a)
Expand Down Expand Up @@ -338,7 +339,7 @@ instance Eq a => Eq (Constr' i a) where
(Constr _ tag1 args1) == (Constr _ tag2 args2) = tag1 == tag2 && args1 == args2

instance Eq a => Eq (Case' i bi a) where
(Case _ v1 bs1 def1) == (Case _ v2 bs2 def2) = v1 == v2 && bs1 == bs2 && def1 == def2
(Case _ sym1 v1 bs1 def1) == (Case _ sym2 v2 bs2 def2) = sym1 == sym2 && v1 == v2 && bs1 == bs2 && def1 == def2

instance Eq a => Eq (CaseBranch' i a) where
(==) =
Expand Down
Loading

0 comments on commit 445376e

Please sign in to comment.