Skip to content

Commit

Permalink
Add new components with names
Browse files Browse the repository at this point in the history
  • Loading branch information
ymherklotz committed Nov 4, 2024
1 parent 9a4702b commit b6d73c6
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 43 deletions.
122 changes: 82 additions & 40 deletions DataflowRewriter/Component.lean
Original file line number Diff line number Diff line change
Expand Up @@ -38,64 +38,89 @@ namespace DataflowRewriter.NatModule
inputs := mod.inputs,
internals := mod.internals }

@[drunfold] def merge T : NatModule (List T) :=
{ inputs := [(0, ⟨ T, λ oldList newElement newList => newList = newElement :: oldList ⟩),
(1, ⟨ T, λ oldList newElement newList => newList = newElement :: oldList ⟩)].toAssocList,
outputs := [(0, ⟨ T, λ oldList oldElement newList =>
∃ i, newList = oldList.remove i
∧ oldElement = oldList.get i ⟩)].toAssocList,
internals := []
}

@[drunfold] def fork T : NatModule (List T) :=
{ inputs := [(0, ⟨ T, λ oldList newElement newList => newList = newElement :: oldList ⟩)].toAssocList,
outputs := [ (0, ⟨ T, λ oldList oldElement newList => ∃ i, newList = oldList.remove i ∧ oldElement = oldList.get i ⟩)
, (1, ⟨ T, λ oldList oldElement newList => ∃ i, newList = oldList.remove i ∧ oldElement = oldList.get i ⟩)
].toAssocList,
internals := []
}
@[drunfold] def merge T (n : Nat) : NatModule (List T) :=
{ inputs := List.range n |>.map (Prod.mk ↑· (⟨ T, λ oldList newElement newList => newList = newElement :: oldList ⟩)) |>.toAssocList,
outputs := [(0, ⟨ T, λ oldList oldElement newList =>
∃ i, newList = oldList.remove i
∧ oldElement = oldList.get i ⟩)].toAssocList,
internals := []
}

@[drunfold] def fork T (n : Nat) : NatModule (List T) :=
{ inputs := [(0, ⟨ T, λ oldList newElement newList => newList = newElement :: oldList ⟩)].toAssocList,
outputs := List.range n |>.map (Prod.mk ↑· ⟨ T, λ oldList oldElement newList => ∃ i, newList = oldList.remove i ∧ oldElement = oldList.get i ⟩) |>.toAssocList,
internals := []
}

@[drunfold] def queue T : NatModule (List T) :=
{ inputs := [( ⟨ .top, 0⟩ ,⟨ T, λ oldList newElement newList => newList = newElement :: oldList ⟩)].toAssocList,
outputs := [(⟨ .top, 0⟩, ⟨ T, λ oldList oldElement newList => newList ++ [oldElement] = oldList ⟩)].toAssocList,
internals := []
{ inputs := [(⟨ .top, 0 ⟩, ⟨ T, λ oldList newElement newList => newList = newElement :: oldList ⟩)].toAssocList,
outputs := [(⟨ .top, 0 ⟩, ⟨ T, λ oldList oldElement newList => newList ++ [oldElement] = oldList ⟩)].toAssocList,
internals := []
}

@[drunfold] def join T T' : NatModule (List T × List T') :=
{ inputs := [ (0, ⟨ T, λ (oldListL, oldListR) newElement (newListL, newListR) =>
newListL = newElement :: oldListL ∧ newListR = oldListR⟩)
, (1, ⟨ T', λ (oldListL, oldListR) newElement (newListL, newListR) =>
newListR = newElement :: oldListR ∧ newListL = oldListL⟩)].toAssocList,
outputs := [(0, ⟨ T × T', λ (oldListL, oldListR) (oldElementL, oldElementR) (newListL, newListR) =>
oldListL = oldElementL :: newListL ∧
oldListR = oldElementR :: newListR ⟩)].toAssocList,
internals := []
}

@[drunfold] def branch T : NatModule (List T × List Bool) :=
{ inputs := [ (0, ⟨ T, λ (oldValList, oldBoolList) val (newValList, newBoolList) =>
newValList = val :: oldValList ∧ newBoolList = oldBoolList ⟩)
, (1, ⟨ Bool, λ (oldValList, oldBoolList) b (newValList, newBoolList) =>
newValList = oldValList ∧ newBoolList = b :: oldBoolList ⟩)
].toAssocList
outputs := [ (0, ⟨ T, λ (oldValList, oldBoolList) val (newValList, newBoolList) =>
val :: newValList = oldValList ∧ true :: newBoolList = oldBoolList ⟩)
, (1, ⟨ T, λ (oldValList, oldBoolList) val (newValList, newBoolList) =>
val :: newValList = oldValList ∧ false :: newBoolList = oldBoolList ⟩)
].toAssocList
internals := []
}

@[drunfold] def mux T : NatModule (List T × List T × List Bool) :=
{ inputs := [ (0, ⟨ T, λ (oldTrueList, oldFalseList, oldBoolList) val (newTrueList, newFalseList, newBoolList) =>
newTrueList = val :: oldTrueList ∧ newFalseList = oldFalseList ∧ newBoolList = oldBoolList ⟩)
, (1, ⟨ T, λ (oldTrueList, oldFalseList, oldBoolList) val (newTrueList, newFalseList, newBoolList) =>
newTrueList = oldTrueList ∧ newFalseList = val :: oldFalseList ∧ newBoolList = oldBoolList ⟩)
, (2, ⟨ Bool, λ (oldTrueList, oldFalseList, oldBoolList) b (newTrueList, newFalseList, newBoolList) =>
newTrueList = oldTrueList ∧ newFalseList = oldFalseList ∧ newBoolList = b :: oldBoolList ⟩)
].toAssocList
outputs := [ (0, ⟨ T, λ (oldTrueList, oldFalseList, oldBoolList) val (newTrueList, newFalseList, newBoolList) =>
∃ b, b :: newBoolList = oldBoolList
if b then val :: newTrueList = oldTrueList ∧ newFalseList = oldFalseList
else newTrueList = oldTrueList ∧ val :: newFalseList = oldFalseList ⟩)
].toAssocList
internals := []
}
@[drunfold] def queueS T : StringModule (List T) :=
queue T |>.mapIdent (λ x => "enq") (λ x => "deq")

@[drunfold] def join T T' : NatModule (List T× List T') :=
{ inputs := [ (0, ⟨ T, λ (oldListL,oldListR) newElement (newListL,newListR) =>
newListL = newElement :: oldListL ∧ newListR = oldListR⟩)
, (1, ⟨ T', λ (oldListL,oldListR) newElement (newListL,newListR) =>
newListR = newElement :: oldListR ∧ newListL = oldListL⟩)].toAssocList,
outputs := [(0, ⟨ T × T', λ (oldListL,oldListR) (oldElementL, oldElementR) (newListL,newListR) =>
oldListL = oldElementL :: newListL ∧
oldListR = oldElementR :: newListR ⟩)].toAssocList,
internals := []
}

@[drunfold] def bag T : NatModule (List T) :=
{ inputs := [(0,⟨ T, λ oldList newElement newList => newList = newElement :: oldList ⟩)].toAssocList,
outputs := [(0,⟨ T, λ oldList oldElement newList => ∃ i, newList = oldList.remove i ∧ oldElement = oldList.get i ⟩)].toAssocList,
internals := []}
{ inputs := [(0,⟨ T, λ oldList newElement newList => newList = newElement :: oldList ⟩)].toAssocList,
outputs := [(0,⟨ T, λ oldList oldElement newList => ∃ i, newList = oldList.remove i ∧ oldElement = oldList.get i ⟩)].toAssocList,
internals := []}

@[drunfold] def tag_complete_spec (TagT : Type 0) [_i: DecidableEq TagT] (T : Type 0) : NatModule (List TagT × (TagT → Option T)) :=
@[drunfold] def tagger (TagT : Type 0) [_i: DecidableEq TagT] (T : Type 0) : NatModule (List TagT × (TagT → Option T)) :=
{ inputs := [
-- Complete computation
(0,⟨ TagT × T, λ (oldOrder, oldMap) (tag,el) (newOrder, newMap) =>
(0, ⟨ TagT × T, λ (oldOrder, oldMap) (tag,el) (newOrder, newMap) =>
-- Tag must be used, but no value ready, otherwise block:
(List.elem tag oldOrder ∧ oldMap tag = none) ∧
newMap = (λ idx => if idx == tag then some el else oldMap idx) ∧ newOrder = oldOrder⟩)
].toAssocList,
outputs := [
-- Allocate fresh tag
(0,⟨ TagT, λ (oldOrder, oldMap) (tag) (newOrder, newMap) =>
(0, ⟨ TagT, λ (oldOrder, oldMap) (tag) (newOrder, newMap) =>
-- Tag must be unused otherwise block (alternatively we
-- could an implication to say undefined behavior):
(!List.elem tag oldOrder ∧ oldMap tag = none) ∧
newMap = oldMap ∧ newOrder = tag :: oldOrder⟩),
-- Dequeue + free tag
(1,⟨ T, λ (oldorder, oldmap) el (neworder, newmap) =>
(1, ⟨ T, λ (oldorder, oldmap) el (neworder, newmap) =>
-- tag must be used otherwise, but no value brought, undefined behavior:
∃ l tag , oldorder = l ++ [tag] ∧ oldmap tag = some el ∧
newmap = (λ idx => if idx == tag then none else oldmap idx) ∧ neworder = l ⟩),
Expand All @@ -107,7 +132,24 @@ end DataflowRewriter.NatModule

namespace DataflowRewriter.StringModule

@[drunfold] def bagS T : StringModule (List T) :=
@[drunfold] def bag T : StringModule (List T) :=
NatModule.bag T |>.mapIdent (λ x => "enq") (λ x => "deq")

@[drunfold] def merge T n := NatModule.merge T n |>.stringify

@[drunfold] def fork T n := NatModule.fork T n |>.stringify

@[drunfold] def queue T : StringModule (List T) :=
NatModule.queue T |>.mapIdent (λ x => "enq") (λ x => "deq")

@[drunfold] def join T T' := NatModule.join T T' |>.stringify

@[drunfold] def branch T := NatModule.branch T
|>.mapIdent (λ | 0 => "val" | _ => "cond") (λ | 0 => "true" | _ => "false")

@[drunfold] def mux T := NatModule.mux T
|>.mapIdent (λ | 0 => "true" | 1 => "false" | _ => "cond") (λ _ => "out")

@[drunfold] def tagger TagT [DecidableEq TagT] T := NatModule.tagger TagT T

end DataflowRewriter.StringModule
12 changes: 9 additions & 3 deletions DataflowRewriter/Rewrites/MergeRewrite.lean
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,16 @@ def mergeHigh : ExprHigh String :=

/--
info: ok: digraph {
rw0_2 [mod = "merge3", label = "rw0_2: merge3"]
rw0_1 [mod = "fork", label = "rw0_1: fork"]
rw0_0 [mod = "fork", label = "rw0_0: fork"]
snk0 [mod = "io", label = "snk0: io"];
src0 [mod = "io", label = "src0: io"];
rw0_2 [mod = "merge3", label = "rw0_2: merge3"];
rw0_1 [mod = "fork", label = "rw0_1: fork"];
rw0_0 [mod = "fork", label = "rw0_0: fork"];
rw0_2 -> snk0 [out = "out0", taillabel = "out0"];
src0 -> rw0_0 [inp = "inp0", headlabel = "inp0"];
rw0_0 -> rw0_1 [out = "out0", inp = "inp0", taillabel = "out0", headlabel = "inp0",];
rw0_0 -> rw0_2 [out = "out1", inp = "inp0", taillabel = "out1", headlabel = "inp0",];
Expand Down

0 comments on commit b6d73c6

Please sign in to comment.