From 235d88f3037abbe12b980d9994c6dc742e7b590c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Czajka?= <62751+lukaszcz@users.noreply.github.com> Date: Tue, 18 Jun 2024 21:38:02 +0200 Subject: [PATCH] Copy propagation in JuvixReg (#2828) * Closes #1614 * Implements the copy propagation transformation in JuvixReg and adds tests for it. * For this optimization to give any improvement, we need to run dead code elimination afterwards (#2827). --- .github/workflows/ci.yml | 2 +- .../Compiler/Reg/Data/TransformationId.hs | 10 +- .../Reg/Data/TransformationId/Strings.hs | 3 + src/Juvix/Compiler/Reg/Extra/Base.hs | 10 ++ src/Juvix/Compiler/Reg/Transformation.hs | 2 + .../Reg/Transformation/CopyPropagation.hs | 56 +++++++++++ src/Juvix/Compiler/Reg/Transformation/SSA.hs | 11 +-- test/Reg/Parse/Positive.hs | 15 ++- test/Reg/Transformation.hs | 4 +- test/Reg/Transformation/CopyPropagation.hs | 21 ++++ test/Reg/Transformation/InitBranchVars.hs | 2 +- test/Reg/Transformation/SSA.hs | 2 +- tests/Reg/positive/out/test039.out | 1 + tests/Reg/positive/out/test040.out | 1 + tests/Reg/positive/test039.jvr | 84 ++++++++++++++++ tests/Reg/positive/test040.jvr | 95 +++++++++++++++++++ 16 files changed, 302 insertions(+), 17 deletions(-) create mode 100644 src/Juvix/Compiler/Reg/Transformation/CopyPropagation.hs create mode 100644 test/Reg/Transformation/CopyPropagation.hs create mode 100644 tests/Reg/positive/out/test039.out create mode 100644 tests/Reg/positive/out/test040.out create mode 100644 tests/Reg/positive/test039.jvr create mode 100644 tests/Reg/positive/test040.jvr diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 079255940d..15c23d8816 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -144,7 +144,7 @@ jobs: - name: Install RISC0 VM shell: bash run: | - cargo install cargo-binstall --force + cargo install cargo-binstall@1.6.9 --force cargo binstall cargo-risczero@1.0.1 --no-confirm --force cargo risczero install diff --git a/src/Juvix/Compiler/Reg/Data/TransformationId.hs b/src/Juvix/Compiler/Reg/Data/TransformationId.hs index f93bde484a..6ec735f736 100644 --- a/src/Juvix/Compiler/Reg/Data/TransformationId.hs +++ b/src/Juvix/Compiler/Reg/Data/TransformationId.hs @@ -9,6 +9,7 @@ data TransformationId | Cleanup | SSA | InitBranchVars + | CopyPropagation deriving stock (Data, Bounded, Enum, Show) data PipelineId @@ -19,14 +20,16 @@ data PipelineId type TransformationLikeId = TransformationLikeId' TransformationId PipelineId +-- Note: this works only because for now we mark all variables as live. Liveness +-- information needs to be re-computed after copy propagation. toCTransformations :: [TransformationId] -toCTransformations = [Cleanup] +toCTransformations = [Cleanup, CopyPropagation] toRustTransformations :: [TransformationId] -toRustTransformations = [Cleanup] +toRustTransformations = [Cleanup, CopyPropagation] toCasmTransformations :: [TransformationId] -toCasmTransformations = [Cleanup, SSA] +toCasmTransformations = [Cleanup, CopyPropagation, SSA] instance TransformationId' TransformationId where transformationText :: TransformationId -> Text @@ -35,6 +38,7 @@ instance TransformationId' TransformationId where Cleanup -> strCleanup SSA -> strSSA InitBranchVars -> strInitBranchVars + CopyPropagation -> strCopyPropagation instance PipelineId' TransformationId PipelineId where pipelineText :: PipelineId -> Text diff --git a/src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs b/src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs index a38e22620a..c300c5ba1a 100644 --- a/src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs +++ b/src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs @@ -22,3 +22,6 @@ strSSA = "ssa" strInitBranchVars :: Text strInitBranchVars = "init-branch-vars" + +strCopyPropagation :: Text +strCopyPropagation = "copy-propagation" diff --git a/src/Juvix/Compiler/Reg/Extra/Base.hs b/src/Juvix/Compiler/Reg/Extra/Base.hs index ac036508f7..f42c5d3c54 100644 --- a/src/Juvix/Compiler/Reg/Extra/Base.hs +++ b/src/Juvix/Compiler/Reg/Extra/Base.hs @@ -148,3 +148,13 @@ overValueRefs f = \case goBlock :: InstrBlock -> InstrBlock goBlock x = x + +updateLiveVars' :: (VarRef -> Maybe VarRef) -> Instruction -> Instruction +updateLiveVars' f = \case + Prealloc x -> Prealloc $ over instrPreallocLiveVars (mapMaybe f) x + Call x -> Call $ over instrCallLiveVars (mapMaybe f) x + CallClosures x -> CallClosures $ over instrCallClosuresLiveVars (mapMaybe f) x + instr -> instr + +updateLiveVars :: (VarRef -> VarRef) -> Instruction -> Instruction +updateLiveVars f = updateLiveVars' (Just . f) diff --git a/src/Juvix/Compiler/Reg/Transformation.hs b/src/Juvix/Compiler/Reg/Transformation.hs index 315e634bb0..76ee9cb20a 100644 --- a/src/Juvix/Compiler/Reg/Transformation.hs +++ b/src/Juvix/Compiler/Reg/Transformation.hs @@ -8,6 +8,7 @@ where import Juvix.Compiler.Reg.Data.TransformationId import Juvix.Compiler.Reg.Transformation.Base import Juvix.Compiler.Reg.Transformation.Cleanup +import Juvix.Compiler.Reg.Transformation.CopyPropagation import Juvix.Compiler.Reg.Transformation.IdentityTrans import Juvix.Compiler.Reg.Transformation.InitBranchVars import Juvix.Compiler.Reg.Transformation.SSA @@ -21,3 +22,4 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts Cleanup -> return . cleanup SSA -> return . computeSSA InitBranchVars -> return . initBranchVars + CopyPropagation -> return . copyPropagate diff --git a/src/Juvix/Compiler/Reg/Transformation/CopyPropagation.hs b/src/Juvix/Compiler/Reg/Transformation/CopyPropagation.hs new file mode 100644 index 0000000000..a296535a2b --- /dev/null +++ b/src/Juvix/Compiler/Reg/Transformation/CopyPropagation.hs @@ -0,0 +1,56 @@ +module Juvix.Compiler.Reg.Transformation.CopyPropagation where + +import Data.HashMap.Strict qualified as HashMap +import Data.HashSet qualified as HashSet +import Juvix.Compiler.Reg.Extra +import Juvix.Compiler.Reg.Transformation.Base + +type VarMap = HashMap VarRef VarRef + +copyPropagateFunction :: Code -> Code +copyPropagateFunction = + snd + . runIdentity + . recurseF + ForwardRecursorSig + { _forwardFun = \i acc -> return (go i acc), + _forwardCombine = combine + } + mempty + where + go :: Instruction -> VarMap -> (VarMap, Instruction) + go instr mpv = case instr' of + Assign InstrAssign {..} + | VRef v <- _instrAssignValue -> + (HashMap.insert _instrAssignResult v mpv', instr') + _ -> + (mpv', instr') + where + instr' = overValueRefs (adjustVarRef mpv) instr + mpv' = maybe mpv (filterOutVars mpv) (getResultVar instr) + + filterOutVars :: VarMap -> VarRef -> VarMap + filterOutVars mpv v = HashMap.delete v $ HashMap.filter (/= v) mpv + + adjustVarRef :: VarMap -> VarRef -> VarRef + adjustVarRef mpv vref@VarRef {..} = case _varRefGroup of + VarGroupArgs -> vref + VarGroupLocal -> fromMaybe vref $ HashMap.lookup vref mpv + + combine :: Instruction -> NonEmpty VarMap -> (VarMap, Instruction) + combine instr mpvs = (mpv, instr') + where + mpv' :| mpvs' = fmap HashMap.toList mpvs + mpv = + HashMap.fromList + . HashSet.toList + . foldr (HashSet.intersection . HashSet.fromList) (HashSet.fromList mpv') + $ mpvs' + + instr' = case instr of + Branch x -> Branch $ over instrBranchOutVar (fmap (adjustVarRef mpv)) x + Case x -> Case $ over instrCaseOutVar (fmap (adjustVarRef mpv)) x + _ -> impossible + +copyPropagate :: InfoTable -> InfoTable +copyPropagate = mapT (const copyPropagateFunction) diff --git a/src/Juvix/Compiler/Reg/Transformation/SSA.hs b/src/Juvix/Compiler/Reg/Transformation/SSA.hs index 4b6604958f..fce911dd86 100644 --- a/src/Juvix/Compiler/Reg/Transformation/SSA.hs +++ b/src/Juvix/Compiler/Reg/Transformation/SSA.hs @@ -22,20 +22,13 @@ computeFunctionSSA = where go :: Instruction -> IndexMap VarRef -> (IndexMap VarRef, Instruction) go instr mp = case getResultVar instr' of - Just vref -> (mp', updateLiveVars mp' (setResultVar instr' (mkVarRef VarGroupLocal idx))) + Just vref -> (mp', updateLiveVars' (adjustVarRef' mp') (setResultVar instr' (mkVarRef VarGroupLocal idx))) where (idx, mp') = IndexMap.assign mp vref - Nothing -> (mp, updateLiveVars mp instr') + Nothing -> (mp, updateLiveVars' (adjustVarRef' mp) instr') where instr' = overValueRefs (adjustVarRef mp) instr - updateLiveVars :: IndexMap VarRef -> Instruction -> Instruction - updateLiveVars mp = \case - Prealloc x -> Prealloc $ over instrPreallocLiveVars (mapMaybe (adjustVarRef' mp)) x - Call x -> Call $ over instrCallLiveVars (mapMaybe (adjustVarRef' mp)) x - CallClosures x -> CallClosures $ over instrCallClosuresLiveVars (mapMaybe (adjustVarRef' mp)) x - instr -> instr - -- For branches, when necessary we insert assignments unifying the renamed -- output variables into a single output variable for both branches. combine :: Instruction -> NonEmpty (IndexMap VarRef) -> (IndexMap VarRef, Instruction) diff --git a/test/Reg/Parse/Positive.hs b/test/Reg/Parse/Positive.hs index 444d99c7c5..bb2811deaf 100644 --- a/test/Reg/Parse/Positive.hs +++ b/test/Reg/Parse/Positive.hs @@ -26,6 +26,9 @@ testDescr PosTest {..} = filterTests :: [String] -> [PosTest] -> [PosTest] filterTests incl = filter (\PosTest {..} -> _name `elem` incl) +filterOutTests :: [String] -> [PosTest] -> [PosTest] +filterOutTests excl = filter (\PosTest {..} -> _name `notElem` excl) + allTests :: TestTree allTests = testGroup @@ -223,5 +226,15 @@ tests = "Test038: Apply & argsnum" $(mkRelDir ".") $(mkRelFile "test038.jvr") - $(mkRelFile "out/test038.out") + $(mkRelFile "out/test038.out"), + PosTest + "Test039: Copy & constant propagation" + $(mkRelDir ".") + $(mkRelFile "test039.jvr") + $(mkRelFile "out/test039.out"), + PosTest + "Test040: Copy & constant propagation with branches" + $(mkRelDir ".") + $(mkRelFile "test040.jvr") + $(mkRelFile "out/test040.out") ] diff --git a/test/Reg/Transformation.hs b/test/Reg/Transformation.hs index 7590f08080..12f3077b96 100644 --- a/test/Reg/Transformation.hs +++ b/test/Reg/Transformation.hs @@ -1,6 +1,7 @@ module Reg.Transformation where import Base +import Reg.Transformation.CopyPropagation qualified as CopyPropagation import Reg.Transformation.IdentityTrans qualified as IdentityTrans import Reg.Transformation.InitBranchVars qualified as InitBranchVars import Reg.Transformation.SSA qualified as SSA @@ -11,5 +12,6 @@ allTests = "JuvixReg transformations" [ IdentityTrans.allTests, SSA.allTests, - InitBranchVars.allTests + InitBranchVars.allTests, + CopyPropagation.allTests ] diff --git a/test/Reg/Transformation/CopyPropagation.hs b/test/Reg/Transformation/CopyPropagation.hs new file mode 100644 index 0000000000..a65bb3f05e --- /dev/null +++ b/test/Reg/Transformation/CopyPropagation.hs @@ -0,0 +1,21 @@ +module Reg.Transformation.CopyPropagation where + +import Base +import Juvix.Compiler.Reg.Transformation +import Reg.Parse.Positive qualified as Parse +import Reg.Transformation.Base + +allTests :: TestTree +allTests = testGroup "Copy Propagation" (map liftTest Parse.tests) + +pipe :: [TransformationId] +pipe = [CopyPropagation] + +liftTest :: Parse.PosTest -> TestTree +liftTest _testRun = + fromTest + Test + { _testTransformations = pipe, + _testAssertion = const (return ()), + _testRun + } diff --git a/test/Reg/Transformation/InitBranchVars.hs b/test/Reg/Transformation/InitBranchVars.hs index 5a56347c15..34089ba4cd 100644 --- a/test/Reg/Transformation/InitBranchVars.hs +++ b/test/Reg/Transformation/InitBranchVars.hs @@ -8,7 +8,7 @@ import Reg.Parse.Positive qualified as Parse import Reg.Transformation.Base allTests :: TestTree -allTests = testGroup "InitBranchVars" (map liftTest Parse.tests) +allTests = testGroup "InitBranchVars" (map liftTest $ Parse.filterOutTests ["Test039: Copy & constant propagation"] Parse.tests) pipe :: [TransformationId] pipe = [SSA, InitBranchVars] diff --git a/test/Reg/Transformation/SSA.hs b/test/Reg/Transformation/SSA.hs index b78d8aae03..f0c066c163 100644 --- a/test/Reg/Transformation/SSA.hs +++ b/test/Reg/Transformation/SSA.hs @@ -7,7 +7,7 @@ import Reg.Parse.Positive qualified as Parse import Reg.Transformation.Base allTests :: TestTree -allTests = testGroup "SSA" (map liftTest Parse.tests) +allTests = testGroup "SSA" (map liftTest $ Parse.filterOutTests ["Test039: Copy & constant propagation"] Parse.tests) pipe :: [TransformationId] pipe = [SSA] diff --git a/tests/Reg/positive/out/test039.out b/tests/Reg/positive/out/test039.out new file mode 100644 index 0000000000..dde92ddc1a --- /dev/null +++ b/tests/Reg/positive/out/test039.out @@ -0,0 +1 @@ +82 diff --git a/tests/Reg/positive/out/test040.out b/tests/Reg/positive/out/test040.out new file mode 100644 index 0000000000..85322d0b54 --- /dev/null +++ b/tests/Reg/positive/out/test040.out @@ -0,0 +1 @@ +79 diff --git a/tests/Reg/positive/test039.jvr b/tests/Reg/positive/test039.jvr new file mode 100644 index 0000000000..d952203eb2 --- /dev/null +++ b/tests/Reg/positive/test039.jvr @@ -0,0 +1,84 @@ +-- Copy & constant propagation + +type either { + left : integer -> either; + right : bool -> either; +} + +function main() : * { + tmp[0] = 7; + tmp[1] = tmp[0]; + tmp[0] = tmp[1]; + tmp[2] = tmp[0]; + -- tmp[2] = 7 + + tmp[1] = tmp[0]; + tmp[0] = add tmp[1] 1; + tmp[2] = add tmp[2] tmp[1]; + -- tmp[2] = 14 + + tmp[0] = 19; + tmp[1] = tmp[0]; + tmp[0] = add tmp[1] 1; + tmp[3] = add tmp[0] tmp[1]; + tmp[4] = tmp[3]; + tmp[2] = add tmp[4] tmp[2]; + -- tmp[2] = 53 + + tmp[1] = eq tmp[2] 54; + tmp[0] = 4; + tmp[3] = 3; + tmp[4] = tmp[0]; + tmp[5] = 4; + tmp[6] = tmp[5]; + br tmp[1] { + true: { + tmp[4] = 7; + }; + false: { + tmp[3] = tmp[6]; + }; + }; + tmp[2] = add tmp[2] tmp[4]; + tmp[2] = add tmp[2] tmp[3]; + -- tmp[2] = 61 + + tmp[0] = alloc left (3); + tmp[1] = 17; + tmp[3] = tmp[1]; + case[either] tmp[0] { + left: { + tmp[4] = tmp[0].left[0]; + tmp[1] = tmp[4]; + tmp[3] = tmp[1]; + }; + right: { + nop; + }; + }; + tmp[2] = add tmp[2] tmp[3]; + -- tmp[2] = 64 + + tmp[0] = alloc right (true); + tmp[1] = 17; + tmp[3] = tmp[1]; + case[either] tmp[0] { + left: { + tmp[1] = tmp[0].left[0]; + }; + right: { + br tmp[0].right[0] { + true: { + tmp[1] = add tmp[3] 1; + }; + false: { + nop; + }; + }; + }; + }; + tmp[2] = add tmp[2] tmp[1]; + -- tmp[2] = 82 + + ret tmp[2]; +} diff --git a/tests/Reg/positive/test040.jvr b/tests/Reg/positive/test040.jvr new file mode 100644 index 0000000000..0b17343945 --- /dev/null +++ b/tests/Reg/positive/test040.jvr @@ -0,0 +1,95 @@ +-- Copy & constant propagation with branches + +type either { + left : integer -> either; + right : bool -> either; +} + +function main() : * { + tmp[0] = 7; + tmp[1] = tmp[0]; + tmp[0] = tmp[1]; + tmp[2] = tmp[0]; + -- tmp[2] = 7 + + tmp[1] = tmp[0]; + tmp[0] = add tmp[1] 1; + tmp[2] = add tmp[2] tmp[1]; + -- tmp[2] = 14 + + tmp[0] = 19; + tmp[1] = tmp[0]; + tmp[0] = add tmp[1] 1; + tmp[3] = add tmp[0] tmp[1]; + tmp[4] = tmp[3]; + tmp[2] = add tmp[4] tmp[2]; + -- tmp[2] = 53 + + tmp[1] = eq tmp[2] 54; + tmp[3] = 3; + tmp[5] = 4; + tmp[6] = tmp[5]; + br tmp[1], out: tmp[3] { + true: { + nop; + }; + false: { + tmp[3] = tmp[6]; + }; + }; + tmp[2] = add tmp[2] tmp[3]; + -- tmp[2] = 57 + + tmp[0] = 1; + br tmp[1], out: tmp[3] { + true: { + tmp[0] = 4; + tmp[3] = tmp[0]; + }; + false: { + tmp[3] = tmp[0]; + }; + }; + tmp[2] = add tmp[2] tmp[3]; + -- tmp[2] = 58 + + tmp[0] = alloc left (3); + tmp[1] = 17; + tmp[3] = tmp[1]; + case[either] tmp[0], out: tmp[3] { + left: { + tmp[4] = tmp[0].left[0]; + tmp[1] = tmp[4]; + tmp[3] = tmp[1]; + }; + right: { + nop; + }; + }; + tmp[2] = add tmp[2] tmp[3]; + -- tmp[2] = 61 + + tmp[0] = alloc right (true); + tmp[1] = 17; + tmp[3] = tmp[1]; + case[either] tmp[0], out: tmp[3] { + left: { + nop; + }; + right: { + br tmp[0].right[0], out: tmp[3] { + true: { + tmp[1] = add tmp[3] 1; + tmp[3] = tmp[1]; + }; + false: { + nop; + }; + }; + }; + }; + tmp[2] = add tmp[2] tmp[3]; + -- tmp[2] = 79 + + ret tmp[2]; +}