Skip to content

Commit

Permalink
Filter out unreachable functions in JuvixAsm (#2575)
Browse files Browse the repository at this point in the history
Adds a JuvixAsm transformation to filter out unreachable functions. This
will make the generated nock/cairo code smaller.
  • Loading branch information
lukaszcz authored Jan 12, 2024
1 parent a9995b8 commit 3269c8f
Show file tree
Hide file tree
Showing 12 changed files with 247 additions and 6 deletions.
53 changes: 53 additions & 0 deletions src/Juvix/Compiler/Asm/Data/CallGraph.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
module Juvix.Compiler.Asm.Data.CallGraph where

import Data.HashSet qualified as HashSet
import Juvix.Compiler.Asm.Data.InfoTable
import Juvix.Compiler.Asm.Extra
import Juvix.Compiler.Asm.Language

-- | Call graph type
type CallGraph = DependencyInfo Symbol

-- | Compute the call graph
createCallGraph :: (Member (Error AsmError) r) => InfoTable -> Sem r CallGraph
createCallGraph tab = do
graph <- createCallGraphMap tab
return $ createDependencyInfo graph startVertices
where
startVertices :: HashSet Symbol
startVertices = HashSet.fromList syms

syms :: [Symbol]
syms = maybe [] singleton (tab ^. infoMainFunction)

createCallGraphMap :: (Member (Error AsmError) r) => InfoTable -> Sem r (HashMap Symbol (HashSet Symbol))
createCallGraphMap tab =
mapM
(\FunctionInfo {..} -> getFunSymbols tab _functionCode)
(tab ^. infoFunctions)

getFunSymbols :: (Member (Error AsmError) r) => InfoTable -> Code -> Sem r (HashSet Symbol)
getFunSymbols tab code = foldS sig code mempty
where
sig :: FoldSig StackInfo r (HashSet Symbol)
sig =
FoldSig
{ _foldInfoTable = tab,
_foldAdjust = const mempty,
_foldInstr = \_ CmdInstr {..} acc -> return $ goInstr acc _cmdInstrInstruction,
_foldBranch = \_ _ a1 a2 a3 -> return $ a1 <> a2 <> a3,
_foldCase = \_ _ as ma a -> return $ mconcat as <> fromMaybe mempty ma <> a,
_foldSave = \_ _ a1 a2 -> return $ a1 <> a2
}

goInstr :: HashSet Symbol -> Instruction -> HashSet Symbol
goInstr syms = \case
AllocClosure InstrAllocClosure {..} -> HashSet.insert _allocClosureFunSymbol syms
Call InstrCall {..} -> goCallType syms _callType
TailCall InstrCall {..} -> goCallType syms _callType
_ -> syms

goCallType :: HashSet Symbol -> CallType -> HashSet Symbol
goCallType syms = \case
CallFun sym -> HashSet.insert sym syms
CallClosure -> syms
13 changes: 12 additions & 1 deletion src/Juvix/Compiler/Asm/Extra/Recursors.hs
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,18 @@ recurseS' sig = go
dropTempStack :: StackInfo -> StackInfo
dropTempStack si = si {_stackInfoTempStackHeight = 0}

-- | Fold signature. Contains read-only fold parameters.
-- | Fold signature. Contains read-only fold parameters. A fold (`foldS`) goes
-- through the code from right to left (from end to beginning) accumulating
-- values. The `a` is the type of the accumulated values. The last argument to
-- the `_fold*` functions below is the accumulator. The `_foldAdjust` function
-- adjusts the accumulator when entering a block (in `CmdBranch`, `CmdCase`,
-- `CmdSave`). For example, for `save { P1 }; P2` let `a2` be the accumulator
-- value after folding `P2`. Then `P1` is folded with the initial accumulator
-- `_foldAdjust a2`. However, `_foldSave` is called with `_foldSave m c a1 a2`,
-- i.e., with the original `a2`, where `a1` is the result of folding `P1` with
-- initial accumulator `_foldAdjust a2`. In most simple cases, one can set
-- `_foldAdjust` to `const empty` where `empty` is the empty accumulator value
-- (e.g. `mempty` for a monoid).
data FoldSig m r a = FoldSig
{ _foldInfoTable :: InfoTable,
_foldAdjust :: a -> a,
Expand Down
8 changes: 5 additions & 3 deletions src/Juvix/Compiler/Asm/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import Juvix.Compiler.Core.Language.Base

-- In what follows, when referring to the stack we mean the current local value
-- stack, unless otherwise stated. By stack[n] we denote the n-th cell from the
-- top in the value stack (0-based).

-- * top* in the value stack (0-based).

-- | Offset of a data field or an argument
type Offset = Int
Expand Down Expand Up @@ -46,8 +47,9 @@ data DirectRef
| -- | ArgRef references an argument in the argument area (0-based offsets).
-- JVA code: 'arg[<offset>]'.
ArgRef OffsetRef
| -- | TempRef references a value in the temporary area (0-based offsets). JVA
-- code: 'tmp[<offset>]'.
| -- | TempRef references a value in the temporary stack (0-based offsets,
-- counted from the *bottom* of the temporary stack). JVA code:
-- 'tmp[<offset>]'.
TempRef OffsetRef

data OffsetRef = OffsetRef
Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Asm/Pipeline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import Juvix.Compiler.Pipeline.EntryPoint
-- | Perform transformations on JuvixAsm necessary before the translation to
-- JuvixReg
toReg' :: (Members '[Error AsmError, Reader Options] r) => InfoTable -> Sem r InfoTable
toReg' = validate >=> computeStackUsage >=> computePrealloc
toReg' = validate >=> filterUnreachable >=> computeStackUsage >=> computePrealloc

toReg :: (Members '[Error JuvixError, Reader EntryPoint] r) => InfoTable -> Sem r InfoTable
toReg = mapReader fromEntryPoint . mapError (JuvixError @AsmError) . toReg'
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Asm/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ module Juvix.Compiler.Asm.Transformation
module Juvix.Compiler.Asm.Transformation.Prealloc,
module Juvix.Compiler.Asm.Transformation.Validate,
module Juvix.Compiler.Asm.Transformation.Apply,
module Juvix.Compiler.Asm.Transformation.FilterUnreachable,
)
where

import Juvix.Compiler.Asm.Transformation.Apply
import Juvix.Compiler.Asm.Transformation.FilterUnreachable
import Juvix.Compiler.Asm.Transformation.Prealloc
import Juvix.Compiler.Asm.Transformation.StackUsage
import Juvix.Compiler.Asm.Transformation.Validate
12 changes: 12 additions & 0 deletions src/Juvix/Compiler/Asm/Transformation/FilterUnreachable.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module Juvix.Compiler.Asm.Transformation.FilterUnreachable where

import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Asm.Data.CallGraph
import Juvix.Compiler.Asm.Data.InfoTable
import Juvix.Compiler.Asm.Error
import Juvix.Compiler.Asm.Language

filterUnreachable :: (Member (Error AsmError) r) => InfoTable -> Sem r InfoTable
filterUnreachable tab = do
graph <- createCallGraph tab
return $ over infoFunctions (HashMap.filterWithKey (const . isReachable graph)) tab
3 changes: 2 additions & 1 deletion test/Asm/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module Asm.Transformation where

import Asm.Transformation.Apply qualified as Apply
import Asm.Transformation.Prealloc qualified as Prealloc
import Asm.Transformation.Reachability qualified as Reachability
import Base

allTests :: TestTree
allTests = testGroup "JuvixAsm transformations" [Prealloc.allTests, Apply.allTests]
allTests = testGroup "JuvixAsm transformations" [Prealloc.allTests, Apply.allTests, Reachability.allTests]
56 changes: 56 additions & 0 deletions test/Asm/Transformation/Reachability.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
module Asm.Transformation.Reachability (allTests) where

import Asm.Run.Positive qualified as Run
import Asm.Transformation.Base
import Base
import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Asm.Options
import Juvix.Compiler.Asm.Transformation
import Juvix.Compiler.Asm.Transformation.Base

data ReachabilityTest = ReachabilityTest
{ _reachabilityTestReachable :: [Text],
_reachabilityTestEval :: Run.PosTest
}

allTests :: TestTree
allTests =
testGroup "Reachability" $
map liftTest rtests

rtests :: [ReachabilityTest]
rtests =
[ ReachabilityTest
{ _reachabilityTestReachable = ["f", "f'", "g'", "h", "h'", "main"],
_reachabilityTestEval =
Run.PosTest
"Test001: Reachability"
$(mkRelDir "reachability")
$(mkRelFile "test001.jva")
$(mkRelFile "out/test001.out")
},
ReachabilityTest
{ _reachabilityTestReachable = ["f", "g", "id", "sum", "main"],
_reachabilityTestEval =
Run.PosTest
"Test002: Reachability with loops & closures"
$(mkRelDir "reachability")
$(mkRelFile "test002.jva")
$(mkRelFile "out/test002.out")
}
]

liftTest :: ReachabilityTest -> TestTree
liftTest ReachabilityTest {..} =
fromTest
Test
{ _testTransformation = runTransformation (runReader opts . filterUnreachable),
_testAssertion = \tab -> unless (nubSort (map (^. functionName) (HashMap.elems (tab ^. infoFunctions))) == nubSort _reachabilityTestReachable) (error "check reachable"),
_testEval = _reachabilityTestEval
}
where
opts =
Options
{ _optDebug = True,
_optLimits = getLimits TargetCWasm32Wasi True
}
1 change: 1 addition & 0 deletions tests/Asm/positive/reachability/out/test001.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
9
1 change: 1 addition & 0 deletions tests/Asm/positive/reachability/out/test002.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
5051
44 changes: 44 additions & 0 deletions tests/Asm/positive/reachability/test001.jva
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
function h(integer) : integer {
push arg[0];
ret;
}

function h'(integer) : integer {
push arg[0];
ret;
}

function f(integer) : integer {
push 1;
push arg[0];
call h;
add;
ret;
}

function f'(integer) : integer {
push 1;
push arg[0];
call h';
add;
ret;
}

function g(integer) : integer {
push 2;
push arg[0];
call f;
add;
ret;
}

function g'(integer) : integer {
push arg[0];
tcall f';
}

function main() : integer {
push 7;
call f;
tcall g';
}
58 changes: 58 additions & 0 deletions tests/Asm/positive/reachability/test002.jva
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
function f(*, integer) : integer {
push arg[1];
push arg[0];
tcall $ 1;
}

function id(integer) : integer {
push arg[0];
ret;
}

function g(integer) : integer {
push 1;
push arg[0];
calloc id 0;
call f;
add;
ret;
}

function sum(integer) : integer {
push arg[0];
push 0;
eq;
br {
true: {
push 0;
tcall g;
};
false: {
push 1;
push arg[0];
sub;
call sum;
push arg[0];
add;
ret;
};
};
}

function g'(integer) : integer {
push 2;
push arg[0];
call id;
add;
ret;
}

function g''(integer) : integer {
push arg[0];
tcall sum;
}

function main() : integer {
push 100;
tcall sum;
}

0 comments on commit 3269c8f

Please sign in to comment.