Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support letrec lifting without lambda lifting #1794

Merged
merged 4 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/Juvix/Compiler/Core/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ module Juvix.Compiler.Core.Data.TransformationId where
import Juvix.Prelude

data TransformationId
= LambdaLifting
= LambdaLetRecLifting
| LetRecLifting
| TopEtaExpand
| RemoveTypeArgs
| MoveApps
Expand Down
9 changes: 7 additions & 2 deletions src/Juvix/Compiler/Core/Data/TransformationId/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ pcompletions = do
in f . Text.intercalate "," . map ppTrans
ppTrans :: TransformationId -> Text
ppTrans = \case
LambdaLifting -> strLifting
LambdaLetRecLifting -> strLifting
LetRecLifting -> strLetRecLifting
TopEtaExpand -> strTopEtaExpand
Identity -> strIdentity
RemoveTypeArgs -> strRemoveTypeArgs
Expand All @@ -62,7 +63,8 @@ symbol = void . lexeme . chunk

transformation :: (MonadParsec e Text m) => m TransformationId
transformation =
symbol strLifting $> LambdaLifting
symbol strLifting $> LambdaLetRecLifting
<|> symbol strLetRecLifting $> LetRecLifting
<|> symbol strIdentity $> Identity
<|> symbol strTopEtaExpand $> TopEtaExpand
<|> symbol strRemoveTypeArgs $> RemoveTypeArgs
Expand All @@ -88,6 +90,9 @@ allStrings =
strLifting :: Text
strLifting = "lifting"

strLetRecLifting :: Text
strLetRecLifting = "letrec-lifting"

strTopEtaExpand :: Text
strTopEtaExpand = "top-eta-expand"

Expand Down
6 changes: 6 additions & 0 deletions src/Juvix/Compiler/Core/Extra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ shift m = umapN go
| v ^. varIndex >= k -> NVar (shiftVar m v)
n -> n

-- | Prism for NRec
_NRec :: SimpleFold Node LetRec
_NRec f = \case
NRec l -> NRec <$> f l
n -> pure n

-- | Prism for NLam
_NLam :: SimpleFold Node Lambda
_NLam f = \case
Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Core/Pipeline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import Juvix.Compiler.Core.Data.InfoTable
import Juvix.Compiler.Core.Transformation

toStrippedTransformations :: [TransformationId]
toStrippedTransformations = [NatToInt, ConvertBuiltinTypes, LambdaLifting, MoveApps, TopEtaExpand, RemoveTypeArgs]
toStrippedTransformations = [NatToInt, ConvertBuiltinTypes, LambdaLetRecLifting, MoveApps, TopEtaExpand, RemoveTypeArgs]

-- | Perform transformations on Core necessary before the translation to
-- Core.Stripped
Expand Down
7 changes: 4 additions & 3 deletions src/Juvix/Compiler/Core/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module Juvix.Compiler.Core.Transformation
( module Juvix.Compiler.Core.Transformation.Base,
module Juvix.Compiler.Core.Transformation,
module Juvix.Compiler.Core.Transformation.Eta,
module Juvix.Compiler.Core.Transformation.LambdaLifting,
module Juvix.Compiler.Core.Transformation.LambdaLetRecLifting,
module Juvix.Compiler.Core.Transformation.TopEtaExpand,
module Juvix.Compiler.Core.Data.TransformationId,
)
Expand All @@ -14,7 +14,7 @@ import Juvix.Compiler.Core.Transformation.ComputeTypeInfo
import Juvix.Compiler.Core.Transformation.ConvertBuiltinTypes
import Juvix.Compiler.Core.Transformation.Eta
import Juvix.Compiler.Core.Transformation.Identity
import Juvix.Compiler.Core.Transformation.LambdaLifting
import Juvix.Compiler.Core.Transformation.LambdaLetRecLifting
import Juvix.Compiler.Core.Transformation.MoveApps
import Juvix.Compiler.Core.Transformation.NatToInt
import Juvix.Compiler.Core.Transformation.RemoveTypeArgs
Expand All @@ -26,7 +26,8 @@ applyTransformations ts tbl = foldl' (flip appTrans) tbl ts
where
appTrans :: TransformationId -> InfoTable -> InfoTable
appTrans = \case
LambdaLifting -> lambdaLifting
LambdaLetRecLifting -> lambdaLetRecLifting
LetRecLifting -> letRecLifting
Identity -> identity
TopEtaExpand -> topEtaExpand
RemoveTypeArgs -> removeTypeArgs
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module Juvix.Compiler.Core.Transformation.LambdaLifting
( module Juvix.Compiler.Core.Transformation.LambdaLifting,
module Juvix.Compiler.Core.Transformation.LambdaLetRecLifting
( module Juvix.Compiler.Core.Transformation.LambdaLetRecLifting,
module Juvix.Compiler.Core.Transformation.Base,
)
where
Expand All @@ -11,10 +11,12 @@ import Juvix.Compiler.Core.Info.NameInfo
import Juvix.Compiler.Core.Pretty
import Juvix.Compiler.Core.Transformation.Base

lambdaLiftBinder :: (Member InfoTableBuilder r) => BinderList Binder -> Binder -> Sem r Binder
lambdaLiftBinder :: Members '[Reader OnlyLetRec, InfoTableBuilder] r => BinderList Binder -> Binder -> Sem r Binder
lambdaLiftBinder bl = traverseOf binderType (lambdaLiftNode bl)

lambdaLiftNode :: forall r. (Member InfoTableBuilder r) => BinderList Binder -> Node -> Sem r Node
type OnlyLetRec = Bool

lambdaLiftNode :: forall r. Members '[Reader OnlyLetRec, InfoTableBuilder] r => BinderList Binder -> Node -> Sem r Node
lambdaLiftNode aboveBl top =
let topArgs :: [LambdaLhs]
(topArgs, body) = unfoldLambdas top
Expand All @@ -40,30 +42,37 @@ lambdaLiftNode aboveBl top =
m -> return (Recur m)
where
goLambda :: Lambda -> Sem r Recur
goLambda lm = do
l' <- lambdaLiftNode bl (NLam lm)
let (freevarsAssocs, fBody') = captureFreeVarsCtx bl l'
allfreevars :: [Var]
allfreevars = map fst freevarsAssocs
argsInfo :: [ArgumentInfo]
argsInfo = map (argumentInfoFromBinder . (^. lambdaLhsBinder)) (fst (unfoldLambdas fBody'))
f <- freshSymbol
let name = uniqueName "lambda" f
registerIdent
name
IdentifierInfo
{ _identifierSymbol = f,
_identifierName = name,
_identifierLocation = Nothing,
_identifierType = typeFromArgs argsInfo,
_identifierArgsNum = length argsInfo,
_identifierArgsInfo = argsInfo,
_identifierIsExported = False,
_identifierBuiltin = Nothing
}
registerIdentNode f fBody'
let fApp = mkApps' (mkIdent (setInfoName name mempty) f) (map NVar allfreevars)
return (End fApp)
goLambda l = do
onlyLetRec <- ask @OnlyLetRec
if
| onlyLetRec -> return (Recur (NLam l))
| otherwise -> goLambdaGo l
where
goLambdaGo :: Lambda -> Sem r Recur
goLambdaGo lm = do
l' <- lambdaLiftNode bl (NLam lm)
let (freevarsAssocs, fBody') = captureFreeVarsCtx bl l'
allfreevars :: [Var]
allfreevars = map fst freevarsAssocs
argsInfo :: [ArgumentInfo]
argsInfo = map (argumentInfoFromBinder . (^. lambdaLhsBinder)) (fst (unfoldLambdas fBody'))
f <- freshSymbol
let name = uniqueName "lambda" f
registerIdent
name
IdentifierInfo
{ _identifierSymbol = f,
_identifierName = name,
_identifierLocation = Nothing,
_identifierType = typeFromArgs argsInfo,
_identifierArgsNum = length argsInfo,
_identifierArgsInfo = argsInfo,
_identifierIsExported = False,
_identifierBuiltin = Nothing
}
registerIdentNode f fBody'
let fApp = mkApps' (mkIdent (setInfoName name mempty) f) (map NVar allfreevars)
return (End fApp)

goLetRec :: LetRec -> Sem r Recur
goLetRec letr = do
Expand Down Expand Up @@ -148,15 +157,34 @@ lambdaLiftNode aboveBl top =
res = shiftHelper body' (nonEmpty' (zipExact letItems letRecBinders'))
return (Recur res)

lambdaLifting :: InfoTable -> InfoTable
lambdaLifting = run . mapT' (const (lambdaLiftNode mempty))
lifting :: Bool -> InfoTable -> InfoTable
lifting onlyLetRec = run . runReader onlyLetRec . mapT' (const (lambdaLiftNode mempty))

lambdaLetRecLifting :: InfoTable -> InfoTable
lambdaLetRecLifting = lifting False

letRecLifting :: InfoTable -> InfoTable
letRecLifting = lifting True

-- | True if lambdas are only found at the top level
nodeIsLifted :: Node -> Bool
nodeIsLifted = not . hasNestedLambdas
nodeIsLifted = nodeIsLambdaLifted .&&. nodeIsLetRecLifted

-- | True if lambdas are only found at the top level
nodeIsLambdaLifted :: Node -> Bool
nodeIsLambdaLifted = not . hasNestedLambdas
where
hasNestedLambdas :: Node -> Bool
hasNestedLambdas = has (cosmos . _NLam) . snd . unfoldLambdas'

-- | True if there are no letrec nodes
nodeIsLetRecLifted :: Node -> Bool
nodeIsLetRecLifted = not . hasLetRecs
where
hasLetRecs :: Node -> Bool
hasLetRecs = has (cosmos . _NRec)

isLifted :: InfoTable -> Bool
isLifted = all nodeIsLifted . toList . (^. identContext)
isLifted = all nodeIsLifted . (^. identContext)

isLetRecLifted :: InfoTable -> Bool
isLetRecLifted = all nodeIsLetRecLifted . (^. identContext)
25 changes: 21 additions & 4 deletions test/Core/Transformation/Lifting.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import Core.Transformation.Base
import Juvix.Compiler.Core.Transformation

allTests :: TestTree
allTests = testGroup "Lambda lifting" (map liftTest Eval.tests)

pipe :: [TransformationId]
pipe = [LambdaLifting]
allTests =
testGroup
"Lifting"
[ testGroup "Lambda and LetRec lifting" (map liftTest Eval.tests),
testGroup "Only LetRec lifting" (map letRecLiftTest Eval.tests)
]

liftTest :: Eval.PosTest -> TestTree
liftTest _testEval =
Expand All @@ -19,3 +21,18 @@ liftTest _testEval =
_testAssertion = \i -> unless (isLifted i) (error "not lambda lifted"),
_testEval
}
where
pipe :: [TransformationId]
pipe = [LambdaLetRecLifting]

letRecLiftTest :: Eval.PosTest -> TestTree
letRecLiftTest _testEval =
fromTest
Test
{ _testTransformations = pipe,
_testAssertion = \i -> unless (isLetRecLifted i) (error "not letrec lifted"),
_testEval
}
where
pipe :: [TransformationId]
pipe = [LetRecLifting]