Skip to content

Commit

Permalink
Don't fold lets if the let-bound variable occurs under a lambda-abstr…
Browse files Browse the repository at this point in the history
…action (#3029)

* Closes #3002
  • Loading branch information
lukaszcz authored Sep 13, 2024
1 parent ef0bc6e commit b609e1f
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 60 deletions.
90 changes: 45 additions & 45 deletions src/Juvix/Compiler/Core/Extra/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -177,52 +177,52 @@ isFalseConstr = \case
NCtr Constr {..} | _constrTag == BuiltinTag TagFalse -> True
_ -> False

isDebugOp :: Node -> Bool
isDebugOp = \case
NBlt BuiltinApp {..} ->
case _builtinAppOp of
OpTrace -> True
OpFail -> True
OpSeq -> True
OpAssert -> False
OpAnomaByteArrayFromAnomaContents -> False
OpAnomaByteArrayToAnomaContents -> False
OpAnomaDecode -> False
OpAnomaEncode -> False
OpAnomaGet -> False
OpAnomaSign -> False
OpAnomaSignDetached -> False
OpAnomaVerifyDetached -> False
OpAnomaVerifyWithMessage -> False
OpEc -> False
OpFieldAdd -> False
OpFieldDiv -> False
OpFieldFromInt -> False
OpFieldMul -> False
OpFieldSub -> False
OpPoseidonHash -> False
OpRandomEcPoint -> False
OpStrConcat -> False
OpStrToInt -> False
OpUInt8FromInt -> False
OpUInt8ToInt -> False
OpByteArrayFromListByte -> False
OpByteArrayLength -> False
OpEq -> False
OpIntAdd -> False
OpIntDiv -> False
OpIntLe -> False
OpIntLt -> False
OpIntMod -> False
OpIntMul -> False
OpIntSub -> False
OpFieldToInt -> False
OpShow -> False
_ -> False

-- | Check if the node contains `trace`, `fail` or `seq` (`>->`).
containsDebugOperations :: Node -> Bool
containsDebugOperations = ufold (\x xs -> x || or xs) isDebugOp
where
isDebugOp :: Node -> Bool
isDebugOp = \case
NBlt BuiltinApp {..} ->
case _builtinAppOp of
OpTrace -> True
OpFail -> True
OpSeq -> True
OpAssert -> False
OpAnomaByteArrayFromAnomaContents -> False
OpAnomaByteArrayToAnomaContents -> False
OpAnomaDecode -> False
OpAnomaEncode -> False
OpAnomaGet -> False
OpAnomaSign -> False
OpAnomaSignDetached -> False
OpAnomaVerifyDetached -> False
OpAnomaVerifyWithMessage -> False
OpEc -> False
OpFieldAdd -> False
OpFieldDiv -> False
OpFieldFromInt -> False
OpFieldMul -> False
OpFieldSub -> False
OpPoseidonHash -> False
OpRandomEcPoint -> False
OpStrConcat -> False
OpStrToInt -> False
OpUInt8FromInt -> False
OpUInt8ToInt -> False
OpByteArrayFromListByte -> False
OpByteArrayLength -> False
OpEq -> False
OpIntAdd -> False
OpIntDiv -> False
OpIntLe -> False
OpIntLt -> False
OpIntMod -> False
OpIntMul -> False
OpIntSub -> False
OpFieldToInt -> False
OpShow -> False
_ -> False
containsDebugOps :: Node -> Bool
containsDebugOps = ufold (\x xs -> x || or xs) isDebugOp

freeVarsSortedMany :: [Node] -> Set Var
freeVarsSortedMany n = Set.fromList (n ^.. each . freeVars)
Expand Down
25 changes: 20 additions & 5 deletions src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,26 @@ makeLenses ''FreeVarsInfo
-- | Computes free variable info for each subnode. Assumption: no subnode is a
-- closure.
computeFreeVarsInfo :: Node -> Node
computeFreeVarsInfo = umap go
computeFreeVarsInfo = computeFreeVarsInfo' 1

-- | `lambdaMultiplier` specifies how much to multiply the free variable count
-- for variables under lambdas
computeFreeVarsInfo' :: Int -> Node -> Node
computeFreeVarsInfo' lambdaMultiplier = umap go
where
go :: Node -> Node
go node = case node of
NVar Var {..} ->
mkVar (Info.insert fvi _varInfo) _varIndex
where
fvi = FreeVarsInfo (Map.singleton _varIndex 1)
NLam Lambda {..} ->
modifyInfo (Info.insert fvi) node
where
fvi =
FreeVarsInfo
. fmap (* lambdaMultiplier)
$ getFreeVars 1 _lambdaBody
_ ->
modifyInfo (Info.insert fvi) node
where
Expand All @@ -35,14 +47,17 @@ computeFreeVarsInfo = umap go
foldr
( \NodeChild {..} acc ->
Map.unionWith (+) acc $
Map.mapKeysMonotonic (\idx -> idx - _childBindersNum) $
Map.filterWithKey
(\idx _ -> idx >= _childBindersNum)
(getFreeVarsInfo _childNode ^. infoFreeVars)
getFreeVars _childBindersNum _childNode
)
mempty
(children node)

getFreeVars :: Int -> Node -> Map Index Int
getFreeVars bindersNum node =
Map.mapKeysMonotonic (\idx -> idx - bindersNum)
. Map.filterWithKey (\idx _ -> idx >= bindersNum)
$ getFreeVarsInfo node ^. infoFreeVars

getFreeVarsInfo :: Node -> FreeVarsInfo
getFreeVarsInfo = fromJust . Info.lookup kFreeVarsInfo . getInfo

Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Core/Transformation/MoveApps.hs
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ convertNode = dmap go
-- - https://github.com/anoma/juvix/issues/1654
-- - https://github.com/anoma/juvix/pull/1659
moveApps :: Module -> Module
moveApps = mapT (const convertNode)
moveApps = mapAllNodes convertNode
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ convertNode inlineDepth nonRecSyms md = dmapL go
NIdt Ident {..} -> case pi of
Just InlineCase ->
NCase cs {_caseValue = mkApps def args}
Just InlineNever ->
node
Nothing
| HashSet.member _identSymbol nonRecSyms
&& isConstructorApp def
Expand Down
11 changes: 8 additions & 3 deletions src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
-- An optimizing transformation that folds lets whose values are immediate,
-- i.e., they don't require evaluation or memory allocation (variables or
-- constants), or when the bound variable occurs at most once in the body.
-- constants), or when the bound variable occurs at most once in the body but
-- not under a lambda-abstraction.
--
-- For example, transforms
-- ```
Expand All @@ -27,7 +28,7 @@ convertNode isFoldable md = rmapL go
|| Info.freeVarOccurrences 0 _letBody <= 1
|| isFoldable md bl (_letItem ^. letItemValue)
)
&& not (containsDebugOperations _letBody) ->
&& not (containsDebugOps _letBody) ->
go (recur . (mkBCRemove b val' :)) (BL.cons b bl) _letBody
where
val' = go recur bl (_letItem ^. letItemValue)
Expand All @@ -40,7 +41,11 @@ letFolding' isFoldable tab =
mapAllNodes
( removeInfo kFreeVarsInfo
. convertNode isFoldable tab
. computeFreeVarsInfo
. computeFreeVarsInfo' 2
-- 2 is the lambda multiplier factor which guarantees that every free
-- variable under a lambda is counted at least twice, preventing let
-- folding for let-bound variables (with non-immediate values) that
-- occur under lambdas
)
tab

Expand Down
16 changes: 13 additions & 3 deletions test/Compilation/Positive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ tests =
$(mkRelDir ".")
$(mkRelFile "test075.juvix")
$(mkRelFile "out/test075.out"),
posTestEval
posTest
"Test076: Builtin Maybe"
$(mkRelDir ".")
$(mkRelFile "test076.juvix")
Expand All @@ -466,9 +466,19 @@ tests =
$(mkRelDir ".")
$(mkRelFile "test078.juvix")
$(mkRelFile "out/test078.out"),
posTestEval
posTest
"Test079: Let / LetRec type inference (during lambda lifting) in Core"
$(mkRelDir ".")
$(mkRelFile "test079.juvix")
$(mkRelFile "out/test079.out")
$(mkRelFile "out/test079.out"),
posTestEval -- TODO: this test is not compiling
"Test080: Do notation"
$(mkRelDir ".")
$(mkRelFile "test080.juvix")
$(mkRelFile "out/test080.out"),
posTest
"Test081: Non-duplication in let-folding"
$(mkRelDir ".")
$(mkRelFile "test081.juvix")
$(mkRelFile "out/test081.out")
]
2 changes: 2 additions & 0 deletions tests/Compilation/positive/out/test080.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
nothing
just 1
1 change: 1 addition & 0 deletions tests/Compilation/positive/out/test081.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
6 changes: 3 additions & 3 deletions tests/Compilation/positive/test059.juvix
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
-- builtin list
module test059;

import Stdlib.Prelude open hiding {head};
import Stdlib.Prelude open;

mylist : List Nat := [1; 2; 3 + 1];

mylist2 : List (List Nat) := [[10]; [2]; 3 + 1 :: nil];

head : {a : Type} -> a -> List a -> a
head' : {a : Type} -> a -> List a -> a
| a [] := a
| a [x; _] := x
| _ (h :: _) := h;

main : Nat := head 50 mylist + head 50 (head [] mylist2);
main : Nat := head' 50 mylist + head' 50 (head' [] mylist2);
18 changes: 18 additions & 0 deletions tests/Compilation/positive/test081.juvix
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
-- Non-duplication in let-folding
module test081;

import Stdlib.Prelude open;

{-# inline: false #-}
g (h : Nat -> Nat) : Nat := h 0 * h 0;

terminating
f (n : Nat) : Nat :=
if
| n == 0 := 0
| else :=
let terminating x := f (sub n 1)
in
g \{_ := x};

main : Nat := f 10000;

0 comments on commit b609e1f

Please sign in to comment.