Skip to content

Commit

Permalink
Adding better elaboration of graphs, as well as the loop rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
ymherklotz committed Dec 8, 2024
1 parent 8fd1234 commit 4b45181
Show file tree
Hide file tree
Showing 14 changed files with 624 additions and 59 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1 @@
/.lake
.lake
2 changes: 1 addition & 1 deletion DataflowRewriter/AssocList/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ theorem contains_eraseAll {α β} [DecidableEq α] {a : AssocList α β} {i i'}
simp only [←contains_find?_iff]; intro ⟨_, _⟩; solve_by_elim [find?_eraseAll]

@[simp] theorem any_map {α β} {f : α → β} {l : List α} {p : β → Bool} : (l.map f).any p = l.any (p ∘ f) := by
induction l with simp | cons _ _ ih => rw [ih]
induction l <;> simp

theorem keysInMap {α β} [DecidableEq α] {m : AssocList α β} {k} : m.contains k → k ∈ m.keysList := by
unfold Batteries.AssocList.contains Batteries.AssocList.keysList
Expand Down
29 changes: 21 additions & 8 deletions DataflowRewriter/Component.lean
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,19 @@ namespace DataflowRewriter.NatModule
}

@[drunfold] def queue T : NatModule (List T) :=
{ inputs := [(⟨ .top, 0 ⟩, ⟨ T, λ oldList newElement newList => newList = newElement :: oldList ⟩)].toAssocList,
outputs := [(⟨ .top, 0 ⟩, ⟨ T, λ oldList oldElement newList => newList.concat oldElement = oldList ⟩)].toAssocList,
{ inputs := [(⟨ .top, 0 ⟩, ⟨ T, λ oldList newElement newList => newList = oldList.concat newElement ⟩)].toAssocList,
outputs := [(⟨ .top, 0 ⟩, ⟨ T, λ oldList oldElement newList => oldElement :: newList = oldList ⟩)].toAssocList,
internals := []
}

@[drunfold] def init T (default : T) : NatModule (List T × Bool) :=
{ inputs := [(⟨ .top, 0 ⟩, ⟨ T, λ (oldList, oldState) newElement (newList, newState) =>
newList = oldList.concat newElement ∧ oldState = newState ⟩)].toAssocList,
outputs := [(⟨ .top, 0 ⟩, ⟨ T, λ (oldList, oldState) oldElement (newList, newState) =>
if oldState then
oldElement :: newList = oldList ∧ oldState = newState
else
newList = oldList ∧ newState = true ∧ oldElement = default ⟩)].toAssocList,
internals := []
}

Expand Down Expand Up @@ -202,10 +213,10 @@ namespace DataflowRewriter.NatModule
/--
Essentially tagger + join without internal rule
-/
@[drunfold] def tagger_untagger_val (TagT : Type 0) [_i: DecidableEq TagT] (T : Type 0) : NatModule (List TagT × AssocList TagT T × List T) :=
@[drunfold] def tagger_untagger_val (TagT : Type 0) [_i: DecidableEq TagT] (T T' : Type 0) : NatModule (List TagT × AssocList TagT T' × List T) :=
{ inputs := [
-- Complete computation
(0, ⟨ TagT × T, λ (oldOrder, oldMap, oldVal) (tag,el) (newOrder, newMap, newVal) =>
(0, ⟨ TagT × T', λ (oldOrder, oldMap, oldVal) (tag,el) (newOrder, newMap, newVal) =>
-- Tag must be used, but no value ready, otherwise block:
(tag ∈ oldOrder ∧ oldMap.find? tag = none) ∧
newMap = oldMap.cons tag el ∧ newOrder = oldOrder ∧ newVal = oldVal ⟩),
Expand All @@ -221,7 +232,7 @@ Essentially tagger + join without internal rule
(tag ∉ oldOrder ∧ oldMap.find? tag = none) ∧
newMap = oldMap ∧ newOrder = tag :: oldOrder ∧ newVal.cons v = oldVal⟩),
-- Dequeue + free tag
(1, ⟨ T, λ (oldorder, oldmap, oldVal) el (neworder, newmap, newVal) =>
(1, ⟨ T', λ (oldorder, oldmap, oldVal) el (neworder, newmap, newVal) =>
-- tag must be used otherwise, but no value brought, undefined behavior:
∃ tag , oldorder = neworder.concat tag ∧ oldmap.find? tag = some el ∧
newmap = oldmap.eraseAll tag ∧ newVal = oldVal ⟩),
Expand Down Expand Up @@ -376,6 +387,8 @@ namespace DataflowRewriter.StringModule
@[drunfold] def binary_op {α β R} (f : α → β → R) := NatModule.binary_op f
|>.stringify

@[drunfold] def init T default := NatModule.init T default |>.stringify

@[drunfold] def constant {T} (t : T) := NatModule.constant t |>.stringify

opaque polymorphic_add {T} [Inhabited T] : T → T → T
Expand All @@ -392,8 +405,8 @@ opaque constant_e {T} [Inhabited T] : T
opaque constant_f {T} [Inhabited T] : T
opaque constant_g {T} [Inhabited T] : T

@[drunfold] def tagger_untagger_val TagT [DecidableEq TagT] T :=
NatModule.tagger_untagger_val TagT T |>.stringify
@[drunfold] def tagger_untagger_val TagT [DecidableEq TagT] T T' :=
NatModule.tagger_untagger_val TagT T T' |>.stringify

def ε (Tag : Type) [DecidableEq Tag] (T : Type) [Inhabited T] : IdentMap String (TModule String) :=
[ ("Join", ⟨_, StringModule.join T T⟩)
Expand Down Expand Up @@ -434,7 +447,7 @@ def ε (Tag : Type) [DecidableEq Tag] (T : Type) [Inhabited T] : IdentMap String
, ("Bag", ⟨_, StringModule.bag (Tag × T)⟩)

, ("Aligner", ⟨_, StringModule.aligner Tag T⟩)
, ("TaggerCntrlAligner", ⟨_, StringModule.tagger_untagger_val Tag T⟩)
, ("TaggerCntrlAligner", ⟨_, StringModule.tagger_untagger_val Tag T T⟩)

, ("ConstantA", ⟨_, StringModule.constant (@constant_a T)⟩)
, ("ConstantB", ⟨_, StringModule.constant (@constant_b T)⟩)
Expand Down
211 changes: 191 additions & 20 deletions DataflowRewriter/ExprHighElaborator.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Lean
import Qq

import DataflowRewriter.ExprHigh
import DataflowRewriter.Component

namespace DataflowRewriter

Expand All @@ -19,6 +20,7 @@ declare_syntax_cat dot_attr
syntax str : dot_value
syntax num : dot_value
syntax ident : dot_value
syntax (name := dot_value_term) "$(" term ")" : dot_value

syntax ident " = " dot_value : dot_attr
syntax (dot_attr),* : dot_attr_list
Expand All @@ -31,6 +33,7 @@ syntax dot_stmnt_list := (dot_stmnt "; ")*
syntax dot_input_list := ("(" ident ", " num ")"),*

syntax (name := dot_graph) "[graph| " dot_stmnt_list " ]" : term
syntax (name := dot_graphEnv) "[graphEnv| " dot_stmnt_list " ]" : term

open Lean.Meta Lean.Elab Term Lean.Syntax

Expand All @@ -42,7 +45,27 @@ def findStx (n : Name) (stx : Array Syntax) : Option Nat := do
out := some (TSyntax.mk pair[2][0]).getNat
out

#check true
open Lean Qq in
@[term_elab dot_value_term]
def dotValueTermElab : TermElab
| `(dot_value| $( $a:term )), t => elabTerm a t
| _, _ => throwError "Could not match syntax"

open Lean in
def hasStxElement (n : Name) (stx : Array Syntax) : Bool := Id.run do
let mut out := false
for pair in stx do
if pair[0].getId = n then
out := true
out

open Lean in
def getListElement (n : Name) (stx : Array Syntax) : MetaM Syntax := do
let mut out : Option Syntax := .none
for pair in stx do
if pair[0].getId = n then
out := .some pair
return out.getD (← getRef)

open Lean in
def findStxBool (n : Name) (stx : Array Syntax) : Option Bool := do
Expand All @@ -65,6 +88,18 @@ def findStxStr (n : Name) (stx : Array Syntax) : MetaM (Option String) := do
out := some out'
return out

open Lean Qq in
def findStxTerm (n : Name) (stx : Array Syntax) : TermElabM (Option (Expr × String)) := do
let mut out := none
for pair in stx do
if pair[0].getId = n ∧ pair[2][0].isStrLit?.isNone then
-- let content : TSyntax `term := ⟨pair[2][1]⟩
-- let depPair ← `(term| Sigma.mk _ $content)
let term ← elabTermEnsuringType pair[2] <| .some q(TModule1 String)
let str ← ppTerm {env := ← getEnv } ⟨pair[2][1]⟩
out := some (term, Format.pretty str)
return out

def toInstIdent {α} [Inhabited α] (n : String) (h : Std.HashMap String α) : InstIdent α :=
match n with
| "io" => .top
Expand Down Expand Up @@ -113,7 +148,7 @@ def dotGraphElab : TermElab := λ stx _typ? => do
match low_stmnt with
| `(dot_stmnt| $i:ident $[[$[$el:dot_attr],*]]? ) =>
let some el := el
| throwErrorAt i "No `type` attribute found at node"
| throwErrorAt i "Element list is not present"
let some modId ← findStxStr `type el
| throwErrorAt i "No `type` attribute found at node"
let mut modCluster : Bool := findStxBool `cluster el |>.getD false
Expand Down Expand Up @@ -151,31 +186,167 @@ def dotGraphElab : TermElab := λ stx _typ? => do
let modListMap : Q(IdentMap String (PortMapping String × String)) := q(List.toAssocList $modList)
return q(ExprHigh.mk $modListMap $connExpr)

-- open Lean.PrettyPrinter Delaborator SubExpr
-- open Qq in
-- instance {α : Q(Type)} : Lean.ToExpr Q($α) where
-- toExpr := id
-- toTypeExpr := α
#check Lean.getRef

open Lean Qq in
def checkInputPresent (envMap : Std.HashMap String Expr) (inInst : Q(TModule1 String)) (inP : Option String)
: TermElabM Unit := do
match inP with
| .some inP =>
let inputS : Q(Type) := q(($inInst).fst)
let inputMap : Q(PortMap String (Σ T : Type, ($inputS → T → $inputS → Prop))) := q(($inInst).snd.inputs)
let expr : Q(Bool) := q(Batteries.AssocList.contains (InternalPort.mk .top $inP) $inputMap)
unless ← isDefEq expr q(true) do
throwError "Input not present in Module"
| .none => return ()

open Lean Qq in
def checkOutputPresent (envMap : Std.HashMap String Expr) (outInst : Q(TModule1 String)) (outP : Option String)
: TermElabM Unit := do
match outP with
| .some outP =>
let outputS : Q(Type) := q(($outInst).fst)
let outputMap : Q(PortMap String (Σ T : Type, ($outputS → T → $outputS → Prop))) := q(($outInst).snd.outputs)
let expr : Q(Bool) := q(Batteries.AssocList.contains (InternalPort.mk .top $outP) $outputMap)
unless ← isDefEq expr q(true) do
throwError "Output not present in Module"
| .none => return ()

open Lean Qq in
def checkTypeErrors (envMap : Std.HashMap String Expr) (maps : InstMaps) (conns : List (Connection String))
(outInst inInst : String) (outP inP : Option String)
: TermElabM Unit := do
let `(dot_stmnt| $a:ident -> $b:ident $[[$[$el:dot_attr],*]]? ) ← getRef
| throwError "Failed to parse reference"
-- logInfo <| repr <| envMap.toList.map (·.fst)
let mut outInstExpr : Q(TModule1 String) := default
let mut inInstExpr : Q(TModule1 String) := default
if outP.isSome then
let .some outTypeInfo := maps.instTypeMap[outInst]?
| throwErrorAt a "Could not find output in type map"
let .some outInstExpr' := envMap[outTypeInfo.2]?
| throwErrorAt b "Could not find output in envMap"
outInstExpr := outInstExpr'
if inP.isSome then
let .some inTypeInfo := maps.instTypeMap[inInst]?
| throwErrorAt b "Could not find output in type map"
let .some inInstExpr' := envMap[inTypeInfo.2]?
| throwErrorAt a "Could not find input in envMap"
inInstExpr := inInstExpr'
-- The best would be to highlight the actual items in the list instead of the
-- two sides of the arrow.
let some el := el | return ()
withRef (← getListElement `inp el) <| checkInputPresent envMap inInstExpr inP
withRef (← getListElement `out el) <| checkOutputPresent envMap outInstExpr outP

-- Only check types if we are not assigning an IO port
let (.some outP, .some inP) := (outP, inP) | return ()

let outputMap : Q(PortMap String (Σ T : Type, (($outInstExpr).fst → T → ($outInstExpr).fst → Prop))) := q(($outInstExpr).snd.outputs)
let inputMap : Q(PortMap String (Σ T : Type, (($inInstExpr).fst → T → ($inInstExpr).fst → Prop))) := q(($inInstExpr).snd.inputs)

let inputType : Q(Type) := q(($inputMap).getIO (InternalPort.mk .top $inP) |>.fst)
let outputType : Q(Type) := q(($outputMap).getIO (InternalPort.mk .top $outP) |>.fst)

unless ← isDefEq inputType outputType do
throwError "Types of input and output port do not match (output ≠ input):\n {← whnf outputType} ≠ {← whnf inputType}"
return ()

open Lean Qq in
@[term_elab dot_graphEnv]
def dotGraphElab' : TermElab := λ stx _typ? => do
let mut instMap : Std.HashMap String (InstIdent String × Bool) := ∅
let mut instTypeMap : Std.HashMap String (PortMapping String × String) := ∅
let mut conns : List (Connection String) := []
let mut envMap : Std.HashMap String Expr := ∅
for stmnt in stx[1][0].getArgs do
let low_stmnt := stmnt.getArgs[0]!
match low_stmnt with
| `(dot_stmnt| $i:ident $[[$[$el:dot_attr],*]]? ) =>
let some el := el
| throwErrorAt i "Element list is not present"
let mut modId := ""
match ← findStxTerm `type el with
| .some modIdImp =>
modId := modIdImp.snd
envMap := envMap.insert modId modIdImp.fst
| .none =>
let some modId' ← findStxStr `type el
| throwErrorAt i "No `type` attribute found at node"
modId := modId'
let mut modCluster : Bool := findStxBool `cluster el |>.getD false
match updateNodeMaps ⟨instMap, instTypeMap⟩ i.getId.toString modId modCluster with
| .ok ⟨a, b⟩ =>
instMap := a
instTypeMap := b
| .error s =>
throwErrorAt i s
| ref@`(dot_stmnt| $a:ident -> $b:ident $[[$[$el:dot_attr],*]]? ) =>
-- Error checking to report it early if the instance is not present in the
-- hashmap.
let some el := el
| throwErrorAt (mkListNode #[a, b]) "No `type` attribute found at node"
let mut out ← (findStxStr `out el)
let mut inp ← (findStxStr `inp el)
withRef ref <|
checkTypeErrors envMap ⟨instMap, instTypeMap⟩ conns a.getId.toString b.getId.toString out inp
match updateConnMaps ⟨instMap, instTypeMap⟩ conns a.getId.toString b.getId.toString out inp with
| .ok (⟨_, b⟩, c) =>
conns := c
instTypeMap := b
| .error (.outInstError s) => throwErrorAt a s
| .error (.inInstError s) => throwErrorAt b s
| .error (.portError s) => throwErrorAt (mkListNode el) s
| _ => pure ()
let connExpr : Q(List (Connection String)) ←
mkListLit q(Connection String) (← conns.mapM (λ ⟨ a, b ⟩ => do
mkAppM ``Connection.mk #[reifyInternalPort a, reifyInternalPort b]))
let modList : Q(List (String × (PortMapping String × String))) ←
mkListLit q(String × (PortMapping String × String))
(instTypeMap.toList.map (fun (a, (p, b)) =>
let a' : Q(String) := .lit (.strVal a)
let b' : Q(String) := .lit (.strVal b)
let p' : Q(PortMapping String) := mkPortMapping <| p.map (.strVal · |> .lit)
q(($a', ($p', $b')))))
let envList : Q(List (String × TModule1 String)) ←
mkListLit q(String × TModule1 String)
(envMap.toList.map (fun (a, m) =>
let m' : Q(TModule1 String) := m
let a' : Q(String) := toExpr a
q(($a', $m'))
))
let modListMap : Q(IdentMap String (PortMapping String × String)) := q(List.toAssocList $modList)
return q(Prod.mk (ExprHigh.mk $modListMap $connExpr) (List.toAssocList $envList))

-- namespace mergemod
-- open Lean.PrettyPrinter Delaborator SubExpr

-- def mergeHigh : ExprHigh String :=
-- [graph|
-- src0 [mod="src"];
-- snk0 [mod="snk"];
namespace mergemod

-- fork1 [mod="fork"];
-- fork2 [mod="fork"];
-- merge1 [mod="merge"];
-- merge2 [mod="merge"];
open StringModule in
def mergeHigh (T : Type _) : ExprHigh String × (IdentMap String (TModule1 String)) :=
([graph|
src0 [type="io"];
snk0 [type="io"];

-- src0 -> fork1 [out="0",inp="0"];
fork1 [type="sten"];
], [("⟨_, fork T 2⟩", ⟨List T, fork T 2⟩)].toAssocList)

-- fork1 -> fork2 [out="0",inp="0"];
open StringModule in
def mergeHigh2 (T : Type _) : ExprHigh String × (IdentMap String (TModule1 String)) :=
[graphEnv|
src0 [type="io"];
snk0 [type="io"];

-- fork1 -> merge1 [out="1",inp="0"];
-- fork2 -> merge1 [out="0",inp="1"];
-- fork2 -> merge2 [out="1",inp="0"];
fork1 [type=$(⟨_, fork T 2⟩)];
fork2 [type=$(⟨_, fork T 3⟩)];
]

-- merge1 -> merge2 [out="0",inp="1"];
#print mergeHigh2

-- merge2 -> snk0 [out="0",inp="0"];
-- ]
end mergemod

end DataflowRewriter
1 change: 1 addition & 0 deletions DataflowRewriter/ExprLowLemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ theorem build_module_type_rename' {e : ExprLow Ident} {f g} :
induction e with
| base map typ =>
simp [drunfold, -AssocList.find?_eq]
sorry
| connect o i e ih =>
dsimp [drunfold, -AssocList.find?_eq]
cases h : build_module' ε e
Expand Down
1 change: 1 addition & 0 deletions DataflowRewriter/Module.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1240,5 +1240,6 @@ def IdentMap.toInterface {Ident} (i : IdentMap Ident (Σ T, Module Ident T))
i.mapVal (λ _ x => x.snd |>.toInterface)

abbrev TModule Ident := Σ T, Module Ident T
abbrev TModule1 Ident := Σ T : Type, Module Ident T

end DataflowRewriter
Loading

0 comments on commit 4b45181

Please sign in to comment.