Skip to content

Commit

Permalink
Make progress on supporting higher-order divergent functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sonmarcho committed Dec 12, 2023
1 parent c23f317 commit dba35a4
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 53 deletions.
22 changes: 12 additions & 10 deletions backends/lean/Base/Diverge/Base.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1467,26 +1467,28 @@ namespace Ex8
let tl ← map f tl
.ret (hd :: tl)

/- The validity theorem for `map`, generic in `f` -/
/- The validity theorems for `map`, generic in `f` -/

-- This is not the most general lemma, but we keep it to test the `divergence` encoding on a simple case
@[divspec]
theorem map_is_valid
theorem map_is_valid_simple
(i : id) (t : ty i)
{{f : (a i t → Result (b i t)) → (a i t) → Result c}}
(Hfvalid : ∀ k x, is_valid_p k (λ k => f (k i t) x))
(k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t))
(ls : List (a i t)) :
is_valid_p k (λ k => map (f (k i t)) ls) := by
is_valid_p k (λ k => map (k i t) ls) := by
induction ls <;> simp [map]
apply is_valid_p_bind <;> try simp_all
intros
apply is_valid_p_bind <;> try simp_all

@[divspec]
theorem map_is_valid'
(i : id) (t : ty i)
theorem map_is_valid
(d : Type y)
{{f : ((i:id) → (t : ty i) → a i t → Result (b i t)) → d → Result c}}
(k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t))
(ls : List (a i t)) :
is_valid_p k (λ k => map (k i t) ls) := by
(Hfvalid : ∀ x1, is_valid_p k (fun kk1 => f kk1 x1))
(ls : List d) :
is_valid_p k (λ k => map (f k) ls) := by
induction ls <;> simp [map]
apply is_valid_p_bind <;> try simp_all
intros
Expand Down Expand Up @@ -1532,7 +1534,7 @@ namespace Ex9
apply is_valid_p_bind <;> try simp [*]
-- We have to show that `map k tl` is valid
-- Remark: `map_is_valid` doesn't work here, we need the specialized version
apply map_is_valid'
apply map_is_valid_simple

def body (k : (i : Fin 1) → (t : ty i) → (x : input_ty i t) → Result (output_ty i t)) (i: Fin 1) :
(t : ty i) → (x : input_ty i t) → Result (output_ty i t) := get_fun bodies i k
Expand Down
126 changes: 83 additions & 43 deletions backends/lean/Base/Diverge/Elab.lean
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,6 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
trace[Diverge.def] "individual body of {preDef.declName}: {body}"
-- Return the constant
let body := Lean.mkConst name (levelParams.map .param)
-- let body ← mkAppM' body #[kk_var]
trace[Diverge.def] "individual body (after decl): {body}"
pure body

Expand Down Expand Up @@ -665,7 +664,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
proveAppIsValid k_var kk_var e f args

partial def proveAppIsValid (k_var kk_var : Expr) (e : Expr) (f : Expr) (args : Array Expr): MetaM Expr := do
trace[Diverge.def.valid] "proveAppIsValid: {f} {args}"
trace[Diverge.def.valid] "proveAppIsValid: {e}\nDecomposed: {f} {args}"
/- There are several cases: first, check if this is a match/if
Check if the expression is a (dependent) if then else.
We treat the if then else expressions differently from the other matches,
Expand Down Expand Up @@ -821,7 +820,8 @@ partial def proveAppIsValid (k_var kk_var : Expr) (e : Expr) (f : Expr) (args :
- if no: this is simple
- if yes: we have to lookup theorems in div spec database and continue -/
trace[Diverge.def.valid] "regular app: {e}, f: {f}, args: {args}"
let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty
let argsFVars ← args.mapM getFVarIds
let allArgsFVars := argsFVars.foldl (fun hs fvars => hs.insertMany fvars) HashSet.empty
trace[Diverge.def.valid] "allArgsFVars: {allArgsFVars.toList.map mkFVar}"
if ¬ allArgsFVars.contains kk_var.fvarId! then do
-- Simple case
Expand All @@ -837,7 +837,6 @@ partial def proveAppIsValid (k_var kk_var : Expr) (e : Expr) (f : Expr) (args :

partial def proveAppIsValidApplyThms (k_var kk_var : Expr) (e : Expr)
(f : Expr) (args : Array Expr) (thms : List Name) : MetaM Expr := do
trace[Diverge.def.valid] "thms: {thms}"
match thms with
| [] => throwError "Could not prove that the following expression is valid: {e}"
| thName :: thms =>
Expand All @@ -849,58 +848,67 @@ partial def proveAppIsValidApplyThms (k_var kk_var : Expr) (e : Expr)
thDecl.levelParams.mapM (λ x => do pure (x, ← mkFreshLevelMVar))
let ulMap : HashMap Name Level := HashMap.ofList ul
let thTy := thDecl.type.instantiateLevelParamsCore (λ x => ulMap.find! x)
trace[Diverge.def.valid] "Trying with theorem {thName}: {thTy}"
-- Introduce meta variables for the universally quantified variables
let (mvars, _binders, thTy) ← forallMetaTelescope thTy
-- thTy should now be of the shape: `is_valid_p k (λ kk => ...)`
/-thTy.consumeMData.withApp fun _ args => do
if args.size ≠ 7 then throwError "Invalid number of arguments (expected 7): {thTy}"
let thTermToMatch := args.get! 6 -/
let thTermToMatch := thTy
let (mvars, _binders, thTyBody) ← forallMetaTelescope thTy
let thTermToMatch := thTyBody
trace[Diverge.def.valid] "thTermToMatch: {thTermToMatch}"
-- Create the term: `is_valid_p k (λ kk => e)`
let termToMatch ← mkLambdaFVars #[kk_var] e
let termToMatch ← mkAppM ``FixII.is_valid_p #[k_var, termToMatch]
trace[Diverge.def.valid] "termToMatch: {termToMatch}"
-- Attempt to match
let ok ← isDefEq thTermToMatch termToMatch
trace[Diverge.def.valid] "Matching terms:\n\n{termToMatch}\n\n{thTermToMatch}"
let ok ← isDefEq termToMatch thTermToMatch
if ¬ ok then
-- Failure: attempt with the other theorems
proveAppIsValidApplyThms k_var kk_var e f args thms
else do
-- Success: continue with this theorem
-- Instantiate the meta variables (some of them will not be instantiated:
-- they are new subgoals)
/- Success: continue with this theorem
Instantiate the meta variables (some of them will not be instantiated:
they are new subgoals)
-/
let mvars ← mvars.mapM instantiateMVars
let th ← mkAppOptM thName (Array.map some mvars)
trace[Diverge.def.valid] "Instantiated theorm: {th}\n{← inferType th}"
-- Filter the meta variables between the instantiated ones
for mvar in mvars do
if mvar.isMVar then do
-- Prove the subgoal (i.e., the precondition of the theorem)
let mvarId := mvar.mvarId!
let mvarDecl ← mvarId.getDecl
-- Dive in the type
forallTelescope mvarDecl.type fun forall_vars mvar_e => do
-- `mvar_e` should have the shape `is_valid_p k (λ kk => ...)`
-- We need to retrieve the new `k` variable, and dive into the
-- `λ kk => ...`
mvar_e.consumeMData.withApp fun is_valid args => do
if is_valid.constName? ≠ ``FixII.is_valid_p ∨ args.size ≠ 2 then
throwError "Invalid precondition: {mvar_e}"
else do
let k_var := args.get! 0
let e_lam := args.get! 1
lambdaTelescope e_lam.consumeMData fun lvars e => do
if lvars.size ≠ 1 then throwError "Invalid number of lambdas (expected 1): {e_lam}"
let kk_var := lvars.get! 0
-- Continue
let e_valid ← proveExprIsValid k_var kk_var e
let e_valid ← mkForallFVars forall_vars e_valid
-- Assign the meta variable
mvarId.assign e_valid
else
-- Nothing to do
pure ()
trace[Diverge.def.valid] "Instantiated theorem: {th}\n{← inferType th}"
-- Filter the instantiated meta variables
let mvars := mvars.filter (fun v => v.isMVar)
let mvars := mvars.map (fun v => v.mvarId!)
trace[Diverge.def.valid] "Remaining subgoals: {mvars}"
for mvarId in mvars do
-- Prove the subgoal (i.e., the precondition of the theorem)
let mvarDecl ← mvarId.getDecl
let declType ← instantiateMVars mvarDecl.type
-- Reduce the subgoal before diving in, if necessary
trace[Diverge.def.valid] "Subgoal: {declType}"
-- Dive in the type
forallTelescope declType fun forall_vars mvar_e => do
trace[Diverge.def.valid] "forall_vars: {forall_vars}"
-- `mvar_e` should have the shape `is_valid_p k (λ kk => ...)`
-- We need to retrieve the new `k` variable, and dive into the
-- `λ kk => ...`
mvar_e.consumeMData.withApp fun is_valid args => do
if is_valid.constName? ≠ ``FixII.is_valid_p ∨ args.size ≠ 7 then
throwError "Invalid precondition: {mvar_e}"
else do
let k_var := args.get! 5
let e_lam := args.get! 6
trace[Diverge.def.valid] "k_var: {k_var}\ne_lam: {e_lam}"
-- The outer lambda should be for the kk_var
lambdaOne e_lam.consumeMData fun kk_var e => do
-- Continue
trace[Diverge.def.valid] "kk_var: {kk_var}\ne: {e}"
-- We sometimes need to reduce the term
let e ← whnf e
trace[Diverge.def.valid] "e (after reduction): {e}"
let e_valid ← proveExprIsValid k_var kk_var e
trace[Diverge.def.valid] "e_valid (for e): {e_valid}"
let e_valid ← mkLambdaFVars forall_vars e_valid
trace[Diverge.def.valid] "e_valid (with foralls): {e_valid}"
let _ ← inferType e_valid -- Sanity check
-- Assign the meta variable
mvarId.assign e_valid
pure th

-- Prove that a match expression is valid.
Expand Down Expand Up @@ -1442,6 +1450,38 @@ namespace Tests

#check id.unfold

-- set_option pp.explicit true
-- set_option trace.Diverge.def true
-- set_option trace.Diverge.def.genBody true
set_option trace.Diverge.def.valid true
divergent def id1 {a : Type u} (t : Tree a) : Result (Tree a) :=
match t with
| .leaf x => .ret (.leaf x)
| .node tl =>
do
let tl ← map (fun x => id1 x) tl
.ret (.node tl)

#check id1.unfold

/-set_option trace.Diverge.def false
-- set_option pp.explicit true
-- set_option trace.Diverge.def true
-- set_option trace.Diverge.def.genBody true
set_option trace.Diverge.def.valid true
divergent def id2 {a : Type u} (t : Tree a) : Result (Tree a) :=
match t with
| .leaf x => .ret (.leaf x)
| .node tl =>
do
let tl ← map (fun x => do let _ ← id2 x; id2 x) tl
.ret (.node tl)
#check id2.unfold
set_option trace.Diverge.def false -/

/-set_option trace.Diverge.def true
-- set_option trace.Diverge.def.genBody true
set_option trace.Diverge.def.valid true
Expand Down

0 comments on commit dba35a4

Please sign in to comment.