Skip to content

Commit

Permalink
Fix rewriting pattern matcher and try to run on example
Browse files Browse the repository at this point in the history
  • Loading branch information
ymherklotz committed Nov 8, 2024
1 parent 633bcf4 commit 8564903
Showing 1 changed file with 45 additions and 3 deletions.
48 changes: 45 additions & 3 deletions DataflowRewriter/Rewrites/BranchMuxToMerge.lean
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,15 @@ def matcher (g : ExprHigh String) : RewriteResult (List String) := do
unless branch_nn.typ = "branch" && branch_nn.inputPort = "cond" do return none
let (.some mux_nn) := followOutput g inst "out1" | return none
unless mux_nn.typ = "mux" && mux_nn.inputPort = "cond" do return none
let (.some scc) := findSCCNodes g branch_nn.inst mux_nn.inst | return none
return some (inst :: scc)
let (.some after_branch_nn0) := followOutput g branch_nn.inst "true" | return none
let (.some after_branch_nn1) := followOutput g branch_nn.inst "false" | return none
-- Now that we go in two directions, we need two calls to findSCCNodes,
-- otherwise it will follow the fork and reach an output.
--
-- However, now that we have already abstracted the modules, we don't need to search anymore.
-- let (.some scc0) := findSCCNodes g after_branch_nn0.inst prev_mux_nn0.inst | return none
-- let (.some scc1) := findSCCNodes g after_branch_nn1.inst prev_mux_nn1.inst | return none
return some [inst, after_branch_nn0.inst, mux_nn.inst, after_branch_nn1.inst, branch_nn.inst]
) none | throw .done
return list

Expand Down Expand Up @@ -127,7 +134,9 @@ def lhs' : ExprHigh String := [graph|
m_right -> mux [out = "m_out", inp = "inp1"];
]

#eval matchModRight lhs'
-- #eval IO.print lhs'
-- #eval IO.print lhs'.invert
#eval matcher lhs'
#eval IO.print lhs'

def lhs := lhs'.extract ["fork", "m_left", "mux", "m_right", "branch"] |>.get rfl
Expand Down Expand Up @@ -173,4 +182,37 @@ def rewrite : Rewrite String :=
input_expr := lhsLower,
output_expr := rhsLower }

namespace TEST

def lhs' : ExprHigh String := [graph|
i_branch [type = "io"];
i_cond [type = "io"];
o_out [type = "io"];

branch [type = "branch"];
m_left1 [type = "mod_left1"];
m_left2 [type = "mod_left2"];
m_right [type = "mod_right"];
mux [type = "mux"];
fork [type = "fork"];

i_branch -> branch [inp = "val"];
i_cond -> fork [inp = "inp0"];
fork -> branch [out = "out0", inp = "cond"];
fork -> mux [out = "out1", inp = "cond"];
m_left1 -> m_left2 [out = "out0", inp = "inp0"];

mux -> o_out [out = "out0"];

branch -> m_left1 [out = "true", inp = "m_in"];
branch -> m_right [out = "false", inp = "m_in"];

m_left2 -> mux [out = "m_out", inp = "inp0"];
m_right -> mux [out = "m_out", inp = "inp1"];
]

#eval rewrite.run "rw0_" lhs'

end TEST

end DataflowRewriter.BranchMuxToMerge

0 comments on commit 8564903

Please sign in to comment.