diff --git a/src/Lean/Meta/Injective.lean b/src/Lean/Meta/Injective.lean index d4d5cef40826..c9fb53a20e19 100644 --- a/src/Lean/Meta/Injective.lean +++ b/src/Lean/Meta/Injective.lean @@ -30,6 +30,14 @@ def elimOptParam (type : Expr) : CoreM Expr := do else return .continue +def occursOrInType (e : Expr) (t : Expr) : MetaM Bool := do + let_fun f (s : Expr) := do + if !s.isFVar then + return s == e + let ty ← inferType s + return s == e || e.occurs ty + return (← t.findM? f).isSome + private partial def mkInjectiveTheoremTypeCore? (ctorVal : ConstructorVal) (useEq : Bool) : MetaM (Option Expr) := do let us := ctorVal.levelParams.map mkLevelParam let type ← elimOptParam ctorVal.type @@ -57,7 +65,7 @@ private partial def mkInjectiveTheoremTypeCore? (ctorVal : ConstructorVal) (useE match (← whnf type) with | Expr.forallE n d b _ => let arg1 := args1.get ⟨i, h⟩ - if arg1.occurs resultType then + if ← occursOrInType arg1 resultType then mkArgs2 (i + 1) (b.instantiate1 arg1) (args2.push arg1) args2New else withLocalDecl n (if useEq then BinderInfo.default else BinderInfo.implicit) d fun arg2 => @@ -103,6 +111,7 @@ private def mkInjectiveTheorem (ctorVal : ConstructorVal) : MetaM Unit := do | return () let value ← mkInjectiveTheoremValue ctorVal.name type let name := mkInjectiveTheoremNameFor ctorVal.name + trace[Meta.injective] "theorem {name} : {type} := {value}" addDecl <| Declaration.thmDecl { name levelParams := ctorVal.levelParams @@ -134,6 +143,7 @@ private def mkInjectiveEqTheorem (ctorVal : ConstructorVal) : MetaM Unit := do | return () let value ← mkInjectiveEqTheoremValue ctorVal.name type let name := mkInjectiveEqTheoremNameFor ctorVal.name + trace[Meta.injective] "theorem {name} : {type} := {value}" addDecl <| Declaration.thmDecl { name levelParams := ctorVal.levelParams diff --git a/src/Lean/Util/FindExpr.lean b/src/Lean/Util/FindExpr.lean index 56217625b889..f37fbbe00259 100644 --- a/src/Lean/Util/FindExpr.lean +++ b/src/Lean/Util/FindExpr.lean @@ -55,6 +55,18 @@ def find? (p : Expr → Bool) (e : Expr) : Option Expr := def occurs (e : Expr) (t : Expr) : Bool := (t.find? fun s => s == e).isSome +def findM? [Monad m] (p : Expr → m Bool) (e : Expr) : m (Option Expr) := do + if ← p e then + return some e + else match e with + | .forallE _ d b _ => findM? p d <||> findM? p b + | .lam _ d b _ => findM? p d <||> findM? p b + | .mdata _ b => findM? p b + | .letE _ t v b _ => findM? p t <||> findM? p v <||> findM? p b + | .app f a => findM? p f <||> findM? p a + | .proj _ _ b => findM? p b + | _ => pure none + /-- Return type for `findExt?` function argument. -/ diff --git a/tests/lean/run/3386.lean b/tests/lean/run/3386.lean new file mode 100644 index 000000000000..f470fdb72993 --- /dev/null +++ b/tests/lean/run/3386.lean @@ -0,0 +1,9 @@ +/- Verify that injectivity lemmas are constructed with the right level of generality + in order to avoid type errors. +-/ + +inductive Tyₛ : Type (u+1) +| SPi : (T : Type u) -> (T -> Tyₛ) -> Tyₛ + +inductive Tmₛ.{u} : Tyₛ.{u} -> Type (u+1) +| app : Tmₛ (.SPi T A) -> (arg : T) -> Tmₛ (A arg)