Skip to content

Commit

Permalink
Fix findSCCNodes' as well as BranchMuxToMerge example
Browse files Browse the repository at this point in the history
  • Loading branch information
ymherklotz committed Nov 8, 2024
1 parent 98cc9a3 commit 7c19a04
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 23 deletions.
5 changes: 4 additions & 1 deletion DataflowRewriter/Rewriter.lean
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def findSCCNodes' (succ : Std.HashMap String (Array String)) (startN endN : Stri
else
let nextNodes ← succ[x]?.map (·.toList)
if "_leaf_" ∈ nextNodes then none
if startN ∈ nextNodes then none
let nextNodes' := nextNodes.filter (· ∉ visited')
go w visited' (nextNodes' ++ q)

Expand All @@ -354,6 +355,8 @@ Find all nodes in between two nodes by performing a DFS that checks that one has
never reached an output node.
-/
def findSCCNodes (g : ExprHigh String) (startN endN : String) : Option (List String) := do
findSCCNodes' (← fullCalcSucc g) startN endN
let l ← findSCCNodes' (← fullCalcSucc g) startN endN
let l' ← findSCCNodes' (← fullCalcSucc g.invert) startN endN
return l.union l'

end DataflowRewriter
45 changes: 23 additions & 22 deletions DataflowRewriter/Rewrites/BranchMuxToMerge.lean
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def matchModLeft (g : ExprHigh String) : RewriteResult (List String) := do
unless branch_nn.typ = "branch" && branch_nn.inputPort = "cond" do return none
let (.some mux_nn) := followOutput g inst "out1" | return none
unless mux_nn.typ = "mux" && mux_nn.inputPort = "cond" do return none
let (.some prev_mux_nn) := followInput g inst "inp0" | return none
let (.some after_branch_nn) := followOutput g inst "true" | return none
let (.some prev_mux_nn) := followInput g mux_nn.inst "inp0" | return none
let (.some after_branch_nn) := followOutput g branch_nn.inst "true" | return none
let (.some scc) := findSCCNodes g after_branch_nn.inst prev_mux_nn.inst | return none
return some (inst :: scc)
return some scc
) none | throw .done
return list

Expand All @@ -78,10 +78,10 @@ def matchModRight (g : ExprHigh String) : RewriteResult (List String) := do
unless branch_nn.typ = "branch" && branch_nn.inputPort = "cond" do return none
let (.some mux_nn) := followOutput g inst "out1" | return none
unless mux_nn.typ = "mux" && mux_nn.inputPort = "cond" do return none
let (.some prev_mux_nn) := followInput g inst "inp1" | return none
let (.some after_branch_nn) := followOutput g inst "false" | return none
let (.some prev_mux_nn) := followInput g mux_nn.inst "inp1" | return none
let (.some after_branch_nn) := followOutput g branch_nn.inst "false" | return none
let (.some scc) := findSCCNodes g after_branch_nn.inst prev_mux_nn.inst | return none
return some (inst :: scc)
return some scc
) none | throw .done
return list

Expand All @@ -103,15 +103,15 @@ def matcher (g : ExprHigh String) : RewriteResult (List String) := do
return list

def lhs' : ExprHigh String := [graph|
i_branch [mod = "io"];
i_cond [mod = "io"];
o_out [mod = "io"];
i_branch [type = "io"];
i_cond [type = "io"];
o_out [type = "io"];

branch [mod = "branch"];
m_left [mod = "mod_left"];
m_right [mod = "mod_right"];
mux [mod = "mux"];
fork [mod = "fork"];
branch [type = "branch"];
m_left [type = "mod_left"];
m_right [type = "mod_right"];
mux [type = "mux"];
fork [type = "fork"];

i_branch -> branch [inp = "val"];
i_cond -> fork [inp = "inp0"];
Expand All @@ -127,6 +127,7 @@ def lhs' : ExprHigh String := [graph|
m_right -> mux [out = "m_out", inp = "inp1"];
]

#eval matchModRight lhs'
#eval IO.print lhs'

def lhs := lhs'.extract ["fork", "m_left", "mux", "m_right", "branch"] |>.get rfl
Expand All @@ -136,14 +137,14 @@ theorem double_check_empty_snd : lhs.snd = ExprHigh.mk ∅ ∅ := by rfl
def lhsLower := lhs.fst.lower.get rfl

def rhs : ExprHigh String := [graph|
i_branch [mod = "io"];
i_cond [mod = "io"];
o_out [mod = "io"];

branch [mod = "branch"];
m_left [mod = "mod_left"];
m_right [mod = "mod_right"];
merge [mod = "merge"];
i_branch [type = "io"];
i_cond [type = "io"];
o_out [type = "io"];

branch [type = "branch"];
m_left [type = "mod_left"];
m_right [type = "mod_right"];
merge [type = "merge"];

i_branch -> branch [inp = "val"];
i_cond -> branch [inp = "cond"];
Expand Down

0 comments on commit 7c19a04

Please sign in to comment.