diff --git a/src/Lean/Meta/Basic.lean b/src/Lean/Meta/Basic.lean index f5f9c79ad10b..05e33b47804a 100644 --- a/src/Lean/Meta/Basic.lean +++ b/src/Lean/Meta/Basic.lean @@ -129,7 +129,7 @@ structure ParamInfo where hasFwdDeps : Bool := false /-- `backDeps` contains the backwards dependencies. That is, the (0-indexed) position of previous parameters that this one depends on. -/ backDeps : Array Nat := #[] - /-- `isProp` is true if the parameter is always a proposition. -/ + /-- `isProp` is true if the parameter type is always a proposition. -/ isProp : Bool := false /-- `isDecInst` is true if the parameter's type is of the form `Decidable ...`. diff --git a/src/Lean/Meta/Tactic/Simp/Rewrite.lean b/src/Lean/Meta/Tactic/Simp/Rewrite.lean index c948bd2edb88..8587c3b70ccf 100644 --- a/src/Lean/Meta/Tactic/Simp/Rewrite.lean +++ b/src/Lean/Meta/Tactic/Simp/Rewrite.lean @@ -108,13 +108,19 @@ where trace[Meta.Tactic.simp.discharge] "{← ppOrigin thmId}, failed to synthesize instance{indentExpr type}" return false +private def useImplicitDefEqProof (thm : SimpTheorem) : SimpM Bool := do + if thm.rfl then + return (← getConfig).implicitDefEqProofs + else + return false + private def tryTheoremCore (lhs : Expr) (xs : Array Expr) (bis : Array BinderInfo) (val : Expr) (type : Expr) (e : Expr) (thm : SimpTheorem) (numExtraArgs : Nat) : SimpM (Option Result) := do recordTriedSimpTheorem thm.origin let rec go (e : Expr) : SimpM (Option Result) := do if (← isDefEq lhs e) then unless (← synthesizeArgs thm.origin bis xs) do return none - let proof? ← if thm.rfl then + let proof? ← if (← useImplicitDefEqProof thm) then pure none else let proof ← instantiateMVars (mkAppN val xs) diff --git a/src/Lean/Meta/Tactic/Split.lean b/src/Lean/Meta/Tactic/Split.lean index bd20c779acfd..650a9a1014ce 100644 --- a/src/Lean/Meta/Tactic/Split.lean +++ b/src/Lean/Meta/Tactic/Split.lean @@ -270,7 +270,7 @@ def mkDiscrGenErrorMsg (e : Expr) : MessageData := def throwDiscrGenError (e : Expr) : MetaM α := throwError (mkDiscrGenErrorMsg e) -def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := do +def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := mvarId.withContext do let some app ← matchMatcherApp? e | throwError "internal error in `split` tactic: match application expected{indentExpr e}\nthis error typically occurs when the `split` tactic internal functions have been used in a new meta-program" let matchEqns ← Match.getEquationsFor app.matcherName let mvarIds ← applyMatchSplitter mvarId app.matcherName app.matcherLevels app.params app.discrs @@ -279,43 +279,14 @@ def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := do return (i+1, mvarId::mvarIds) return mvarIds.reverse -/-- Return an `if-then-else` or `match-expr` to split. -/ -partial def findSplit? (env : Environment) (e : Expr) (splitIte := true) (exceptionSet : ExprSet := {}) : Option Expr := - go e -where - go (e : Expr) : Option Expr := - if let some target := e.find? isCandidate then - if e.isIte || e.isDIte then - let cond := target.getArg! 1 5 - -- Try to find a nested `if` in `cond` - go cond |>.getD target - else - some target - else - none - - isCandidate (e : Expr) : Bool := Id.run do - if exceptionSet.contains e then - false - else if splitIte && (e.isIte || e.isDIte) then - !(e.getArg! 1 5).hasLooseBVars - else if let some info := isMatcherAppCore? env e then - let args := e.getAppArgs - for i in [info.getFirstDiscrPos : info.getFirstDiscrPos + info.numDiscrs] do - if args[i]!.hasLooseBVars then - return false - return true - else - false - end Split open Split -partial def splitTarget? (mvarId : MVarId) (splitIte := true) : MetaM (Option (List MVarId)) := commitWhenSome? do +partial def splitTarget? (mvarId : MVarId) (splitIte := true) : MetaM (Option (List MVarId)) := commitWhenSome? do mvarId.withContext do let target ← instantiateMVars (← mvarId.getType) let rec go (badCases : ExprSet) : MetaM (Option (List MVarId)) := do - if let some e := findSplit? (← getEnv) target splitIte badCases then + if let some e ← findSplit? target (if splitIte then .both else .match) badCases then if e.isIte || e.isDIte then return (← splitIfTarget? mvarId).map fun (s₁, s₂) => [s₁.mvarId, s₂.mvarId] else @@ -334,7 +305,7 @@ partial def splitTarget? (mvarId : MVarId) (splitIte := true) : MetaM (Option (L def splitLocalDecl? (mvarId : MVarId) (fvarId : FVarId) : MetaM (Option (List MVarId)) := commitWhenSome? do mvarId.withContext do - if let some e := findSplit? (← getEnv) (← instantiateMVars (← inferType (mkFVar fvarId))) then + if let some e ← findSplit? (← instantiateMVars (← inferType (mkFVar fvarId))) then if e.isIte || e.isDIte then return (← splitIfLocalDecl? mvarId fvarId).map fun (mvarId₁, mvarId₂) => [mvarId₁, mvarId₂] else diff --git a/src/Lean/Meta/Tactic/SplitIf.lean b/src/Lean/Meta/Tactic/SplitIf.lean index d6a21e31495f..d148dfa8a003 100644 --- a/src/Lean/Meta/Tactic/SplitIf.lean +++ b/src/Lean/Meta/Tactic/SplitIf.lean @@ -8,6 +8,110 @@ import Lean.Meta.Tactic.Cases import Lean.Meta.Tactic.Simp.Main namespace Lean.Meta + +inductive SplitKind where + | ite | match | both + +def SplitKind.considerIte : SplitKind → Bool + | .ite | .both => true + | _ => false + +def SplitKind.considerMatch : SplitKind → Bool + | .match | .both => true + | _ => false + +namespace FindSplitImpl + +structure Context where + exceptionSet : ExprSet := {} + kind : SplitKind := .both + +unsafe abbrev FindM := ReaderT Context $ StateT (PtrSet Expr) MetaM + +private def isCandidate (env : Environment) (ctx : Context) (e : Expr) : Bool := Id.run do + if ctx.exceptionSet.contains e then + return false + if ctx.kind.considerIte && (e.isIte || e.isDIte) then + return !(e.getArg! 1 5).hasLooseBVars + if ctx.kind.considerMatch then + if let some info := isMatcherAppCore? env e then + let args := e.getAppArgs + for i in [info.getFirstDiscrPos : info.getFirstDiscrPos + info.numDiscrs] do + if args[i]!.hasLooseBVars then + return false + return true + return false + +@[inline] unsafe def checkVisited (e : Expr) : OptionT FindM Unit := do + if (← get).contains e then + failure + modify fun s => s.insert e + +unsafe def visit (e : Expr) : OptionT FindM Expr := do + checkVisited e + if isCandidate (← getEnv) (← read) e then + return e + else + -- We do not look for split candidates in proofs. + unless e.hasLooseBVars do + if (← isProof e) then + failure + match e with + | .lam _ _ b _ | .proj _ _ b -- We do not look for split candidates in the binder of lambdas. + | .mdata _ b => visit b + | .forallE _ d b _ => visit d <|> visit b -- We want to look for candidates at `A → B` + | .letE _ _ v b _ => visit v <|> visit b + | .app .. => visitApp? e + | _ => failure +where + visitApp? (e : Expr) : FindM (Option Expr) := + e.withApp fun f args => do + let info ← getFunInfo f + for u : i in [0:args.size] do + let arg := args[i] + if h : i < info.paramInfo.size then + let info := info.paramInfo[i] + unless info.isProp do + if info.isExplicit then + let some found ← visit arg | pure () + return found + else + let some found ← visit arg | pure () + return found + visit f + +end FindSplitImpl + +/-- Return an `if-then-else` or `match-expr` to split. -/ +partial def findSplit? (e : Expr) (kind : SplitKind := .both) (exceptionSet : ExprSet := {}) : MetaM (Option Expr) := do + go (← instantiateMVars e) +where + go (e : Expr) : MetaM (Option Expr) := do + if let some target ← find? e then + if target.isIte || target.isDIte then + let cond := target.getArg! 1 5 + -- Try to find a nested `if` in `cond` + return (← go cond).getD target + else + return some target + else + return none + + find? (e : Expr) : MetaM (Option Expr) := do + let some candidate ← unsafe FindSplitImpl.visit e { kind, exceptionSet } |>.run' mkPtrSet + | return none + trace[split.debug] "candidate:{indentExpr candidate}" + return some candidate + +/-- Return the condition and decidable instance of an `if` expression to case split. -/ +private partial def findIfToSplit? (e : Expr) : MetaM (Option (Expr × Expr)) := do + if let some iteApp ← findSplit? e .ite then + let cond := iteApp.getArg! 1 5 + let dec := iteApp.getArg! 2 5 + return (cond, dec) + else + return none + namespace SplitIf /-- @@ -63,19 +167,9 @@ private def discharge? (numIndices : Nat) (useDecide : Bool) : Simp.Discharge := def mkDischarge? (useDecide := false) : MetaM Simp.Discharge := return discharge? (← getLCtx).numIndices useDecide -/-- Return the condition and decidable instance of an `if` expression to case split. -/ -private partial def findIfToSplit? (e : Expr) : Option (Expr × Expr) := - if let some iteApp := e.find? fun e => (e.isIte || e.isDIte) && !(e.getArg! 1 5).hasLooseBVars then - let cond := iteApp.getArg! 1 5 - let dec := iteApp.getArg! 2 5 - -- Try to find a nested `if` in `cond` - findIfToSplit? cond |>.getD (cond, dec) - else - none - -def splitIfAt? (mvarId : MVarId) (e : Expr) (hName? : Option Name) : MetaM (Option (ByCasesSubgoal × ByCasesSubgoal)) := do +def splitIfAt? (mvarId : MVarId) (e : Expr) (hName? : Option Name) : MetaM (Option (ByCasesSubgoal × ByCasesSubgoal)) := mvarId.withContext do let e ← instantiateMVars e - if let some (cond, decInst) := findIfToSplit? e then + if let some (cond, decInst) ← findIfToSplit? e then let hName ← match hName? with | none => mkFreshUserName `h | some hName => pure hName @@ -107,6 +201,7 @@ def splitIfTarget? (mvarId : MVarId) (hName? : Option Name := none) : MetaM (Opt let mvarId₁ ← simpIfTarget s₁.mvarId let mvarId₂ ← simpIfTarget s₂.mvarId if s₁.mvarId == mvarId₁ && s₂.mvarId == mvarId₂ then + trace[split.failure] "`split` tactic failed to simplify target using new hypotheses Goals:\n{mvarId₁}\n{mvarId₂}" return none else return some ({ s₁ with mvarId := mvarId₁ }, { s₂ with mvarId := mvarId₂ }) @@ -119,6 +214,7 @@ def splitIfLocalDecl? (mvarId : MVarId) (fvarId : FVarId) (hName? : Option Name let mvarId₁ ← simpIfLocalDecl s₁.mvarId fvarId let mvarId₂ ← simpIfLocalDecl s₂.mvarId fvarId if s₁.mvarId == mvarId₁ && s₂.mvarId == mvarId₂ then + trace[split.failure] "`split` tactic failed to simplify target using new hypotheses Goals:\n{mvarId₁}\n{mvarId₂}" return none else return some (mvarId₁, mvarId₂) diff --git a/tests/compiler/uset.lean b/tests/compiler/uset.lean index 9562cd710501..d83389f50707 100644 --- a/tests/compiler/uset.lean +++ b/tests/compiler/uset.lean @@ -7,4 +7,3 @@ structure Point where def main : IO Unit := IO.println (Point.right ⟨0, 0⟩).x - diff --git a/tests/lean/1113.lean b/tests/lean/1113.lean index f4036fd1e3a3..0c4035e85a44 100644 --- a/tests/lean/1113.lean +++ b/tests/lean/1113.lean @@ -4,7 +4,7 @@ def foo: {n: Nat} → Fin n → Nat theorem t3 {f: Fin (n+1)}: foo f = 0 := by - simp only [←Nat.succ_eq_add_one n] at f + dsimp only [←Nat.succ_eq_add_one n] at f -- use `dsimp` to ensure we don't copy `f` trace_state simp only [←Nat.succ_eq_add_one n, foo] diff --git a/tests/lean/rfl_simp_thm.lean b/tests/lean/rfl_simp_thm.lean index 5f84fab634a4..88e3f12ca0e3 100644 --- a/tests/lean/rfl_simp_thm.lean +++ b/tests/lean/rfl_simp_thm.lean @@ -3,6 +3,6 @@ def inc (x : Nat) := x + 1 @[simp] theorem inc_eq : inc x = x + 1 := rfl theorem ex (a b : Fin (inc n)) (h : a = b) : b = a := by - simp only [inc_eq] at a + simp (config := { implicitDefEqProofs := true }) only [inc_eq] at a trace_state exact h.symm diff --git a/tests/lean/run/implicitRflProofs.lean b/tests/lean/run/implicitRflProofs.lean new file mode 100644 index 000000000000..fcf5acb04d49 --- /dev/null +++ b/tests/lean/run/implicitRflProofs.lean @@ -0,0 +1,25 @@ +def f (x : Nat) := x + 1 + +theorem f_eq (x : Nat) : f (x + 1) = x + 2 := rfl + +theorem ex1 : f (f (x + 1)) = x + 3 := by + simp (config := { implicitDefEqProofs := false }) [f_eq] + +/-- +info: theorem ex1 : ∀ {x : Nat}, f (f (x + 1)) = x + 3 := +fun {x} => + of_eq_true + (Eq.trans (congrArg (fun x_1 => x_1 = x + 3) (Eq.trans (congrArg f (f_eq x)) (f_eq (x + 1)))) (eq_self (x + 1 + 2))) +-/ +#guard_msgs in +#print ex1 + +theorem ex2 : f (f (x + 1)) = x + 3 := by + simp (config := { implicitDefEqProofs := true }) [f_eq] + +/-- +info: theorem ex2 : ∀ {x : Nat}, f (f (x + 1)) = x + 3 := +fun {x} => of_eq_true (eq_self (x + 1 + 2)) +-/ +#guard_msgs in +#print ex2 diff --git a/tests/lean/run/simp2.lean b/tests/lean/run/simp2.lean index 89f82bdda336..3d25ff858c85 100644 --- a/tests/lean/run/simp2.lean +++ b/tests/lean/run/simp2.lean @@ -4,7 +4,7 @@ def p (x : Prop) := x rfl theorem ex1 (x : Prop) (h : x) : p x := by - simp + simp (config := { implicitDefEqProofs := true }) assumption /-- @@ -14,6 +14,17 @@ fun x h => id h #guard_msgs in #print ex1 +theorem ex1' (x : Prop) (h : x) : p x := by + simp (config := { implicitDefEqProofs := false }) + assumption + +/-- +info: theorem ex1' : ∀ (x : Prop), x → p x := +fun x h => Eq.mpr (id (lemma1 x)) h +-/ +#guard_msgs in +#print ex1' + theorem ex2 (x : Prop) (q : Prop → Prop) (h₁ : x) (h₂ : q x = x) : q x := by simp [h₂] assumption diff --git a/tests/lean/run/simp5.lean b/tests/lean/run/simp5.lean index c2163c98aadb..8645b488b3bf 100644 --- a/tests/lean/run/simp5.lean +++ b/tests/lean/run/simp5.lean @@ -4,15 +4,27 @@ theorem f_Eq {α} (a b : α) : f a b = a := rfl theorem ex1 (a b c : α) : f (f a b) c = a := by - simp [f_Eq] + simp (config := { implicitDefEqProofs := false }) [f_Eq] /-- info: theorem ex1.{u_1} : ∀ {α : Sort u_1} (a b c : α), f (f a b) c = a := -fun {α} a b c => of_eq_true (eq_self a) +fun {α} a b c => + of_eq_true + (Eq.trans (congrArg (fun x => x = a) (Eq.trans (congrArg (fun x => f x c) (f_Eq a b)) (f_Eq a c))) (eq_self a)) -/ #guard_msgs in #print ex1 +theorem ex1' (a b c : α) : f (f a b) c = a := by + simp (config := { implicitDefEqProofs := true }) [f_Eq] + +/-- +info: theorem ex1'.{u_1} : ∀ {α : Sort u_1} (a b c : α), f (f a b) c = a := +fun {α} a b c => of_eq_true (eq_self a) +-/ +#guard_msgs in +#print ex1' + theorem ex2 (p : Nat → Bool) (x : Nat) (h : p x = true) : (if p x then 1 else 2) = 1 := by simp [h] diff --git a/tests/lean/run/splitIssue2.lean b/tests/lean/run/splitIssue2.lean new file mode 100644 index 000000000000..f9d3317793cb --- /dev/null +++ b/tests/lean/run/splitIssue2.lean @@ -0,0 +1,71 @@ +namespace Batteries + +/-- Union-find node type -/ +structure UFNode where + /-- Parent of node -/ + parent : Nat + +namespace UnionFind + +/-- Parent of a union-find node, defaults to self when the node is a root -/ +def parentD (arr : Array UFNode) (i : Nat) : Nat := + if h : i < arr.size then (arr.get ⟨i, h⟩).parent else i + +/-- Rank of a union-find node, defaults to 0 when the node is a root -/ +def rankD (arr : Array UFNode) (i : Nat) : Nat := 0 + +theorem parentD_of_not_lt : ¬i < arr.size → parentD arr i = i := sorry + +theorem parentD_set {arr : Array UFNode} {x v i} : + parentD (arr.set x v) i = if x.1 = i then v.parent else parentD arr i := by + rw [parentD] + sorry + +end UnionFind + +open UnionFind + +structure UnionFind where + arr : Array UFNode + +namespace UnionFind + +/-- Size of union-find structure. -/ +@[inline] abbrev size (self : UnionFind) := self.arr.size + +/-- Parent of union-find node -/ +abbrev parent (self : UnionFind) (i : Nat) : Nat := parentD self.arr i + +theorem parent_lt (self : UnionFind) (i : Nat) : self.parent i < self.size ↔ i < self.size := + sorry + +/-- Rank of union-find node -/ +abbrev rank (self : UnionFind) (i : Nat) : Nat := rankD self.arr i + +/-- Maximum rank of nodes in a union-find structure -/ +noncomputable def rankMax (self : UnionFind) := 0 + +/-- Root of a union-find node. -/ +def root (self : UnionFind) (x : Fin self.size) : Fin self.size := + let y := (self.arr.get x).parent + if h : y = x then + x + else + have : self.rankMax - self.rank (self.arr.get x).parent < self.rankMax - self.rank x := + sorry + self.root ⟨y, sorry⟩ +termination_by self.rankMax - self.rank x + +/-- Root of a union-find node. Returns input if index is out of bounds. -/ +def rootD (self : UnionFind) (x : Nat) : Nat := + if h : x < self.size then self.root ⟨x, h⟩ else x + +theorem rootD_parent (self : UnionFind) (x : Nat) : self.rootD (self.parent x) = self.rootD x := by + simp only [rootD, Array.data_length, parent_lt] + split + · simp only [parentD, ↓reduceDIte, *] + conv => rhs; rw [root] + split + · rw [root, dif_pos] <;> simp_all + · simp + · simp only [not_false_eq_true, parentD_of_not_lt, *] diff --git a/tests/lean/run/splitOrderIssue.lean b/tests/lean/run/splitOrderIssue.lean new file mode 100644 index 000000000000..457b31a551a1 --- /dev/null +++ b/tests/lean/run/splitOrderIssue.lean @@ -0,0 +1,10 @@ +example (b : Bool) : (if (if b then true else true) then 1 else 2) = 1 := by + split + next h => + guard_target =ₛ (if true = true then 1 else 2) = 1 + guard_hyp h : b = true + simp + next h => + guard_target =ₛ (if true = true then 1 else 2) = 1 + guard_hyp h : ¬b = true + simp