From 7c19a044a06fdf83e787660add06697ec0c391ad Mon Sep 17 00:00:00 2001 From: Yann Herklotz Date: Fri, 8 Nov 2024 14:48:27 +0100 Subject: [PATCH] Fix findSCCNodes' as well as BranchMuxToMerge example --- DataflowRewriter/Rewriter.lean | 5 ++- .../Rewrites/BranchMuxToMerge.lean | 45 ++++++++++--------- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/DataflowRewriter/Rewriter.lean b/DataflowRewriter/Rewriter.lean index 00d5c13..fcd7183 100644 --- a/DataflowRewriter/Rewriter.lean +++ b/DataflowRewriter/Rewriter.lean @@ -346,6 +346,7 @@ def findSCCNodes' (succ : Std.HashMap String (Array String)) (startN endN : Stri else let nextNodes ← succ[x]?.map (·.toList) if "_leaf_" ∈ nextNodes then none + if startN ∈ nextNodes then none let nextNodes' := nextNodes.filter (· ∉ visited') go w visited' (nextNodes' ++ q) @@ -354,6 +355,8 @@ Find all nodes in between two nodes by performing a DFS that checks that one has never reached an output node. -/ def findSCCNodes (g : ExprHigh String) (startN endN : String) : Option (List String) := do - findSCCNodes' (← fullCalcSucc g) startN endN + let l ← findSCCNodes' (← fullCalcSucc g) startN endN + let l' ← findSCCNodes' (← fullCalcSucc g.invert) startN endN + return l.union l' end DataflowRewriter diff --git a/DataflowRewriter/Rewrites/BranchMuxToMerge.lean b/DataflowRewriter/Rewrites/BranchMuxToMerge.lean index 79859a5..cb0f8db 100644 --- a/DataflowRewriter/Rewrites/BranchMuxToMerge.lean +++ b/DataflowRewriter/Rewrites/BranchMuxToMerge.lean @@ -60,10 +60,10 @@ def matchModLeft (g : ExprHigh String) : RewriteResult (List String) := do unless branch_nn.typ = "branch" && branch_nn.inputPort = "cond" do return none let (.some mux_nn) := followOutput g inst "out1" | return none unless mux_nn.typ = "mux" && mux_nn.inputPort = "cond" do return none - let (.some prev_mux_nn) := followInput g inst "inp0" | return none - let (.some after_branch_nn) := followOutput g inst "true" | return none + let (.some prev_mux_nn) := followInput g mux_nn.inst "inp0" | return none + let (.some after_branch_nn) := followOutput g branch_nn.inst "true" | return none let (.some scc) := findSCCNodes g after_branch_nn.inst prev_mux_nn.inst | return none - return some (inst :: scc) + return some scc ) none | throw .done return list @@ -78,10 +78,10 @@ def matchModRight (g : ExprHigh String) : RewriteResult (List String) := do unless branch_nn.typ = "branch" && branch_nn.inputPort = "cond" do return none let (.some mux_nn) := followOutput g inst "out1" | return none unless mux_nn.typ = "mux" && mux_nn.inputPort = "cond" do return none - let (.some prev_mux_nn) := followInput g inst "inp1" | return none - let (.some after_branch_nn) := followOutput g inst "false" | return none + let (.some prev_mux_nn) := followInput g mux_nn.inst "inp1" | return none + let (.some after_branch_nn) := followOutput g branch_nn.inst "false" | return none let (.some scc) := findSCCNodes g after_branch_nn.inst prev_mux_nn.inst | return none - return some (inst :: scc) + return some scc ) none | throw .done return list @@ -103,15 +103,15 @@ def matcher (g : ExprHigh String) : RewriteResult (List String) := do return list def lhs' : ExprHigh String := [graph| - i_branch [mod = "io"]; - i_cond [mod = "io"]; - o_out [mod = "io"]; + i_branch [type = "io"]; + i_cond [type = "io"]; + o_out [type = "io"]; - branch [mod = "branch"]; - m_left [mod = "mod_left"]; - m_right [mod = "mod_right"]; - mux [mod = "mux"]; - fork [mod = "fork"]; + branch [type = "branch"]; + m_left [type = "mod_left"]; + m_right [type = "mod_right"]; + mux [type = "mux"]; + fork [type = "fork"]; i_branch -> branch [inp = "val"]; i_cond -> fork [inp = "inp0"]; @@ -127,6 +127,7 @@ def lhs' : ExprHigh String := [graph| m_right -> mux [out = "m_out", inp = "inp1"]; ] +#eval matchModRight lhs' #eval IO.print lhs' def lhs := lhs'.extract ["fork", "m_left", "mux", "m_right", "branch"] |>.get rfl @@ -136,14 +137,14 @@ theorem double_check_empty_snd : lhs.snd = ExprHigh.mk ∅ ∅ := by rfl def lhsLower := lhs.fst.lower.get rfl def rhs : ExprHigh String := [graph| - i_branch [mod = "io"]; - i_cond [mod = "io"]; - o_out [mod = "io"]; - - branch [mod = "branch"]; - m_left [mod = "mod_left"]; - m_right [mod = "mod_right"]; - merge [mod = "merge"]; + i_branch [type = "io"]; + i_cond [type = "io"]; + o_out [type = "io"]; + + branch [type = "branch"]; + m_left [type = "mod_left"]; + m_right [type = "mod_right"]; + merge [type = "merge"]; i_branch -> branch [inp = "val"]; i_cond -> branch [inp = "cond"];