Skip to content

Commit

Permalink
Fix the BranchMuxToMerge example by adding a bag
Browse files Browse the repository at this point in the history
  • Loading branch information
ymherklotz committed Nov 9, 2024
1 parent 42a147f commit bde4cfd
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 19 deletions.
13 changes: 12 additions & 1 deletion DataflowRewriter/Component.lean
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ Essentially tagger + join without internal rule
internals := []
}

@[drunfold] def sink (T : Type _) : NatModule Unit :=
{ inputs := [(0, ⟨ T, λ _ _ _ => True ⟩)].toAssocList,
outputs := ∅,
internals := []
}

@[drunfold] def unary_op {α R} (f : α → R): NatModule (List α) :=
{ inputs := [
(0, ⟨ α, λ oldList newElement newList => newList = newElement :: oldList ⟩)
Expand Down Expand Up @@ -285,6 +291,8 @@ namespace DataflowRewriter.StringModule

@[drunfold] def split T T' := NatModule.split T T' |>.stringify

@[drunfold] def sink T := NatModule.sink T |>.stringify

@[drunfold] def branch T := NatModule.branch T
|>.stringify
-- |>.mapIdent (λ | 0 => "val" | _ => "cond") (λ | 0 => "true" | _ => "false")
Expand Down Expand Up @@ -331,7 +339,10 @@ def ε (Tag : Type) [DecidableEq Tag] (T : Type) [Inhabited T] : IdentMap String
, ("TaggedSplit", ⟨_, StringModule.split Tag T⟩)

, ("Merge", ⟨_, StringModule.merge T 2⟩)
, ("TagggedMerge", ⟨_, StringModule.merge (Tag × T) 2⟩)
, ("TaggedMerge", ⟨_, StringModule.merge (Tag × T) 2⟩)

, ("Sink", ⟨_, StringModule.sink T⟩)
, ("TaggedSink", ⟨_, StringModule.sink Tag⟩)

, ("Fork", ⟨_, StringModule.fork T 2⟩)
, ("Fork3", ⟨_, StringModule.fork T 3⟩)
Expand Down
43 changes: 25 additions & 18 deletions DataflowRewriter/Rewrites/BranchMuxToMerge.lean
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ Match the left module so that it can be abstracted.
def matchModLeft (g : ExprHigh String) : RewriteResult (List String) := do
let (.some list) ← g.modules.foldlM (λ s inst (pmap, typ) => do
if s.isSome then return s
unless typ = "fork" do return none
unless typ = "TaggedFork" do return none
let (.some branch_nn) := followOutput g inst "out0" | return none
unless branch_nn.typ = "branch" && branch_nn.inputPort = "cond" do return none
unless branch_nn.typ = "TaggedBranch" && branch_nn.inputPort = "cond" do return none
let (.some mux_nn) := followOutput g inst "out1" | return none
unless mux_nn.typ = "mux" && mux_nn.inputPort = "cond" do return none
unless mux_nn.typ = "TaggedMux" && mux_nn.inputPort = "cond" do return none
let (.some bag_nn) := followOutput g mux_nn.inst "out0" | return none
unless bag_nn.typ = "Bag" do return none
let (.some prev_mux_nn) := followInput g mux_nn.inst "inp0" | return none
let (.some after_branch_nn) := followOutput g branch_nn.inst "true" | return none
let (.some scc) := findSCCNodes g after_branch_nn.inst prev_mux_nn.inst | return none
Expand All @@ -73,11 +75,13 @@ Match the right module so that it can be abstracted.
def matchModRight (g : ExprHigh String) : RewriteResult (List String) := do
let (.some list) ← g.modules.foldlM (λ s inst (pmap, typ) => do
if s.isSome then return s
unless typ = "fork" do return none
unless typ = "TaggedFork" do return none
let (.some branch_nn) := followOutput g inst "out0" | return none
unless branch_nn.typ = "branch" && branch_nn.inputPort = "cond" do return none
unless branch_nn.typ = "TaggedBranch" && branch_nn.inputPort = "cond" do return none
let (.some mux_nn) := followOutput g inst "out1" | return none
unless mux_nn.typ = "mux" && mux_nn.inputPort = "cond" do return none
unless mux_nn.typ = "TaggedMux" && mux_nn.inputPort = "cond" do return none
let (.some bag_nn) := followOutput g mux_nn.inst "out0" | return none
unless bag_nn.typ = "Bag" do return none
let (.some prev_mux_nn) := followInput g mux_nn.inst "inp1" | return none
let (.some after_branch_nn) := followOutput g branch_nn.inst "false" | return none
let (.some scc) := findSCCNodes g after_branch_nn.inst prev_mux_nn.inst | return none
Expand All @@ -92,11 +96,13 @@ to match the graph.
def matcher (g : ExprHigh String) : RewriteResult (List String) := do
let (.some list) ← g.modules.foldlM (λ s inst (pmap, typ) => do
if s.isSome then return s
unless typ = "fork" do return none
unless typ = "TaggedFork" do return none
let (.some branch_nn) := followOutput g inst "out0" | return none
unless branch_nn.typ = "branch" && branch_nn.inputPort = "cond" do return none
unless branch_nn.typ = "TaggedBranch" && branch_nn.inputPort = "cond" do return none
let (.some mux_nn) := followOutput g inst "out1" | return none
unless mux_nn.typ = "mux" && mux_nn.inputPort = "cond" do return none
unless mux_nn.typ = "TaggedMux" && mux_nn.inputPort = "cond" do return none
let (.some bag_nn) := followOutput g mux_nn.inst "out0" | return none
unless bag_nn.typ = "Bag" do return none
let (.some after_branch_nn0) := followOutput g branch_nn.inst "true" | return none
let (.some after_branch_nn1) := followOutput g branch_nn.inst "false" | return none
-- Now that we go in two directions, we need two calls to findSCCNodes,
Expand All @@ -105,7 +111,7 @@ def matcher (g : ExprHigh String) : RewriteResult (List String) := do
-- However, now that we have already abstracted the modules, we don't need to search anymore.
-- let (.some scc0) := findSCCNodes g after_branch_nn0.inst prev_mux_nn0.inst | return none
-- let (.some scc1) := findSCCNodes g after_branch_nn1.inst prev_mux_nn1.inst | return none
return some [inst, after_branch_nn0.inst, mux_nn.inst, after_branch_nn1.inst, branch_nn.inst]
return some [inst, after_branch_nn0.inst, mux_nn.inst, after_branch_nn1.inst, branch_nn.inst, bag_nn.inst]
) none | throw .done
return list

Expand All @@ -114,19 +120,20 @@ def lhs' : ExprHigh String := [graph|
i_cond [type = "io"];
o_out [type = "io"];

branch [type = "branch"];
branch [type = "TaggedBranch"];
m_left [type = "mod_left"];
m_right [type = "mod_right"];
mux [type = "mux"];
bag [type = "bag"];
fork [type = "fork"];
mux [type = "TaggedMux"];
fork [type = "TaggedFork"];
bag [type = "Bag"];

i_branch -> branch [inp = "val"];
i_cond -> fork [inp = "inp0"];
fork -> branch [out = "out0", inp = "cond"];
fork -> mux [out = "out1", inp = "cond"];

mux -> o_out [out = "out0"];
mux -> bag [out = "out0", inp = "inp0"];
bag -> o_out [out = "out0"];

branch -> m_left [out = "true", inp = "m_in"];
branch -> m_right [out = "false", inp = "m_in"];
Expand All @@ -140,7 +147,7 @@ def lhs' : ExprHigh String := [graph|
#eval matcher lhs'
#eval IO.print lhs'

def lhs := lhs'.extract ["fork", "m_left", "mux", "m_right", "branch"] |>.get rfl
def lhs := lhs'.extract ["fork", "m_left", "mux", "m_right", "branch", "bag"] |>.get rfl

theorem double_check_empty_snd : lhs.snd = ExprHigh.mk ∅ ∅ := by rfl

Expand All @@ -151,10 +158,10 @@ def rhs : ExprHigh String := [graph|
i_cond [type = "io"];
o_out [type = "io"];

branch [type = "branch"];
branch [type = "TaggedBranch"];
m_left [type = "mod_left"];
m_right [type = "mod_right"];
merge [type = "merge"];
merge [type = "TaggedMerge"];

i_branch -> branch [inp = "val"];
i_cond -> branch [inp = "cond"];
Expand Down

0 comments on commit bde4cfd

Please sign in to comment.