Skip to content

Commit

Permalink
Add more proofs
Browse files Browse the repository at this point in the history
  • Loading branch information
ymherklotz committed Nov 11, 2024
1 parent 0958ab1 commit 350f975
Show file tree
Hide file tree
Showing 6 changed files with 495 additions and 115 deletions.
16 changes: 16 additions & 0 deletions DataflowRewriter/AssocList/Lemmas.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import DataflowRewriter.AssocList.Basic
import Mathlib.Logic.Function.Basic

namespace Batteries.AssocList

Expand Down Expand Up @@ -224,4 +225,19 @@ theorem disjoint_keys_mapVal_both {α β γ μ η} [DecidableEq α] {a : AssocLi
a.disjoint_keys b → (a.mapVal g).disjoint_keys (b.mapVal f) := by
intros; solve_by_elim [disjoint_keys_mapVal, disjoint_keys_symm]

theorem mapKey_find? {α β γ} [DecidableEq α] [DecidableEq γ] {a : AssocList α β} {f : α → γ} {i} (hinj : Function.Injective f) :
(a.mapKey f).find? (f i) = a.find? i := by
induction a with
| nil => simp
| cons k v xs ih =>
dsimp
by_cases h : f k = f i
· have h' := hinj h; rw [h']; simp
· have h' := hinj.ne_iff.mp h;
rw [Batteries.AssocList.find?.eq_2]
rw [Batteries.AssocList.find?.eq_2]; rw [ih]
have t1 : (f k == f i) = false := by simp [*]
have t2 : (k == i) = false := by simp [*]
rw [t1, t2]

end Batteries.AssocList
4 changes: 2 additions & 2 deletions DataflowRewriter/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ instance set to `top`.
(l : PortMap Ident (Σ T : Type u₂, (S → T → S → Prop)))
(n : InternalPort Ident)
: Σ T : Type u₂, (S → T → S → Prop) :=
l.find? n |>.getD (⟨ PUnit, λ _ _ _ => True ⟩)
l.find? n |>.getD (⟨ PUnit, λ _ _ _ => False ⟩)

theorem getIO_none {S} (m : PortMap Ident ((T : Type) × (S → T → S → Prop)))
(ident : InternalPort Ident) :
m.find? ident = none ->
m.getIO ident = ⟨ PUnit, λ _ _ _ => True ⟩ := by
m.getIO ident = ⟨ PUnit, λ _ _ _ => False ⟩ := by
intros H; simp only [PortMap.getIO, H]; simp

theorem getIO_some {S} (m : PortMap Ident ((T : Type) × (S → T → S → Prop)))
Expand Down
33 changes: 33 additions & 0 deletions DataflowRewriter/ExprLow.lean
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,39 @@ def findBase (typ : Ident) : ExprLow Ident → Option (PortMapping Ident)
| some port => port
| none => e₂.findBase typ

def mapInputPorts (f : InternalPort Ident → InternalPort Ident) : ExprLow Ident → ExprLow Ident
| .base map typ' => .base ⟨map.input.mapVal (λ _ => f), map.output⟩ typ'
| .connect o i e => e.mapInputPorts f |> .connect o (f i)
| .product e₁ e₂ => .product (e₁.mapInputPorts f) (e₂.mapInputPorts f)

def mapOutputPorts (f : InternalPort Ident → InternalPort Ident) : ExprLow Ident → ExprLow Ident
| .base map typ' => .base ⟨map.input, map.output.mapVal (λ _ => f)⟩ typ'
| .connect o i e => e.mapOutputPorts f |> .connect (f o) i
| .product e₁ e₂ => .product (e₁.mapOutputPorts f) (e₂.mapOutputPorts f)

def mapPorts2 (f g : InternalPort Ident → InternalPort Ident) (e : ExprLow Ident) : ExprLow Ident :=
e.mapInputPorts f |>.mapOutputPorts g

def filterId (p : PortMapping Ident) : PortMapping Ident :=
⟨p.input.filter (λ a b => a ≠ b), p.output.filter (λ a b => a ≠ b)⟩

def invertible {α} [DecidableEq α] (p : Batteries.AssocList α α) : Bool :=
p.keysList.inter p.inverse.keysList = ∅ ∧ p.keysList.Nodup ∧ p.inverse.keysList.Nodup

def bijectivePortRenaming (p : PortMap Ident (InternalPort Ident)) (i: InternalPort Ident) : InternalPort Ident :=
let p' := p.inverse
if p.keysList.inter p'.keysList = ∅ && p.keysList.Nodup && p'.keysList.Nodup then
let map := p.append p.inverse
map.find? i |>.getD i
else i

theorem invertibleMap {α} [DecidableEq α] {p : Batteries.AssocList α α} {a b} :
invertible p →
(p.append p.inverse).find? a = some b → (p.append p.inverse).find? b = some a := by sorry

def renamePorts (m : ExprLow Ident) (p : PortMapping Ident) : ExprLow Ident :=
m.mapPorts2 (bijectivePortRenaming p.input) (bijectivePortRenaming p.output)

/--
Assume that the input is currently not mapped.
-/
Expand Down
116 changes: 104 additions & 12 deletions DataflowRewriter/ExprLowLemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace ExprLow
variable {Ident}
variable [DecidableEq Ident]

variable (ε : IdentMap Ident (Σ T : Type _, Module Ident T))
variable (ε : IdentMap Ident (Σ T : Type, Module Ident T))

@[drunfold] def get_types (i : Ident) : Type _ :=
(ε.find? i) |>.map Sigma.fst |>.getD PUnit
Expand Down Expand Up @@ -64,11 +64,13 @@ theorem build_moduleD.dep_rewrite {instIdent} : ∀ {modIdent : Ident} {ε a} (H
a
(Eq.refl a)) := by intro a b c d; cases d; rfl

theorem filterId_empty : filterId (Ident := Ident) ∅ = ∅ := by rfl

@[drunfold] def build_module'
: (e : ExprLow Ident) → Option (Σ T, Module Ident T)
| .base i e => do
let mod ← ε.find? e
return ⟨ _, mod.2.renamePorts i
return ⟨ _, mod.2.renamePorts (filterId i)
| .connect o i e' => do
let e ← e'.build_module'
return ⟨ _, e.2.connect' o i ⟩
Expand Down Expand Up @@ -146,7 +148,7 @@ theorem wf_modify_expression {e : ExprLow Ident} {i i'}:

theorem build_base_in_env {T inst i mod} :
ε.find? i = some ⟨ T, mod ⟩ →
build_module' ε (base inst i) = some ⟨ T, mod.renamePorts inst ⟩ := by
build_module' ε (base inst i) = some ⟨ T, mod.renamePorts (filterId inst) ⟩ := by
intro h; dsimp [drunfold]; rw [h]; rfl

theorem wf_replace {e e_pat e'} : wf ε e → wf ε e' → wf ε (e.replace e_pat e') := by
Expand All @@ -160,9 +162,100 @@ theorem wf_abstract {e e_pat a b} : wf ε e → ε.contains b → wf ε (e.abstr

theorem build_module_unfold_1 {m r i} :
ε.find? i = some m →
build_module ε (.base r i) = ⟨ m.fst, m.snd.renamePorts r ⟩ := by
build_module ε (.base r i) = ⟨ m.fst, m.snd.renamePorts (filterId r) ⟩ := by
intro h; simp only [drunfold]; rw [h]; simp

theorem build_module_type_rename' {e : ExprLow Ident} {f g} :
(e.mapPorts2 f g |>.build_module' ε).isSome = (e.build_module' ε).isSome := by
induction e with
| base map typ =>
simp [drunfold, -AssocList.find?_eq]
| connect o i e ih =>
dsimp [drunfold, -AssocList.find?_eq]
cases h : build_module' ε e
· rw [h] at ih; simp [mapPorts2] at ih; simp [ih]
· rw [h] at ih; simp at ih; rw [Option.isSome_iff_exists] at ih; rcases ih with ⟨_, ih⟩
unfold mapPorts2 at *; rw [ih]; rfl
| product e₁ e₂ ihe₁ ihe₂ =>
dsimp [drunfold]
cases h : (build_module' ε e₁)
· rw [h] at ihe₁; simp [mapPorts2] at ihe₁; simp [ihe₁]
· cases h2 : (build_module' ε e₂)
· rw [h2] at ihe₂; simp [mapPorts2] at ihe₂; simp [ihe₂]
· rw [h] at ihe₁; simp at ihe₁; rw [Option.isSome_iff_exists] at ihe₁; rcases ihe₁ with ⟨_, ihe₁⟩
unfold mapPorts2 at *; rw [ihe₁];
rw [h2] at ihe₂; simp at ihe₂; rw [Option.isSome_iff_exists] at ihe₂; rcases ihe₂ with ⟨_, ihe₂⟩
rw [ihe₂]; rfl

theorem build_module_type_rename {e f g} :
([T| e.mapPorts2 f g, ε]) = ([T| e, ε ]) := by
induction e with
| base map typ =>
simp [drunfold, -AssocList.find?_eq]
cases h : (AssocList.find? typ ε) <;> rfl
| connect o i e ie =>
simp [drunfold, -AssocList.find?_eq]
cases h : build_module' ε e
· have : (build_module' ε (mapOutputPorts g (mapInputPorts f e))) = none := by
have := build_module_type_rename' (ε := ε) (e := e) (f := f) (g := g)
rw [h] at this; simp_all [mapPorts2]
rw [this]; rfl
· have := build_module_type_rename' (ε := ε) (e := e) (f := f) (g := g)
rw [h] at this; dsimp at this; rw [Option.isSome_iff_exists] at this
rcases this with ⟨a, this⟩
dsimp [mapPorts2] at this; rw [this]
unfold build_module_type build_module at *
unfold mapPorts2 at *
rw [this] at ie; rw [h] at ie
dsimp at ie; assumption
| product e₁ e₂ he₁ he₂ =>
simp [drunfold, -AssocList.find?_eq]
cases h : build_module' ε e₁
· have : (build_module' ε (mapOutputPorts g (mapInputPorts f e₁))) = none := by
have := build_module_type_rename' (ε := ε) (e := e₁) (f := f) (g := g)
rw [h] at this; simp_all [mapPorts2]
rw [this]; rfl
· have this := build_module_type_rename' (ε := ε) (e := e₁) (f := f) (g := g)
have this2 := build_module_type_rename' (ε := ε) (e := e₂) (f := f) (g := g)
rw [h] at this; dsimp at this; rw [Option.isSome_iff_exists] at this; rcases this with ⟨ a, this ⟩
cases h' : build_module' ε e₂
· have this3 : (build_module' ε (mapOutputPorts g (mapInputPorts f e₂))) = none := by
rw [h'] at this2; simp_all [mapPorts2]
rw [this3]; unfold mapPorts2 at *; rw [this]; rfl
· rw [h'] at this2; dsimp at this2; rw [Option.isSome_iff_exists] at this2
rcases this2 with ⟨a, this2⟩
dsimp [mapPorts2] at this this2; rw [this]
unfold build_module_type build_module at *
unfold mapPorts2 at *
dsimp; rw [this2]; dsimp
rw [h,this] at he₁
rw [h',this2] at he₂
simp at *; congr

def cast_module {S T} (h : S = T): Module Ident S = Module Ident T := by
cases h; rfl

theorem rename_build_module {e : ExprLow Ident} {f g} (h : Function.Bijective f) (h' : Function.Bijective g) :
(([e| e.mapPorts2 f g, ε])) = ((cast_module build_module_type_rename).mpr ([e| e, ε ])).mapPorts2 f g := by
-- induction e with
-- | base map typ =>
-- dsimp [drunfold]
-- cases h : (AssocList.find? typ ε).isSome
-- · simp [-AssocList.find?_eq] at h
-- sorry
-- · sorry
-- | connect o i e ih =>
-- dsimp [drunfold]
sorry

theorem rename_build_module2 {e f g} :
Function.Bijective f → Function.Bijective g →
([e| e, ε ]).mapPorts2 f g ⊑ ([e| e.mapPorts2 f g, ε]) := by sorry

theorem rename_build_module3 {e f g} :
Function.Bijective f → Function.Bijective g →
([e| e.mapPorts2 f g, ε]) ⊑ ([e| e, ε ]).mapPorts2 f g := by sorry

section Refinement

universe v w
Expand Down Expand Up @@ -195,8 +288,7 @@ theorem refines_product {e₁ e₂ e₁' e₂'} :
rw [wf1, wf2, wf3, wf4]
rw [wf1, wf3] at ref1
rw [wf2, wf4] at ref2
sorry
-- solve_by_elim [Module.refines_product]
solve_by_elim [Module.refines_product]

theorem refines_connect {e e' o i} :
wf ε e → wf ε e' →
Expand Down Expand Up @@ -301,7 +393,7 @@ theorem abstract_refines {iexpr expr_pat i} :
subst_vars
simp [drunfold, Option.bind, Option.getD, hb]
rw [hb]; simp
rw [Module.renamePorts_empty]; apply Module.refines_reflexive
rw [filterId_empty,Module.renamePorts_empty]; apply Module.refines_reflexive
· have : (if base inst typ = expr_pat then base ∅ i else base inst typ) = base inst typ := by
simp [h]
rw [this]; clear this
Expand All @@ -317,7 +409,7 @@ theorem abstract_refines {iexpr expr_pat i} :
else (e₁.replace expr_pat (base ∅ i)).product (e₂.replace expr_pat (base ∅ i))) = y
split at H <;> subst y <;> rename_i h
· subst_vars
rw [build_module_unfold_1 hfind, Module.renamePorts_empty]
rw [build_module_unfold_1 hfind, filterId_empty, Module.renamePorts_empty]
apply Module.refines_reflexive
· unfold abstract at ihe₁ ihe₂
have : wf ε (e₁.replace expr_pat (base ∅ i)) := by
Expand All @@ -335,7 +427,7 @@ theorem abstract_refines {iexpr expr_pat i} :
intro hwf
generalize h : (if connect x y e = expr_pat then base ∅ i else connect x y (e.replace expr_pat (base ∅ i))) = y
split at h <;> subst_vars
· rw [build_module_unfold_1 hfind, Module.renamePorts_empty]
· rw [build_module_unfold_1 hfind, filterId_empty, Module.renamePorts_empty]
apply Module.refines_reflexive
· have : wf ε (connect x y (e.replace expr_pat (base ∅ i))) := by
simp [wf, all]
Expand Down Expand Up @@ -366,7 +458,7 @@ theorem abstract_refines2 {iexpr expr_pat i} :
subst_vars
simp [drunfold, Option.bind, Option.getD, hb]
rw [hb]; simp
rw [Module.renamePorts_empty]; apply Module.refines_reflexive
rw [filterId_empty, Module.renamePorts_empty]; apply Module.refines_reflexive
· have : (if base inst typ = expr_pat then base ∅ i else base inst typ) = base inst typ := by
simp [h]
rw [this]; clear this
Expand All @@ -382,7 +474,7 @@ theorem abstract_refines2 {iexpr expr_pat i} :
else (e₁.replace expr_pat (base ∅ i)).product (e₂.replace expr_pat (base ∅ i))) = y
split at H <;> subst y <;> rename_i h
· subst_vars
rw [build_module_unfold_1 hfind, Module.renamePorts_empty]
rw [build_module_unfold_1 hfind, filterId_empty, Module.renamePorts_empty]
apply Module.refines_reflexive
· unfold abstract at ihe₁ ihe₂
have : wf ε (e₁.replace expr_pat (base ∅ i)) := by
Expand All @@ -400,7 +492,7 @@ theorem abstract_refines2 {iexpr expr_pat i} :
intro hwf
generalize h : (if connect x y e = expr_pat then base ∅ i else connect x y (e.replace expr_pat (base ∅ i))) = y
split at h <;> subst_vars
· rw [build_module_unfold_1 hfind, Module.renamePorts_empty]
· rw [build_module_unfold_1 hfind, filterId_empty, Module.renamePorts_empty]
apply Module.refines_reflexive
· have : wf ε (connect x y (e.replace expr_pat (base ∅ i))) := by
simp [wf, all]
Expand Down
Loading

0 comments on commit 350f975

Please sign in to comment.