Skip to content

Commit

Permalink
Merge branch 'aya-rewrites'
Browse files Browse the repository at this point in the history
  • Loading branch information
ymherklotz committed Nov 6, 2024
2 parents c31a7d1 + e29a124 commit 9a96e24
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 55 deletions.
2 changes: 1 addition & 1 deletion DataflowRewriter/ExprHigh.lean
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class FreshIdent (Ident : Type _) where
next : Nat → Ident

instance : FreshIdent String where
next n := "mod" ++ toString n
next n := "type" ++ toString n

instance : FreshIdent Nat where
next := id
Expand Down
42 changes: 21 additions & 21 deletions DataflowRewriter/ExprHighElaborator.lean
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def findStxStr (n : Name) (stx : Array Syntax) : MetaM (Option String) := do
for pair in stx do
if pair[0].getId = n then
let some out' := pair[2][0].isStrLit?
| throwErrorAt pair[2][0] "`mod` attribute is not a string"
| throwErrorAt pair[2][0] "`type` attribute is not a string"
out := some out'
return out

Expand Down Expand Up @@ -113,9 +113,9 @@ def dotGraphElab : TermElab := λ stx _typ? => do
match low_stmnt with
| `(dot_stmnt| $i:ident $[[$[$el:dot_attr],*]]? ) =>
let some el := el
| throwErrorAt i "No `mod` attribute found at node"
let some modId ← findStxStr `mod el
| throwErrorAt i "No `mod` attribute found at node"
| throwErrorAt i "No `type` attribute found at node"
let some modId ← findStxStr `type el
| throwErrorAt i "No `type` attribute found at node"
let mut modCluster : Bool := findStxBool `cluster el |>.getD false
match updateNodeMaps ⟨instMap, instTypeMap⟩ i.getId.toString modId modCluster with
| .ok ⟨a, b⟩ =>
Expand All @@ -127,7 +127,7 @@ def dotGraphElab : TermElab := λ stx _typ? => do
-- Error checking to report it early if the instance is not present in the
-- hashmap.
let some el := el
| throwErrorAt (mkListNode #[a, b]) "No `mod` attribute found at node"
| throwErrorAt (mkListNode #[a, b]) "No `type` attribute found at node"
let mut out ← (findStxStr `out el)
let mut inp ← (findStxStr `inp el)
match updateConnMaps ⟨instMap, instTypeMap⟩ conns a.getId.toString b.getId.toString out inp with
Expand Down Expand Up @@ -155,27 +155,27 @@ def dotGraphElab : TermElab := λ stx _typ? => do

-- namespace mergemod

def mergeHigh : ExprHigh String :=
[graph|
src0 [mod="src"];
snk0 [mod="snk"];
-- def mergeHigh : ExprHigh String :=
-- [graph|
-- src0 [mod="src"];
-- snk0 [mod="snk"];

fork1 [mod="fork"];
fork2 [mod="fork"];
merge1 [mod="merge"];
merge2 [mod="merge"];
-- fork1 [mod="fork"];
-- fork2 [mod="fork"];
-- merge1 [mod="merge"];
-- merge2 [mod="merge"];

src0 -> fork1 [out="0",inp="0"];
-- src0 -> fork1 [out="0",inp="0"];

fork1 -> fork2 [out="0",inp="0"];
-- fork1 -> fork2 [out="0",inp="0"];

fork1 -> merge1 [out="1",inp="0"];
fork2 -> merge1 [out="0",inp="1"];
fork2 -> merge2 [out="1",inp="0"];
-- fork1 -> merge1 [out="1",inp="0"];
-- fork2 -> merge1 [out="0",inp="1"];
-- fork2 -> merge2 [out="1",inp="0"];

merge1 -> merge2 [out="0",inp="1"];
-- merge1 -> merge2 [out="0",inp="1"];

merge2 -> snk0 [out="0",inp="0"];
]
-- merge2 -> snk0 [out="0",inp="0"];
-- ]

end DataflowRewriter
115 changes: 115 additions & 0 deletions DataflowRewriter/Rewrites/ForkRewrite.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/-
Copyright (c) 2024 VCA Lab, EPFL. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Yann Herklotz
-/

import DataflowRewriter.Rewriter
import DataflowRewriter.ExprHighElaborator

namespace DataflowRewriter.ForkRewrite

/--
The matcher takes in a dot graph and should return the cluster of nodes that
form the subgraph as a list of instance names.
-/
def matcher (g : ExprHigh String) : RewriteResult (List String) := do
let (.some list) ← g.modules.foldlM (λ nodes inst (pmap, typ) => do
if nodes.isSome then return nodes
unless typ = "Fork" do return none
let (.some nn) := followOutput g inst "out1" | return none
unless nn.typ = "Fork" && nn.inputPort = "inp0" do return none
return some [inst, nn.inst]
) none | throw .done
return list

@[drunfold] def Lhs : ExprHigh String := [graph|
out0 [type = "io"];
out1 [type = "io"];
out2 [type = "io"];
inp0 [type = "io"];

fork1 [type = "Fork"];
fork2 [type = "Fork"];

inp0 -> fork1 [inp = "inp0"];

fork1 -> out0 [out = "out0"];
fork1 -> fork2 [out = "out1", inp = "inp0"];

fork2 -> out1 [out = "out0"];
fork2 -> out2 [out = "out1"];
]

/--
To get instances in a predictable order, it's a good idea to extract the whole
graph once with the nodes in the order that you want to provide them in the
pattern-matcher. In this case we want fork1 to be listed before fork2.
-/
def LhsOrdered := Lhs.extract ["fork1", "fork2"] |>.get rfl

/--
Graph extraction gives back two graphs, the subgraph and the rest of the graph.
Here we just double check that the rest of the graph is empty, implying we
extracted the whole graph. The proof of `rfl` should always work for this.
-/
theorem double_check_empty_snd : LhsOrdered.snd = ExprHigh.mk ∅ ∅ := by rfl

/--
We then use the extracted graph to lower to ExprLow, which ensures the right
ordering of instances.
-/
def LhsLower := LhsOrdered.fst.lower.get rfl

@[drunfold] def Rhs : ExprHigh String := [graph|
out0 [type = "io"];
out1 [type = "io"];
out2 [type = "io"];
inp0 [type = "io"];

fork3 [type = "Fork3"];

inp0 -> fork3 [inp = "inp0"];

fork3 -> out0 [out = "out0"];
fork3 -> out1 [out = "out1"];
fork3 -> out2 [out = "out2"];
]

def RhsLower := Rhs.lower.get rfl

def rewrite : Rewrite String :=
{ pattern := matcher,
input_expr := LhsLower,
output_expr := RhsLower }

namespace TestRewriter

def fullCircuit : ExprHigh String :=
[graph|
src0 [type="io"];
snk0 [type="io"];

fork1 [type="Fork"];
fork2 [type="Fork"];
merge2 [type="merge"];
merge1 [type="merge"];

src0 -> fork1 [inp="inp0"];

fork1 -> fork2 [out="out1",inp="inp0"];

fork1 -> merge1 [out="out0",inp="inp0"];
fork2 -> merge1 [out="out0",inp="inp1"];
fork2 -> merge2 [out="out1",inp="inp1"];

merge1 -> merge2 [out="out0",inp="inp0"];

merge2 -> snk0 [out="out0"];
]

#eval DataflowRewriter.rewrite "rw0_" fullCircuit rewrite |> toString |> IO.print

end TestRewriter

end DataflowRewriter.ForkRewrite
38 changes: 19 additions & 19 deletions DataflowRewriter/Rewrites/MergeRewrite.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@ form the subgraph as a list of instance names.
def matcher (g : ExprHigh String) : RewriteResult (List String) := do
let (.some list) ← g.modules.foldlM (λ nodes inst (pmap, typ) => do
if nodes.isSome then return nodes
unless typ = "merge" do return none
unless typ = "Merge" do return none
let (.some nn) := followOutput g inst "out0" | return none
unless nn.typ = "merge" && nn.inputPort = "inp0" do return none
unless nn.typ = "Merge" && nn.inputPort = "inp0" do return none
return some [inst, nn.inst]
) none | throw .done
return list

@[drunfold] def mergeLhs : ExprHigh String := [graph|
out0 [mod = "io"];
inp0 [mod = "io"];
inp1 [mod = "io"];
inp2 [mod = "io"];
out0 [type = "io"];
inp0 [type = "io"];
inp1 [type = "io"];
inp2 [type = "io"];

merge1 [mod = "merge"];
merge2 [mod = "merge"];
merge1 [type = "Merge"];
merge2 [type = "Merge"];

inp0 -> merge1 [inp = "inp0"];
inp1 -> merge1 [inp = "inp1"];
Expand Down Expand Up @@ -64,12 +64,12 @@ ordering of instances.
def mergeLhsLower := mergeLhsOrdered.fst.lower.get rfl

@[drunfold] def mergeRhs : ExprHigh String := [graph|
out0 [mod = "io"];
inp0 [mod = "io"];
inp1 [mod = "io"];
inp2 [mod = "io"];
out0 [type = "io"];
inp0 [type = "io"];
inp1 [type = "io"];
inp2 [type = "io"];

merge3 [mod = "merge3"];
merge3 [type = "Merge3"];

inp0 -> merge3 [inp = "inp0"];
inp1 -> merge3 [inp = "inp1"];
Expand All @@ -91,13 +91,13 @@ namespace TestRewriter

def mergeHigh : ExprHigh String :=
[graph|
src0 [mod="io"];
snk0 [mod="io"];
src0 [type="io"];
snk0 [type="io"];

fork1 [mod="fork"];
fork2 [mod="fork"];
merge2 [mod="merge"];
merge1 [mod="merge"];
fork1 [type="Fork"];
fork2 [type="Fork"];
merge2 [type="Merge"];
merge1 [type="Merge"];

src0 -> fork1 [inp="inp0"];

Expand Down
3 changes: 2 additions & 1 deletion Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import DataflowRewriter.DotParser
import DataflowRewriter.Rewriter
import DataflowRewriter.Rewrites.MergeRewrite
import DataflowRewriter.DynamaticPrinter
import DataflowRewriter.Rewrites.ForkRewrite

open Batteries (AssocList)

Expand Down Expand Up @@ -62,7 +63,7 @@ def main (args : List String) : IO Unit := do
let fileContents ← IO.FS.readFile parsed.inputFile.get!
let (exprHigh, assoc) ← IO.ofExcept fileContents.toExprHigh
let rewrittenExprHigh ← IO.ofExcept <|
rewrite_loop "rw" exprHigh [MergeRewrite.rewrite] 100
rewrite_loop "rw" exprHigh [MergeRewrite.rewrite, ForkRewrite.rewrite] 100
let some l := dynamaticString rewrittenExprHigh assoc.inverse
| IO.eprintln s!"Failed to print ExprHigh: {rewrittenExprHigh}"
match parsed.outputFile with
Expand Down
48 changes: 48 additions & 0 deletions tests/dynamatic-no-control-flow.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
Digraph G {
splines=spline;
//DHLS version: 0.1.1" [shape = "none" pos = "20,20!"]

subgraph cluster_0 {
color = "darkgreen";
label = "block1";
"arg" [type = "Entry", bbID= 1, in = "in1:32", out = "out1:32", tagged=false, taggers_num=0, tagger_id=-1];
"mul_0" [type = "Operator", bbID= 1, op = "mul_op", in = "in1:32 in2:32 ", out = "out1:32 ", delay=0.000, latency=4, II=1, tagged=false, taggers_num=0, tagger_id=-1];
"cst_0" [type = "Constant", bbID= 1, in = "in1:32", out = "out1:32", value = "0x00000002", tagged=false, taggers_num=0, tagger_id=-1];
"add_1" [type = "Operator", bbID= 1, op = "add_op", in = "in1:32 in2:32 ", out = "out1:32 ", delay=1.693, latency=0, II=1, tagged=false, taggers_num=0, tagger_id=-1];
"cst_1" [type = "Constant", bbID= 1, in = "in1:32", out = "out1:32", value = "0x00000002", tagged=false, taggers_num=0, tagger_id=-1];
"shl_2" [type = "Operator", bbID= 1, op = "shl_op", in = "in1:32 in2:32 ", out = "out1:32 ", delay=0.000, latency=0, II=1, tagged=false, taggers_num=0, tagger_id=-1];
"add_3" [type = "Operator", bbID= 1, op = "add_op", in = "in1:32 in2:32 ", out = "out1:32 ", delay=1.693, latency=0, II=1, tagged=false, taggers_num=0, tagger_id=-1];
"cst_2" [type = "Constant", bbID= 1, in = "in1:32", out = "out1:32", value = "0x0000000A", tagged=false, taggers_num=0, tagger_id=-1];
"mul_4" [type = "Operator", bbID= 1, op = "mul_op", in = "in1:32 in2:32 ", out = "out1:32 ", delay=0.000, latency=4, II=1, tagged=false, taggers_num=0, tagger_id=-1];
"mul_5" [type = "Operator", bbID= 1, op = "mul_op", in = "in1:32 in2:32 ", out = "out1:32 ", delay=0.000, latency=4, II=1, tagged=false, taggers_num=0, tagger_id=-1];
"ret_0" [type = "Operator", bbID= 1, op = "ret_op", in = "in1:32 ", out = "out1:32 ", delay=0.000, latency=0, II=1, tagged=false, taggers_num=0, tagger_id=-1];
"start_0" [type = "Entry", control= "true", bbID= 1, in = "in1:0", out = "out1:0", tagged=false, taggers_num=0, tagger_id=-1];
"fork_0" [type = "Fork", bbID= 1, in = "in1:32", out = "out1:32 out2:32 out3:32 ", tagged=false, taggers_num=0, tagger_id=-1];
"fork_1" [type = "Fork", bbID= 1, in = "in1:32", out = "out1:32 out2:32 ", tagged=false, taggers_num=0, tagger_id=-1];
"forkC_1" [type = "Fork", bbID= 1, in = "in1:0", out = "out1:0 out2:0 out3:0 ", tagged=false, taggers_num=0, tagger_id=-1];

}
"end_0" [type = "Exit", bbID= 0, in = " in1:32 ", out = "out1:32" ];

"arg" -> "fork_0" [color = "red", from = "out1", to = "in1"];
"mul_0" -> "add_1" [color = "red", from = "out1", to = "in1"];
"cst_0" -> "add_1" [color = "red", from = "out1", to = "in2"];
"add_1" -> "add_3" [color = "red", from = "out1", to = "in1"];
"cst_1" -> "shl_2" [color = "red", from = "out1", to = "in2"];
"shl_2" -> "add_3" [color = "red", from = "out1", to = "in2"];
"add_3" -> "mul_4" [color = "red", from = "out1", to = "in1"];
"cst_2" -> "mul_4" [color = "red", from = "out1", to = "in2"];
"mul_4" -> "mul_5" [color = "red", from = "out1", to = "in1"];
"mul_5" -> "ret_0" [color = "red", from = "out1", to = "in1"];
"ret_0" -> "end_0" [color = "red", from = "out1", to = "in1"];
"start_0" -> "forkC_1" [color = "gold3", from = "out1", to = "in1"];
"fork_0" -> "mul_0" [color = "red", from = "out1", to = "in1"];
"fork_0" -> "mul_0" [color = "red", from = "out2", to = "in2"];
"fork_0" -> "fork_1" [color = "red", from = "out3", to = "in1"];
"fork_1" -> "shl_2" [color = "red", from = "out1", to = "in1"];
"fork_1" -> "mul_5" [color = "red", from = "out2", to = "in2"];
"forkC_1" -> "cst_0" [color = "gold3", from = "out1", to = "in1"];
"forkC_1" -> "cst_1" [color = "gold3", from = "out2", to = "in1"];
"forkC_1" -> "cst_2" [color = "gold3", from = "out3", to = "in1"];

}
26 changes: 13 additions & 13 deletions tests/fork_merge.dot
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
digraph {
src0 [mod="io"];
snk0 [mod="io"];
"src0" [type="io"];
"snk0" [type="io"];

fork1 [mod="fork"];
fork2 [mod="fork"];
merge2 [mod="merge"];
merge1 [mod="merge"];
"fork1" [type="Fork"];
"fork2" [type="Fork"];
"merge2" [type="Merge"];
"merge1" [type="Merge"];

src0 -> fork1 [inp="inp0"];
"src0" -> "fork1" [inp="inp0"];

fork1 -> fork2 [out="out1",inp="inp0"];
"fork1" -> "fork2" [out="out1",inp="inp0"];

fork1 -> merge1 [out="out0",inp="inp0"];
fork2 -> merge1 [out="out0",inp="inp1"];
fork2 -> merge2 [out="out1",inp="inp1"];
"fork1" -> "merge1" [out="out0",inp="inp0"];
"fork2" -> "merge1" [out="out0",inp="inp1"];
"fork2" -> "merge2" [out="out1",inp="inp1"];

merge1 -> merge2 [out="out0",inp="inp0"];
"merge1" -> "merge2" [out="out0",inp="inp0"];

merge2 -> snk0 [out="out0"];
"merge2" -> "snk0" [out="out0"];
}

0 comments on commit 9a96e24

Please sign in to comment.