Skip to content

Commit

Permalink
Copy propagation in JuvixReg (#2828)
Browse files Browse the repository at this point in the history
* Closes #1614 
* Implements the copy propagation transformation in JuvixReg and adds
tests for it.
* For this optimization to give any improvement, we need to run dead
code elimination afterwards (#2827).
  • Loading branch information
lukaszcz authored Jun 18, 2024
1 parent 7a7c8ce commit 235d88f
Show file tree
Hide file tree
Showing 16 changed files with 302 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ jobs:
- name: Install RISC0 VM
shell: bash
run: |
cargo install cargo-binstall --force
cargo install cargo-binstall@1.6.9 --force
cargo binstall cargo-risczero@1.0.1 --no-confirm --force
cargo risczero install
Expand Down
10 changes: 7 additions & 3 deletions src/Juvix/Compiler/Reg/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ data TransformationId
| Cleanup
| SSA
| InitBranchVars
| CopyPropagation
deriving stock (Data, Bounded, Enum, Show)

data PipelineId
Expand All @@ -19,14 +20,16 @@ data PipelineId

type TransformationLikeId = TransformationLikeId' TransformationId PipelineId

-- Note: this works only because for now we mark all variables as live. Liveness
-- information needs to be re-computed after copy propagation.
toCTransformations :: [TransformationId]
toCTransformations = [Cleanup]
toCTransformations = [Cleanup, CopyPropagation]

toRustTransformations :: [TransformationId]
toRustTransformations = [Cleanup]
toRustTransformations = [Cleanup, CopyPropagation]

toCasmTransformations :: [TransformationId]
toCasmTransformations = [Cleanup, SSA]
toCasmTransformations = [Cleanup, CopyPropagation, SSA]

instance TransformationId' TransformationId where
transformationText :: TransformationId -> Text
Expand All @@ -35,6 +38,7 @@ instance TransformationId' TransformationId where
Cleanup -> strCleanup
SSA -> strSSA
InitBranchVars -> strInitBranchVars
CopyPropagation -> strCopyPropagation

instance PipelineId' TransformationId PipelineId where
pipelineText :: PipelineId -> Text
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ strSSA = "ssa"

strInitBranchVars :: Text
strInitBranchVars = "init-branch-vars"

strCopyPropagation :: Text
strCopyPropagation = "copy-propagation"
10 changes: 10 additions & 0 deletions src/Juvix/Compiler/Reg/Extra/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,13 @@ overValueRefs f = \case

goBlock :: InstrBlock -> InstrBlock
goBlock x = x

updateLiveVars' :: (VarRef -> Maybe VarRef) -> Instruction -> Instruction
updateLiveVars' f = \case
Prealloc x -> Prealloc $ over instrPreallocLiveVars (mapMaybe f) x
Call x -> Call $ over instrCallLiveVars (mapMaybe f) x
CallClosures x -> CallClosures $ over instrCallClosuresLiveVars (mapMaybe f) x
instr -> instr

updateLiveVars :: (VarRef -> VarRef) -> Instruction -> Instruction
updateLiveVars f = updateLiveVars' (Just . f)
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Reg/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ where
import Juvix.Compiler.Reg.Data.TransformationId
import Juvix.Compiler.Reg.Transformation.Base
import Juvix.Compiler.Reg.Transformation.Cleanup
import Juvix.Compiler.Reg.Transformation.CopyPropagation
import Juvix.Compiler.Reg.Transformation.IdentityTrans
import Juvix.Compiler.Reg.Transformation.InitBranchVars
import Juvix.Compiler.Reg.Transformation.SSA
Expand All @@ -21,3 +22,4 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
Cleanup -> return . cleanup
SSA -> return . computeSSA
InitBranchVars -> return . initBranchVars
CopyPropagation -> return . copyPropagate
56 changes: 56 additions & 0 deletions src/Juvix/Compiler/Reg/Transformation/CopyPropagation.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
module Juvix.Compiler.Reg.Transformation.CopyPropagation where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Reg.Extra
import Juvix.Compiler.Reg.Transformation.Base

type VarMap = HashMap VarRef VarRef

copyPropagateFunction :: Code -> Code
copyPropagateFunction =
snd
. runIdentity
. recurseF
ForwardRecursorSig
{ _forwardFun = \i acc -> return (go i acc),
_forwardCombine = combine
}
mempty
where
go :: Instruction -> VarMap -> (VarMap, Instruction)
go instr mpv = case instr' of
Assign InstrAssign {..}
| VRef v <- _instrAssignValue ->
(HashMap.insert _instrAssignResult v mpv', instr')
_ ->
(mpv', instr')
where
instr' = overValueRefs (adjustVarRef mpv) instr
mpv' = maybe mpv (filterOutVars mpv) (getResultVar instr)

filterOutVars :: VarMap -> VarRef -> VarMap
filterOutVars mpv v = HashMap.delete v $ HashMap.filter (/= v) mpv

adjustVarRef :: VarMap -> VarRef -> VarRef
adjustVarRef mpv vref@VarRef {..} = case _varRefGroup of
VarGroupArgs -> vref
VarGroupLocal -> fromMaybe vref $ HashMap.lookup vref mpv

combine :: Instruction -> NonEmpty VarMap -> (VarMap, Instruction)
combine instr mpvs = (mpv, instr')
where
mpv' :| mpvs' = fmap HashMap.toList mpvs
mpv =
HashMap.fromList
. HashSet.toList
. foldr (HashSet.intersection . HashSet.fromList) (HashSet.fromList mpv')
$ mpvs'

instr' = case instr of
Branch x -> Branch $ over instrBranchOutVar (fmap (adjustVarRef mpv)) x
Case x -> Case $ over instrCaseOutVar (fmap (adjustVarRef mpv)) x
_ -> impossible

copyPropagate :: InfoTable -> InfoTable
copyPropagate = mapT (const copyPropagateFunction)
11 changes: 2 additions & 9 deletions src/Juvix/Compiler/Reg/Transformation/SSA.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,13 @@ computeFunctionSSA =
where
go :: Instruction -> IndexMap VarRef -> (IndexMap VarRef, Instruction)
go instr mp = case getResultVar instr' of
Just vref -> (mp', updateLiveVars mp' (setResultVar instr' (mkVarRef VarGroupLocal idx)))
Just vref -> (mp', updateLiveVars' (adjustVarRef' mp') (setResultVar instr' (mkVarRef VarGroupLocal idx)))
where
(idx, mp') = IndexMap.assign mp vref
Nothing -> (mp, updateLiveVars mp instr')
Nothing -> (mp, updateLiveVars' (adjustVarRef' mp) instr')
where
instr' = overValueRefs (adjustVarRef mp) instr

updateLiveVars :: IndexMap VarRef -> Instruction -> Instruction
updateLiveVars mp = \case
Prealloc x -> Prealloc $ over instrPreallocLiveVars (mapMaybe (adjustVarRef' mp)) x
Call x -> Call $ over instrCallLiveVars (mapMaybe (adjustVarRef' mp)) x
CallClosures x -> CallClosures $ over instrCallClosuresLiveVars (mapMaybe (adjustVarRef' mp)) x
instr -> instr

-- For branches, when necessary we insert assignments unifying the renamed
-- output variables into a single output variable for both branches.
combine :: Instruction -> NonEmpty (IndexMap VarRef) -> (IndexMap VarRef, Instruction)
Expand Down
15 changes: 14 additions & 1 deletion test/Reg/Parse/Positive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ testDescr PosTest {..} =
filterTests :: [String] -> [PosTest] -> [PosTest]
filterTests incl = filter (\PosTest {..} -> _name `elem` incl)

filterOutTests :: [String] -> [PosTest] -> [PosTest]
filterOutTests excl = filter (\PosTest {..} -> _name `notElem` excl)

allTests :: TestTree
allTests =
testGroup
Expand Down Expand Up @@ -223,5 +226,15 @@ tests =
"Test038: Apply & argsnum"
$(mkRelDir ".")
$(mkRelFile "test038.jvr")
$(mkRelFile "out/test038.out")
$(mkRelFile "out/test038.out"),
PosTest
"Test039: Copy & constant propagation"
$(mkRelDir ".")
$(mkRelFile "test039.jvr")
$(mkRelFile "out/test039.out"),
PosTest
"Test040: Copy & constant propagation with branches"
$(mkRelDir ".")
$(mkRelFile "test040.jvr")
$(mkRelFile "out/test040.out")
]
4 changes: 3 additions & 1 deletion test/Reg/Transformation.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Reg.Transformation where

import Base
import Reg.Transformation.CopyPropagation qualified as CopyPropagation
import Reg.Transformation.IdentityTrans qualified as IdentityTrans
import Reg.Transformation.InitBranchVars qualified as InitBranchVars
import Reg.Transformation.SSA qualified as SSA
Expand All @@ -11,5 +12,6 @@ allTests =
"JuvixReg transformations"
[ IdentityTrans.allTests,
SSA.allTests,
InitBranchVars.allTests
InitBranchVars.allTests,
CopyPropagation.allTests
]
21 changes: 21 additions & 0 deletions test/Reg/Transformation/CopyPropagation.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module Reg.Transformation.CopyPropagation where

import Base
import Juvix.Compiler.Reg.Transformation
import Reg.Parse.Positive qualified as Parse
import Reg.Transformation.Base

allTests :: TestTree
allTests = testGroup "Copy Propagation" (map liftTest Parse.tests)

pipe :: [TransformationId]
pipe = [CopyPropagation]

liftTest :: Parse.PosTest -> TestTree
liftTest _testRun =
fromTest
Test
{ _testTransformations = pipe,
_testAssertion = const (return ()),
_testRun
}
2 changes: 1 addition & 1 deletion test/Reg/Transformation/InitBranchVars.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import Reg.Parse.Positive qualified as Parse
import Reg.Transformation.Base

allTests :: TestTree
allTests = testGroup "InitBranchVars" (map liftTest Parse.tests)
allTests = testGroup "InitBranchVars" (map liftTest $ Parse.filterOutTests ["Test039: Copy & constant propagation"] Parse.tests)

pipe :: [TransformationId]
pipe = [SSA, InitBranchVars]
Expand Down
2 changes: 1 addition & 1 deletion test/Reg/Transformation/SSA.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import Reg.Parse.Positive qualified as Parse
import Reg.Transformation.Base

allTests :: TestTree
allTests = testGroup "SSA" (map liftTest Parse.tests)
allTests = testGroup "SSA" (map liftTest $ Parse.filterOutTests ["Test039: Copy & constant propagation"] Parse.tests)

pipe :: [TransformationId]
pipe = [SSA]
Expand Down
1 change: 1 addition & 0 deletions tests/Reg/positive/out/test039.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
82
1 change: 1 addition & 0 deletions tests/Reg/positive/out/test040.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
79
84 changes: 84 additions & 0 deletions tests/Reg/positive/test039.jvr
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
-- Copy & constant propagation

type either {
left : integer -> either;
right : bool -> either;
}

function main() : * {
tmp[0] = 7;
tmp[1] = tmp[0];
tmp[0] = tmp[1];
tmp[2] = tmp[0];
-- tmp[2] = 7

tmp[1] = tmp[0];
tmp[0] = add tmp[1] 1;
tmp[2] = add tmp[2] tmp[1];
-- tmp[2] = 14

tmp[0] = 19;
tmp[1] = tmp[0];
tmp[0] = add tmp[1] 1;
tmp[3] = add tmp[0] tmp[1];
tmp[4] = tmp[3];
tmp[2] = add tmp[4] tmp[2];
-- tmp[2] = 53

tmp[1] = eq tmp[2] 54;
tmp[0] = 4;
tmp[3] = 3;
tmp[4] = tmp[0];
tmp[5] = 4;
tmp[6] = tmp[5];
br tmp[1] {
true: {
tmp[4] = 7;
};
false: {
tmp[3] = tmp[6];
};
};
tmp[2] = add tmp[2] tmp[4];
tmp[2] = add tmp[2] tmp[3];
-- tmp[2] = 61

tmp[0] = alloc left (3);
tmp[1] = 17;
tmp[3] = tmp[1];
case[either] tmp[0] {
left: {
tmp[4] = tmp[0].left[0];
tmp[1] = tmp[4];
tmp[3] = tmp[1];
};
right: {
nop;
};
};
tmp[2] = add tmp[2] tmp[3];
-- tmp[2] = 64

tmp[0] = alloc right (true);
tmp[1] = 17;
tmp[3] = tmp[1];
case[either] tmp[0] {
left: {
tmp[1] = tmp[0].left[0];
};
right: {
br tmp[0].right[0] {
true: {
tmp[1] = add tmp[3] 1;
};
false: {
nop;
};
};
};
};
tmp[2] = add tmp[2] tmp[1];
-- tmp[2] = 82

ret tmp[2];
}
Loading

0 comments on commit 235d88f

Please sign in to comment.