Skip to content

Commit

Permalink
Add components
Browse files Browse the repository at this point in the history
  • Loading branch information
ymherklotz committed Nov 5, 2024
1 parent ea0bb83 commit 2dce1d6
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 6 deletions.
156 changes: 151 additions & 5 deletions DataflowRewriter/Component.lean
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,31 @@ namespace DataflowRewriter.NatModule
internals := []
}

@[drunfold] def cntrl_merge T : NatModule (List T × List Bool) :=
{ inputs := [ (0, ⟨ T, λ (oldListL, oldListR) newElement (newListL, newListR) =>
newListL = newElement :: oldListL ∧ newListR = true :: oldListR ⟩)
, (1, ⟨ T, λ (oldListL, oldListR) newElement (newListL, newListR) =>
newListL = newElement :: oldListL ∧ newListR = false :: oldListR ⟩)
].toAssocList,
outputs := [ (0, ⟨ T, λ (oldListL, oldListR) oldElement (newListL, newListR) =>
newListL.concat oldElement = oldListL ∧ newListR = oldListR ⟩)
, (1, ⟨ Bool, λ (oldListL, oldListR) oldElement (newListL, newListR) =>
newListR.concat oldElement = oldListR ∧ newListL = oldListL ⟩)
].toAssocList,
internals := []
}

@[drunfold] def cntrl_merge_n T (n : Nat) : NatModule (List T × List Nat) :=
{ inputs := List.range n |>.map (Prod.mk ↑· (⟨ T, λ (oldListL, oldListR) newElement (newListL, newListR) =>
newListL = newElement :: oldListL ∧ newListR = n :: oldListR ⟩)) |>.toAssocList,
outputs := [ (0, ⟨ T, λ (oldListL, oldListR) oldElement (newListL, newListR) =>
newListL.concat oldElement = oldListL ∧ newListR = oldListR ⟩)
, (1, ⟨ Nat, λ (oldListL, oldListR) oldElement (newListL, newListR) =>
newListR.concat oldElement = oldListR ∧ newListL = oldListL ⟩)
].toAssocList,
internals := []
}

@[drunfold] def fork T (n : Nat) : NatModule (List T) :=
{ inputs := [(0, ⟨ T, λ oldList newElement newList => newList = newElement :: oldList ⟩)].toAssocList,
outputs := List.range n |>.map (Prod.mk ↑· ⟨ T, λ oldList oldElement newList => ∃ i, newList = oldList.remove i ∧ oldElement = oldList.get i ⟩) |>.toAssocList,
Expand All @@ -72,6 +97,16 @@ namespace DataflowRewriter.NatModule
internals := []
}

@[drunfold] def split T T' : NatModule (List T × List T') :=
{ inputs := [ (0, ⟨ T × T', λ (oldListL, oldListR) (newElementL, newElementR) (newListL, newListR) =>
newListL = newElementL :: oldListL ∧ newListR = newElementR :: oldListR⟩)].toAssocList,
outputs := [(0, ⟨ T, λ (oldListL, oldListR) oldElementL (newListL, newListR) =>
oldListL = newListL.concat oldElementL ∧ oldListR = oldListR ⟩),
(1, ⟨ T', λ (oldListL, oldListR) oldElementR (newListL, newListR) =>
oldListR = newListR.concat oldElementR ∧ oldListL = oldListL ⟩)].toAssocList,
internals := []
}

@[drunfold] def branch T : NatModule (List T × List Bool) :=
{ inputs := [ (0, ⟨ T, λ (oldValList, oldBoolList) val (newValList, newBoolList) =>
newValList = val :: oldValList ∧ newBoolList = oldBoolList ⟩)
Expand Down Expand Up @@ -178,28 +213,139 @@ Essentially tagger + join without internal rule
internals := []
}

@[drunfold] def unary_op {α R} (f : α → R): NatModule (List α) :=
{ inputs := [
(0, ⟨ α, λ oldList newElement newList => newList = newElement :: oldList ⟩)
].toAssocList,
outputs := [
(0, ⟨ R, λ oldList oldElement newList => ∃ a, newList.concat a = oldList ∧ oldElement = f a ⟩)
].toAssocList,
internals := []
}

@[drunfold] def binary_op {α β R} (f : α → β → R): NatModule (List α × List β) :=
{ inputs := [
(0, ⟨ α, λ (oldListL, oldListR) newElement (newListL, newListR) => newListL = newElement :: oldListL ⟩),
(1, ⟨ β, λ (oldListL, oldListR) newElement (newListL, newListR) => newListR = newElement :: oldListR ⟩)
].toAssocList,
outputs := [
(0, ⟨ R, λ (oldListL, oldListR) oldElement (newListL, newListR) =>
∃ a b, newListL.concat a = oldListL
∧ newListR.concat b = oldListR
∧ oldElement = f a b ⟩)
].toAssocList,
internals := []
}

@[drunfold] def constant {T} (t : T) : NatModule (List Unit) :=
{ inputs := [
(0, ⟨ Unit, λ oldList newElement newList => newList = newElement :: oldList ⟩)
].toAssocList,
outputs := [
(0, ⟨ T, λ oldList oldElement newList => ∃ a, newList.concat a = oldList ∧ oldElement = t ⟩)
].toAssocList,
internals := []
}

end DataflowRewriter.NatModule

namespace DataflowRewriter.StringModule

@[drunfold] def bag T : StringModule (List T) :=
NatModule.bag T |>.mapIdent (λ x => "enq") (λ x => "deq")
@[drunfold] def bag T : StringModule (List T) := NatModule.bag T
|>.stringify
-- |>.mapIdent (λ x => "enq") (λ x => "deq")

@[drunfold] def merge T n := NatModule.merge T n |>.stringify

@[drunfold] def fork T n := NatModule.fork T n |>.stringify

@[drunfold] def queue T : StringModule (List T) :=
NatModule.queue T |>.mapIdent (λ x => "enq") (λ x => "deq")
@[drunfold] def cntrl_merge T := NatModule.cntrl_merge T |>.stringify

@[drunfold] def queue T : StringModule (List T) := NatModule.queue T
|>.stringify
-- |>.mapIdent (λ x => "enq") (λ x => "deq")

@[drunfold] def join T T' := NatModule.join T T' |>.stringify

@[drunfold] def split T T' := NatModule.split T T' |>.stringify

@[drunfold] def branch T := NatModule.branch T
|>.mapIdent (λ | 0 => "val" | _ => "cond") (λ | 0 => "true" | _ => "false")
|>.stringify
-- |>.mapIdent (λ | 0 => "val" | _ => "cond") (λ | 0 => "true" | _ => "false")

@[drunfold] def mux T := NatModule.mux T
|>.stringify

@[drunfold] def tagger TagT [DecidableEq TagT] T := NatModule.tagger TagT T
|>.stringify

@[drunfold] def unary_op {α R} (f : α → R) := NatModule.unary_op f
|>.stringify

@[drunfold] def binary_op {α β R} (f : α → β → R) := NatModule.binary_op f
|>.stringify

@[drunfold] def constant {T} (t : T) := NatModule.constant t |>.stringify

opaque polymorphic_add {T} [Inhabited T] : T → T → T
opaque polymorphic_sub {T} [Inhabited T] : T → T → T
opaque polymorphic_mult {T} [Inhabited T] : T → T → T
opaque polymorphic_div {T} [Inhabited T] : T → T → T
opaque polymorphic_shift_left {T} [Inhabited T] : T → T → T

opaque constant_a {T} [Inhabited T] : T
opaque constant_b {T} [Inhabited T] : T
opaque constant_c {T} [Inhabited T] : T
opaque constant_d {T} [Inhabited T] : T
opaque constant_e {T} [Inhabited T] : T
opaque constant_f {T} [Inhabited T] : T
opaque constant_g {T} [Inhabited T] : T

@[drunfold] def tagger_untagger_val TagT [DecidableEq TagT] T :=
NatModule.tagger_untagger_val TagT T |>.stringify

def ε (Tag : Type) [DecidableEq Tag] (T : Type) [Inhabited T] : IdentMap String (TModule String) :=
[ ("Join", ⟨_, StringModule.join T T⟩)
, ("TaggedJoin", ⟨_, StringModule.join Tag T⟩)

, ("Split", ⟨_, StringModule.split T T⟩)
, ("TaggedSplit", ⟨_, StringModule.split Tag T⟩)

, ("Merge", ⟨_, StringModule.merge T 2⟩)
, ("TagggedMerge", ⟨_, StringModule.merge (Tag × T) 2⟩)

, ("Fork", ⟨_, StringModule.fork T 2⟩)
, ("TagggedFork", ⟨_, StringModule.fork (Tag × T) 2⟩)

, ("CntrlMerge", ⟨_, StringModule.cntrl_merge T⟩)
, ("TagggedCntrlMerge", ⟨_, StringModule.cntrl_merge (Tag × T)⟩)

, ("Branch", ⟨_, StringModule.branch T⟩)
, ("TagggedBranch", ⟨_, StringModule.branch (Tag × T)⟩)

, ("Mux", ⟨_, StringModule.mux T⟩)
, ("TagggedMux", ⟨_, StringModule.mux (Tag × T)⟩)

, ("Buffer", ⟨_, StringModule.queue T⟩)
, ("TagggedBuffer", ⟨_, StringModule.queue (Tag × T)⟩)

, ("Bag", ⟨_, StringModule.bag (Tag × T)⟩)

, ("TaggerCntrlAligner", ⟨_, StringModule.tagger_untagger_val Tag T⟩)

, ("ConstantA", ⟨_, StringModule.constant (@constant_a T)⟩)
, ("ConstantB", ⟨_, StringModule.constant (@constant_b T)⟩)
, ("ConstantC", ⟨_, StringModule.constant (@constant_c T)⟩)
, ("ConstantD", ⟨_, StringModule.constant (@constant_d T)⟩)
, ("ConstantE", ⟨_, StringModule.constant (@constant_e T)⟩)
, ("ConstantF", ⟨_, StringModule.constant (@constant_f T)⟩)
, ("ConstantG", ⟨_, StringModule.constant (@constant_g T)⟩)

, ("Add", ⟨_, StringModule.binary_op (@polymorphic_add T _)⟩)
, ("Mul", ⟨_, StringModule.binary_op (@polymorphic_mult T _)⟩)
, ("Div", ⟨_, StringModule.binary_op (@polymorphic_div T _)⟩)
, ("Shl", ⟨_, StringModule.binary_op (@polymorphic_shift_left T _)⟩)
, ("Sub", ⟨_, StringModule.binary_op (@polymorphic_sub T _)⟩)
].toAssocList

end DataflowRewriter.StringModule
2 changes: 1 addition & 1 deletion DataflowRewriter/Rewrites/MuxTaggedRewriteCorrect.lean
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/-
Copyright (c) 2024 VCA Lab, EPFL. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Yann Herklotz
Authors: Martina Camaioni
-/

import Lean
Expand Down

0 comments on commit 2dce1d6

Please sign in to comment.