Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't fold lets if the let-bound variable occurs under a lambda-abstraction #3029

Merged
merged 8 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 45 additions & 45 deletions src/Juvix/Compiler/Core/Extra/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -177,52 +177,52 @@ isFalseConstr = \case
NCtr Constr {..} | _constrTag == BuiltinTag TagFalse -> True
_ -> False

isDebugOp :: Node -> Bool
isDebugOp = \case
NBlt BuiltinApp {..} ->
case _builtinAppOp of
OpTrace -> True
OpFail -> True
OpSeq -> True
OpAssert -> False
OpAnomaByteArrayFromAnomaContents -> False
OpAnomaByteArrayToAnomaContents -> False
OpAnomaDecode -> False
OpAnomaEncode -> False
OpAnomaGet -> False
OpAnomaSign -> False
OpAnomaSignDetached -> False
OpAnomaVerifyDetached -> False
OpAnomaVerifyWithMessage -> False
OpEc -> False
OpFieldAdd -> False
OpFieldDiv -> False
OpFieldFromInt -> False
OpFieldMul -> False
OpFieldSub -> False
OpPoseidonHash -> False
OpRandomEcPoint -> False
OpStrConcat -> False
OpStrToInt -> False
OpUInt8FromInt -> False
OpUInt8ToInt -> False
OpByteArrayFromListByte -> False
OpByteArrayLength -> False
OpEq -> False
OpIntAdd -> False
OpIntDiv -> False
OpIntLe -> False
OpIntLt -> False
OpIntMod -> False
OpIntMul -> False
OpIntSub -> False
OpFieldToInt -> False
OpShow -> False
_ -> False

-- | Check if the node contains `trace`, `fail` or `seq` (`>->`).
containsDebugOperations :: Node -> Bool
containsDebugOperations = ufold (\x xs -> x || or xs) isDebugOp
where
isDebugOp :: Node -> Bool
isDebugOp = \case
NBlt BuiltinApp {..} ->
case _builtinAppOp of
OpTrace -> True
OpFail -> True
OpSeq -> True
OpAssert -> False
OpAnomaByteArrayFromAnomaContents -> False
OpAnomaByteArrayToAnomaContents -> False
OpAnomaDecode -> False
OpAnomaEncode -> False
OpAnomaGet -> False
OpAnomaSign -> False
OpAnomaSignDetached -> False
OpAnomaVerifyDetached -> False
OpAnomaVerifyWithMessage -> False
OpEc -> False
OpFieldAdd -> False
OpFieldDiv -> False
OpFieldFromInt -> False
OpFieldMul -> False
OpFieldSub -> False
OpPoseidonHash -> False
OpRandomEcPoint -> False
OpStrConcat -> False
OpStrToInt -> False
OpUInt8FromInt -> False
OpUInt8ToInt -> False
OpByteArrayFromListByte -> False
OpByteArrayLength -> False
OpEq -> False
OpIntAdd -> False
OpIntDiv -> False
OpIntLe -> False
OpIntLt -> False
OpIntMod -> False
OpIntMul -> False
OpIntSub -> False
OpFieldToInt -> False
OpShow -> False
_ -> False
containsDebugOps :: Node -> Bool
containsDebugOps = ufold (\x xs -> x || or xs) isDebugOp

freeVarsSortedMany :: [Node] -> Set Var
freeVarsSortedMany n = Set.fromList (n ^.. each . freeVars)
Expand Down
25 changes: 20 additions & 5 deletions src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,26 @@ makeLenses ''FreeVarsInfo
-- | Computes free variable info for each subnode. Assumption: no subnode is a
-- closure.
computeFreeVarsInfo :: Node -> Node
computeFreeVarsInfo = umap go
computeFreeVarsInfo = computeFreeVarsInfo' 1

-- | `lambdaMultiplier` specifies how much to multiply the free variable count
-- for variables under lambdas
computeFreeVarsInfo' :: Int -> Node -> Node
computeFreeVarsInfo' lambdaMultiplier = umap go
where
go :: Node -> Node
go node = case node of
NVar Var {..} ->
mkVar (Info.insert fvi _varInfo) _varIndex
where
fvi = FreeVarsInfo (Map.singleton _varIndex 1)
NLam Lambda {..} ->
modifyInfo (Info.insert fvi) node
where
fvi =
FreeVarsInfo
. fmap (* lambdaMultiplier)
$ getFreeVars 1 _lambdaBody
_ ->
modifyInfo (Info.insert fvi) node
where
Expand All @@ -35,14 +47,17 @@ computeFreeVarsInfo = umap go
foldr
( \NodeChild {..} acc ->
Map.unionWith (+) acc $
Map.mapKeysMonotonic (\idx -> idx - _childBindersNum) $
Map.filterWithKey
(\idx _ -> idx >= _childBindersNum)
(getFreeVarsInfo _childNode ^. infoFreeVars)
getFreeVars _childBindersNum _childNode
)
mempty
(children node)

getFreeVars :: Int -> Node -> Map Index Int
getFreeVars bindersNum node =
Map.mapKeysMonotonic (\idx -> idx - bindersNum)
. Map.filterWithKey (\idx _ -> idx >= bindersNum)
$ getFreeVarsInfo node ^. infoFreeVars

getFreeVarsInfo :: Node -> FreeVarsInfo
getFreeVarsInfo = fromJust . Info.lookup kFreeVarsInfo . getInfo

Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Core/Transformation/MoveApps.hs
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ convertNode = dmap go
-- - https://github.com/anoma/juvix/issues/1654
-- - https://github.com/anoma/juvix/pull/1659
moveApps :: Module -> Module
moveApps = mapT (const convertNode)
moveApps = mapAllNodes convertNode
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ convertNode inlineDepth nonRecSyms md = dmapL go
NIdt Ident {..} -> case pi of
Just InlineCase ->
NCase cs {_caseValue = mkApps def args}
Just InlineNever ->
node
Nothing
| HashSet.member _identSymbol nonRecSyms
&& isConstructorApp def
Expand Down
11 changes: 8 additions & 3 deletions src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
-- An optimizing transformation that folds lets whose values are immediate,
-- i.e., they don't require evaluation or memory allocation (variables or
-- constants), or when the bound variable occurs at most once in the body.
-- constants), or when the bound variable occurs at most once in the body but
-- not under a lambda-abstraction.
--
-- For example, transforms
-- ```
Expand All @@ -27,7 +28,7 @@ convertNode isFoldable md = rmapL go
|| Info.freeVarOccurrences 0 _letBody <= 1
|| isFoldable md bl (_letItem ^. letItemValue)
)
&& not (containsDebugOperations _letBody) ->
&& not (containsDebugOps _letBody) ->
go (recur . (mkBCRemove b val' :)) (BL.cons b bl) _letBody
where
val' = go recur bl (_letItem ^. letItemValue)
Expand All @@ -40,7 +41,11 @@ letFolding' isFoldable tab =
mapAllNodes
( removeInfo kFreeVarsInfo
. convertNode isFoldable tab
. computeFreeVarsInfo
. computeFreeVarsInfo' 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd add a comment here explaining how this 2 prevents folding lets that bind variables in lambdas

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

-- 2 is the lambda multiplier factor which guarantees that every free
-- variable under a lambda is counted at least twice, preventing let
-- folding for let-bound variables (with non-immediate values) that
-- occur under lambdas
)
tab

Expand Down
16 changes: 13 additions & 3 deletions test/Compilation/Positive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ tests =
$(mkRelDir ".")
$(mkRelFile "test075.juvix")
$(mkRelFile "out/test075.out"),
posTestEval
posTest
"Test076: Builtin Maybe"
$(mkRelDir ".")
$(mkRelFile "test076.juvix")
Expand All @@ -466,9 +466,19 @@ tests =
$(mkRelDir ".")
$(mkRelFile "test078.juvix")
$(mkRelFile "out/test078.out"),
posTestEval
posTest
"Test079: Let / LetRec type inference (during lambda lifting) in Core"
$(mkRelDir ".")
$(mkRelFile "test079.juvix")
$(mkRelFile "out/test079.out")
$(mkRelFile "out/test079.out"),
posTestEval -- TODO: this test is not compiling
"Test080: Do notation"
$(mkRelDir ".")
$(mkRelFile "test080.juvix")
$(mkRelFile "out/test080.out"),
posTest
"Test081: Non-duplication in let-folding"
$(mkRelDir ".")
$(mkRelFile "test081.juvix")
$(mkRelFile "out/test081.out")
]
2 changes: 2 additions & 0 deletions tests/Compilation/positive/out/test080.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
nothing
just 1
1 change: 1 addition & 0 deletions tests/Compilation/positive/out/test081.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
6 changes: 3 additions & 3 deletions tests/Compilation/positive/test059.juvix
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
-- builtin list
module test059;

import Stdlib.Prelude open hiding {head};
import Stdlib.Prelude open;

mylist : List Nat := [1; 2; 3 + 1];

mylist2 : List (List Nat) := [[10]; [2]; 3 + 1 :: nil];

head : {a : Type} -> a -> List a -> a
head' : {a : Type} -> a -> List a -> a
| a [] := a
| a [x; _] := x
| _ (h :: _) := h;

main : Nat := head 50 mylist + head 50 (head [] mylist2);
main : Nat := head' 50 mylist + head' 50 (head' [] mylist2);
18 changes: 18 additions & 0 deletions tests/Compilation/positive/test081.juvix
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
-- Non-duplication in let-folding
module test081;

import Stdlib.Prelude open;

{-# inline: false #-}
g (h : Nat -> Nat) : Nat := h 0 * h 0;

terminating
f (n : Nat) : Nat :=
if
| n == 0 := 0
| else :=
let terminating x := f (sub n 1)
in
g \{_ := x};

main : Nat := f 10000;
Loading