Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: FunInd: erase, not clear #6923

Merged
merged 1 commit into from
Feb 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 36 additions & 40 deletions src/Lean/Meta/Tactic/FunInd.lean
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,11 @@ namespace Lean.Tactic.FunInd

open Lean Elab Meta

/-- Opens the body of a lambda, _without_ putting the free variable into the local context.
This is used when replacing parameters with different expressions.
This way it will not be picked up by metavariables.
-/
def removeLamda {n} [MonadLiftT MetaM n] [MonadError n] [MonadNameGenerator n] [Monad n] {α} (e : Expr) (k : FVarId → Expr → n α) : n α := do
let .lam _n _d b _bi := ← whnfD e | throwError "removeLamda: expected lambda, got {e}"
let x ← mkFreshFVarId
let b := b.instantiate1 (.fvar x)
k x b

def lambdaTelescope1 {n} [MonadControlT MetaM n] [MonadError n] [MonadNameGenerator n] [Monad n] {α} (e : Expr) (k : FVarId → Expr → n α) : n α := do
lambdaBoundedTelescope e 1 fun xs body => do
unless xs.size == 1 do
throwError "lambdaTelescope1: expected lambda, got {e}"
k xs[0]!.fvarId! body

/--
A monad to help collecting inductive hypothesis.
Expand Down Expand Up @@ -294,7 +289,7 @@ partial def foldAndCollect (oldIH newIH : FVarId) (isRecCall : Expr → Option E
let dummyGoal := mkConst ``True []
mkArrow eTypeAbst dummyGoal)
(onAlt := fun altType alt => do
removeLamda alt fun oldIH' alt => do
lambdaTelescope1 alt fun oldIH' alt => do
forallBoundedTelescope altType (some 1) fun newIH' _goal' => do
let #[newIH'] := newIH' | unreachable!
let altIHs ← M.exec <| foldAndCollect oldIH' newIH'.fvarId! isRecCall alt
Expand All @@ -311,7 +306,7 @@ partial def foldAndCollect (oldIH newIH : FVarId) (isRecCall : Expr → Option E
let some (_extra, body) := motiveBody.arrow? | throwError "motive not an arrow"
M.eval (foldAndCollect oldIH newIH isRecCall body))
(onAlt := fun _altType alt => do
removeLamda alt fun oldIH alt => do
lambdaTelescope1 alt fun oldIH alt => do
M.eval (foldAndCollect oldIH newIH isRecCall alt))
(onRemaining := fun _ => pure #[])
return matcherApp'.toExpr
Expand Down Expand Up @@ -461,22 +456,23 @@ def M2.branch {α} (act : M2 α) : M2 α :=


/-- Base case of `buildInductionBody`: Construct a case for the final induction hypthesis. -/
def buildInductionCase (oldIH newIH : FVarId) (isRecCall : Expr → Option Expr) (toClear : Array FVarId)
def buildInductionCase (oldIH newIH : FVarId) (isRecCall : Expr → Option Expr) (toErase : Array FVarId)
(goal : Expr) (e : Expr) : M2 Expr := do
withTraceNode `Meta.FunInd (pure m!"{exceptEmoji ·} buildInductionCase:{indentExpr e}") do
let _e' ← foldAndCollect oldIH newIH isRecCall e
let IHs : Array Expr ← M.ask
let IHs ← deduplicateIHs IHs

let mvar ← mkFreshExprSyntheticOpaqueMVar goal (tag := `hyp)
let mut mvarId := mvar.mvarId!
mvarId ← assertIHs IHs mvarId
trace[Meta.FunInd] "Goal before cleanup:{mvarId}"
for fvarId in toClear do
mvarId ← mvarId.clear fvarId
modify (·.push mvarId)
let mvar ← instantiateMVars mvar
pure mvar
withErasedFVars toErase do
let mvar ← mkFreshExprSyntheticOpaqueMVar goal (tag := `hyp)
let mut mvarId := mvar.mvarId!
mvarId ← assertIHs IHs mvarId
trace[Meta.FunInd] "Goal before cleanup:{mvarId}"
-- for fvarId in toErase do
-- mvarId ← mvarId.clear fvarId
modify (·.push mvarId)
let mvar ← instantiateMVars mvar
pure mvar

/--
Like `mkLambdaFVars (usedOnly := true)`, but
Expand Down Expand Up @@ -518,7 +514,7 @@ Builds an expression of type `goal` by replicating the expression `e` into its t
where it calls `buildInductionCase`. Collects the cases of the final induction hypothesis
as `MVars` as it goes.
-/
partial def buildInductionBody (toClear : Array FVarId) (goal : Expr)
partial def buildInductionBody (toErase : Array FVarId) (goal : Expr)
(oldIH newIH : FVarId) (isRecCall : Expr → Option Expr) (e : Expr) : M2 Expr := do
withTraceNode `Meta.FunInd
(pure m!"{exceptEmoji ·} buildInductionBody: {oldIH.name} → {newIH.name}:{indentExpr e}") do
Expand All @@ -529,10 +525,10 @@ partial def buildInductionBody (toClear : Array FVarId) (goal : Expr)
let c' ← foldAndCollect oldIH newIH isRecCall c
let h' ← foldAndCollect oldIH newIH isRecCall h
let t' ← withLocalDecl `h .default c' fun h => M2.branch do
let t' ← buildInductionBody toClear goal oldIH newIH isRecCall t
let t' ← buildInductionBody toErase goal oldIH newIH isRecCall t
mkLambdaFVars #[h] t'
let f' ← withLocalDecl `h .default (mkNot c') fun h => M2.branch do
let f' ← buildInductionBody toClear goal oldIH newIH isRecCall f
let f' ← buildInductionBody toErase goal oldIH newIH isRecCall f
mkLambdaFVars #[h] f'
let u ← getLevel goal
return mkApp5 (mkConst ``dite [u]) goal c' h' t' f'
Expand All @@ -541,11 +537,11 @@ partial def buildInductionBody (toClear : Array FVarId) (goal : Expr)
let h' ← foldAndCollect oldIH newIH isRecCall h
let t' ← withLocalDecl `h .default c' fun h => M2.branch do
let t ← instantiateLambda t #[h]
let t' ← buildInductionBody toClear goal oldIH newIH isRecCall t
let t' ← buildInductionBody toErase goal oldIH newIH isRecCall t
mkLambdaFVars #[h] t'
let f' ← withLocalDecl `h .default (mkNot c') fun h => M2.branch do
let f ← instantiateLambda f #[h]
let f' ← buildInductionBody toClear goal oldIH newIH isRecCall f
let f' ← buildInductionBody toErase goal oldIH newIH isRecCall f
mkLambdaFVars #[h] f'
let u ← getLevel goal
return mkApp5 (mkConst ``dite [u]) goal c' h' t' f'
Expand All @@ -556,8 +552,8 @@ partial def buildInductionBody (toClear : Array FVarId) (goal : Expr)
match_expr goal with
| And goal₁ goal₂ => match_expr e with
| PProd.mk _α _β e₁ e₂ =>
let e₁' ← buildInductionBody toClear goal₁ oldIH newIH isRecCall e₁
let e₂' ← buildInductionBody toClear goal₂ oldIH newIH isRecCall e₂
let e₁' ← buildInductionBody toErase goal₁ oldIH newIH isRecCall e₁
let e₂' ← buildInductionBody toErase goal₂ oldIH newIH isRecCall e₂
return mkApp4 (.const ``And.intro []) goal₁ goal₂ e₁' e₂'
| _ =>
throwError "Goal is PProd, but expression is:{indentExpr e}"
Expand All @@ -580,10 +576,10 @@ partial def buildInductionBody (toClear : Array FVarId) (goal : Expr)
(onParams := (foldAndCollect oldIH newIH isRecCall ·))
(onMotive := fun xs _body => pure (absMotiveBody.beta (maskArray mask xs)))
(onAlt := fun expAltType alt => M2.branch do
removeLamda alt fun oldIH' alt => do
lambdaTelescope1 alt fun oldIH' alt => do
forallBoundedTelescope expAltType (some 1) fun newIH' goal' => do
let #[newIH'] := newIH' | unreachable!
let alt' ← buildInductionBody (toClear.push newIH'.fvarId!) goal' oldIH' newIH'.fvarId! isRecCall alt
let alt' ← buildInductionBody (toErase ++ #[oldIH', newIH'.fvarId!]) goal' oldIH' newIH'.fvarId! isRecCall alt
mkLambdaFVars #[newIH'] alt')
(onRemaining := fun _ => pure #[.fvar newIH])
return matcherApp'.toExpr
Expand All @@ -599,29 +595,29 @@ partial def buildInductionBody (toClear : Array FVarId) (goal : Expr)
(onParams := (foldAndCollect oldIH newIH isRecCall ·))
(onMotive := fun xs _body => pure (absMotiveBody.beta (maskArray mask xs)))
(onAlt := fun expAltType alt => M2.branch do
buildInductionBody toClear expAltType oldIH newIH isRecCall alt)
buildInductionBody toErase expAltType oldIH newIH isRecCall alt)
return matcherApp'.toExpr

-- we look through mdata
if e.isMData then
let b ← buildInductionBody toClear goal oldIH newIH isRecCall e.mdataExpr!
let b ← buildInductionBody toErase goal oldIH newIH isRecCall e.mdataExpr!
return e.updateMData! b

if let .letE n t v b _ := e then
let t' ← foldAndCollect oldIH newIH isRecCall t
let v' ← foldAndCollect oldIH newIH isRecCall v
return ← withLetDecl n t' v' fun x => M2.branch do
let b' ← buildInductionBody toClear goal oldIH newIH isRecCall (b.instantiate1 x)
let b' ← buildInductionBody toErase goal oldIH newIH isRecCall (b.instantiate1 x)
mkLetFVars #[x] b'

if let some (n, t, v, b) := e.letFun? then
let t' ← foldAndCollect oldIH newIH isRecCall t
let v' ← foldAndCollect oldIH newIH isRecCall v
return ← withLocalDeclD n t' fun x => M2.branch do
let b' ← buildInductionBody toClear goal oldIH newIH isRecCall (b.instantiate1 x)
let b' ← buildInductionBody toErase goal oldIH newIH isRecCall (b.instantiate1 x)
mkLetFun x v' b'

liftM <| buildInductionCase oldIH newIH isRecCall toClear goal e
liftM <| buildInductionCase oldIH newIH isRecCall toErase goal e

/--
Given an expression `e` with metavariables `mvars`
Expand Down Expand Up @@ -708,9 +704,9 @@ def deriveUnaryInduction (name : Name) : MetaM Name := do
let extraParams := xs[2:]
-- open body with the same arg
let body ← instantiateLambda body targets
removeLamda body fun oldIH body => do
lambdaTelescope1 body fun oldIH body => do
let body ← instantiateLambda body extraParams
let body' ← buildInductionBody #[genIH.fvarId!] goal oldIH genIH.fvarId! isRecCall body
let body' ← buildInductionBody #[oldIH, genIH.fvarId!] goal oldIH genIH.fvarId! isRecCall body
if body'.containsFVar oldIH then
throwError m!"Did not fully eliminate {mkFVar oldIH} from induction principle body:{indentExpr body}"
mkLambdaFVars (targets.push genIH) (← mkLambdaFVars extraParams body')
Expand Down Expand Up @@ -1013,10 +1009,10 @@ def deriveInductionStructural (names : Array Name) (numFixed : Nat) : MetaM Unit
let extraParams := xs[numTargets+1:]
-- open body with the same arg
let body ← instantiateLambda brecOnMinor targets
removeLamda body fun oldIH body => do
lambdaTelescope1 body fun oldIH body => do
trace[Meta.FunInd] "replacing {Expr.fvar oldIH} with {genIH}"
let body ← instantiateLambda body extraParams
let body' ← buildInductionBody #[genIH.fvarId!] goal oldIH genIH.fvarId! isRecCall body
let body' ← buildInductionBody #[oldIH, genIH.fvarId!] goal oldIH genIH.fvarId! isRecCall body
if body'.containsFVar oldIH then
throwError m!"Did not fully eliminate {mkFVar oldIH} from induction principle body:{indentExpr body}"
mkLambdaFVars (targets.push genIH) (← mkLambdaFVars extraParams body')
Expand Down
Loading