Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add @egg tags #27

Merged
merged 20 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Lean/Egg.lean
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ import Egg.Tactic.Basic
import Egg.Tactic.Calc
import Egg.Tactic.Guides
import Egg.Tactic.Trace
import Egg.Tactic.Tags
1 change: 1 addition & 0 deletions Lean/Egg/Core/Config.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ structure Encoding extends Normalization where

structure Gen where
builtins := true
tagged? := some `egg
genTcProjRws := true
genTcSpecRws := true
genGoalTcSpec := true
Expand Down
2 changes: 2 additions & 0 deletions Lean/Egg/Core/Explanation/Parse.lean
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ syntax "*" noWs num : egg_basic_fwd_rw_src
syntax "⊢" : egg_basic_fwd_rw_src
syntax "↣" noWs num : egg_basic_fwd_rw_src
syntax "◯" noWs num : egg_basic_fwd_rw_src
syntax "□" noWs num (noWs "/" noWs num)? : egg_basic_fwd_rw_src

syntax "[" egg_tc_proj_loc num "," num "]" : egg_tc_proj

Expand Down Expand Up @@ -119,6 +120,7 @@ private def parseTcProjLocation : (TSyntax `egg_tc_proj_loc) → Source.TcProjLo

private def parseBasicFwdRwSrc : (TSyntax `egg_basic_fwd_rw_src) → Source
| `(egg_basic_fwd_rw_src|#$idx$[/$eqn?]?) => .explicit idx.getNat (eqn?.map TSyntax.getNat)
| `(egg_basic_fwd_rw_src|□$idx$[/$eqn?]?) => .tagged idx.getNat (eqn?.map TSyntax.getNat)
| `(egg_basic_fwd_rw_src|*$idx) => .star (.fromUniqueIdx idx.getNat)
| `(egg_basic_fwd_rw_src|⊢) => .goal
| `(egg_basic_fwd_rw_src|↣$idx) => .guide idx.getNat
Expand Down
14 changes: 8 additions & 6 deletions Lean/Egg/Tactic/Premises/Gen.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ namespace Egg.Premises
-- TODO: Perform pruning during generation, not after.

private def tracePremises
(basic : WithSyntax Rewrites) (builtins tc pruned : Rewrites) (facts : WithSyntax Facts)
(basic : WithSyntax Rewrites) (tagged builtins tc pruned : Rewrites) (facts : WithSyntax Facts)
(cfg : Config.Gen) : TacticM Unit := do
let cls := `egg.rewrites
withTraceNode cls (fun _ => return "Rewrites") do
withTraceNode cls (fun _ => return m!"Basic ({basic.elems.size})") do basic.elems.trace basic.stxs cls
withTraceNode cls (fun _ => return m!"Tagged ({tagged.size})") do tagged.trace #[] cls
withTraceNode cls (fun _ => return m!"Generated ({tc.size})") do tc.trace #[] cls
withTraceNode cls (fun _ => return m!"Builtin ({builtins.size})") do builtins.trace #[] cls
withTraceNode cls (fun _ => return m!"Hypotheses ({facts.elems.size})") do
Expand All @@ -30,14 +31,15 @@ private def tracePremises
partial def gen
(goal : Congr) (ps : TSyntax `egg_premises) (guides : Guides) (cfg : Config)
(amb : MVars.Ambient) : TacticM (Rewrites × Facts) := do
let tagged ← Premises.buildTagged cfg amb
let ⟨⟨basic, basicStxs⟩, facts⟩ ← Premises.elab { norm? := cfg, amb } ps
let (basic, basicStxs, pruned₁) ← prune basic basicStxs (remove := #[])
let (basic, basicStxs, pruned₁) ← prune basic basicStxs (remove := tagged)
let builtins ← if cfg.builtins then Rewrites.builtins { norm? := cfg, amb } else pure #[]
let (builtins, _, pruned₂) ← prune builtins (remove := basic)
let (builtins, _, pruned₂) ← prune builtins (remove := tagged ++ basic)
let tc ← genTcRws (basic ++ builtins) facts.elems
let (tc, _, pruned₃) ← prune tc (remove := basic ++ builtins)
tracePremises ⟨basic, basicStxs⟩ builtins tc (pruned₁ ++ pruned₂ ++ pruned₃) facts cfg
let rws := basic ++ builtins ++ tc
let (tc, _, pruned₃) ← prune tc (remove := tagged ++ basic ++ builtins)
tracePremises ⟨basic, basicStxs⟩ tagged builtins tc (pruned₁ ++ pruned₂ ++ pruned₃) facts cfg
let rws := tagged ++ basic ++ builtins ++ tc
catchInvalidConditionals rws
return (rws, facts.elems)
where
Expand Down
15 changes: 15 additions & 0 deletions Lean/Egg/Tactic/Premises/Parse.lean
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import Egg.Core.Premise.Rewrites
import Egg.Core.Premise.Facts
import Egg.Tactic.Premises.Validate
import Egg.Tactic.Tags
import Lean

open Lean Meta Elab Tactic
Expand Down Expand Up @@ -117,6 +119,19 @@ private def Premises.taggedRw (prem : Name) (idx : Nat) (cfg : Rewrite.Config) :
let rws ← Premises.explicit ident idx mk .tagged
return rws.elems

private def Premises.elabTagged (prems : Array Name) (cfg : Rewrite.Config) : TacticM Rewrites := do
let mut rws : Rewrites := #[]
for prem in prems, idx in [:prems.size] do
rws := rws ++ (← taggedRw prem idx cfg)
return rws

def Premises.buildTagged (cfg : Config) (amb : MVars.Ambient ): TacticM Rewrites :=
match cfg.tagged? with
| none => return {}
| some _ => do -- This should later use this `Name` to find the proper extension
let prems := eggXtension.getState (← getEnv)
elabTagged prems { norm? := cfg, amb}

-- Note: This function is expected to be called with the lctx which contains the desired premises.
--
-- Note: We need to filter out auxiliary declaration and implementation details, as they are not
Expand Down
53 changes: 53 additions & 0 deletions Lean/Egg/Tactic/Premises/Validate.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import Egg.Core.Premise.Rewrites

open Lean Meta Elab Tactic

inductive Premise.Raw where
| single (expr : Expr) (type? : Option Expr := none)
| eqns (exprs : Array Expr)

inductive Premise.RawRaw where
| eqns (exprs : Array Name)
| single (expr : Expr) (type? : Option Expr := none)
| prem (prem : Term)


def Premise.Raw.validate (prem : Term) : MetaM Premise.RawRaw := do
if let some const ← optional (resolveGlobalConstNoOverload prem) then
if let some eqs ← getEqnsFor? const (nonRec := true) then
-- `prem` is a global definition.
return .eqns eqs
else
-- `prem` is an global constant which is not a definition with equations.
let env ← getEnv
let some info := env.find? const | throwErrorAt prem m!"unknown constant '{mkConst const}'"
match info with
| .defnInfo _ | .axiomInfo _ | .thmInfo _ | .opaqueInfo _ =>
let lvlMVars ← List.replicateM info.numLevelParams mkFreshLevelMVar
let val := if info.hasValue then info.instantiateValueLevelParams! lvlMVars else .const info.name lvlMVars
let type := info.instantiateTypeLevelParams lvlMVars
return .single val type
| _ => throwErrorAt prem "egg requires arguments to be theorems, definitions or axioms"
else
-- `prem` is an invalid identifier or a term which is not an identifier.
-- We must use `Tactic.elabTerm`, not `Term.elabTerm`. Otherwise elaborating `‹...›` doesn't
-- work correctly. See https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Elaborate.20.E2.80.B9.2E.2E.2E.E2.80.BA
return .prem prem

-- We don't just elaborate premises directly as:
-- (1) this can cause problems for global constants with typeclass arguments, as Lean sometimes
-- tries to synthesize the arguments and fails if it can't (instead of inserting mvars).
-- (2) global constants which are definitions with equations (cf. `getEqnsFor?`) are supposed to
-- be replaced by their defining equations.
partial def Premise.Raw.elab (prem : Term) : TacticM Premise.Raw := do
if let some hyp ← optional (getFVarId prem) then
-- `prem` is a local declaration.
let decl ← hyp.getDecl
if decl.isImplementationDetail || decl.isAuxDecl then
throwErrorAt prem "egg does not support using auxiliary declarations"
else
return .single (.fvar hyp) (← hyp.getType)
match (← validate prem) with
| .eqns eqs => return .eqns <| ← eqs.mapM fun eqn => Tactic.elabTerm (mkIdent eqn) none
| .single val type? => return .single val type?
| .prem prem => return .single (← Tactic.elabTerm prem none)
110 changes: 110 additions & 0 deletions Lean/Egg/Tactic/Tags.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import Egg.Tactic.Premises.Validate
import Lean

open Lean Elab Tactic Term

namespace Egg

/--
This validates that a theorem can be used by the `egg` tactic (it ultimately boils down to an equality.)

Unimplemented: Currently, this is a noop.
-/
private def validateEggTheorem (thm : Term) : MetaM Unit := do
let _ ← Premise.Raw.validate thm
return ()

-- Ideally this should be at some point a discrimination tree
abbrev EggTheorems := Array Name

abbrev EggEntry := Name -- later: Lean.Meta.SimpEntry

def addEggTheoremEntry (d : EggTheorems) (e : EggEntry) : EggTheorems :=
d.push e

abbrev EggXtension := SimpleScopedEnvExtension EggEntry EggTheorems

open Lean.Elab
open Lean.Elab.Command

def EggXtension.getTheorems (ext : EggXtension) : CoreM EggTheorems :=
return ext.getState (← getEnv)

/--
This function does the appropriate preprocessing from a `Name` to record a theorem as
an `egg` theorem.

For now, this preprocessing is nothing (just store the name in a singleton `Array`).
However, in the future this can be used like simp to do more efficient preprocessing
and deal with other kinds of `egg` lemmas (or even import `simp` lemmas).
-/
private def mkEggTheoremsFromConst (declName : Name) : MetaM EggTheorems :=
pure #[declName]

def addEggTheorem (ext : EggXtension) (declName : Name) (attrKind : AttributeKind) : MetaM Unit := do
let _ ← validateEggTheorem { raw := Syntax.ident default default declName []} -- ugly!
let eggThms ← mkEggTheoremsFromConst declName
for eggThm in eggThms do
ext.add eggThm attrKind

def mkEggXt (name : Name := by exact decl_name%) : IO EggXtension :=
registerSimpleScopedEnvExtension {
name := name
initial := {}
addEntry := fun d e => addEggTheoremEntry d e
}

def mkEggAttr (attrName : Name) (attrDescr : String) (ext : EggXtension)
(ref : Name := by exact decl_name%) : IO Unit :=
registerBuiltinAttribute {
ref := ref
name := attrName
descr := attrDescr
applicationTime := AttributeApplicationTime.afterCompilation
add := fun declName _stx attrKind => do
let go : MetaM Unit := do
let info ← getConstInfo declName
if (← Meta.isProp info.type) then
addEggTheorem ext declName attrKind
else
throwError "invalid 'egg', it is not a proposition"
discard <| go.run {} {}
erase := fun declName => do
let s := ext.getState (← getEnv)
let s := s.erase (declName)
modifyEnv fun env => ext.modifyState env fun _ => s
}


abbrev EggXtensionMap := HashMap Name EggXtension

initialize eggXtensionMapRef : IO.Ref EggXtensionMap ← IO.mkRef {}

def registerEggAttr (attrName : Name) (attrDescr : String)
(ref : Name := by exact decl_name%) : IO EggXtension := do
let ext ← mkEggXt ref
mkEggAttr attrName attrDescr ext ref -- Remark: it will fail if it is not performed during initialization
eggXtensionMapRef.modify fun map => map.insert attrName ext
return ext

initialize eggXtension : EggXtension ← registerEggAttr `egg "equality saturation theorem theorem"


syntax (name := showEgg) "#show_egg_set" : command

--
--
--#check Lean.Meta.mkSimpAttr
--
--@[command_elab insertEgg] def elabInsertEgg : CommandElab := fun stx => do
-- IO.println s!"inserting {stx[1].getId}"
-- modifyEnv fun env => eggXtension.addEntry env stx[1].getId
--
@[command_elab showEgg] def elabShowEgg : CommandElab := fun _ => do
logInfo m!"egg set: {eggXtension.getState (← getEnv) |>.toList}"
--
--
--initialize eggTag : TagAttribute ←
-- registerTagAttribute `egg "Tag for marking lemmas used for equality saturation" (validate := validateEggTheorem)

end Egg
1 change: 1 addition & 0 deletions Lean/Egg/Tests/Conditional.lean
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ info: [egg.rewrites] Rewrites
expr: [?l₂]
class: []
level: []
[egg.rewrites] Tagged (0)
[egg.rewrites] Generated (0)
[egg.rewrites] Builtin (0)
[egg.rewrites] Hypotheses (0)
Expand Down
71 changes: 71 additions & 0 deletions Lean/Egg/Tests/FreshmanTags.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import Egg

class Inv (α) where inv : α → α
postfix:max "⁻¹" => Inv.inv

class Zero (α) where zero : α
instance [Zero α] : OfNat α 0 where ofNat := Zero.zero

class One (α) where one : α
instance [One α] : OfNat α 1 where ofNat := One.one

class CommRing (α) extends Zero α, One α, Add α, Sub α, Mul α, Div α, Pow α Nat, Inv α, Neg α where
comm_add (a b : α) : a + b = b + a
comm_mul (a b : α) : a * b = b * a
add_assoc (a b c : α) : a + (b + c) = (a + b) + c
mul_assoc (a b c : α) : a * (b * c) = (a * b) * c
sub_canon (a b : α) : a - b = a + -b
neg_add (a : α) : a + -a = 0
div_canon (a b : α) : a / b = a * b⁻¹
zero_add (a : α) : a + 0 = a
zero_mul (a : α) : a * 0 = 0
one_mul (a : α) : a * 1 = a
distrib (a b c : α) : a * (b + c) = (a * b) + (a * c)
pow_zero (a : α) : a ^ 0 = 1
pow_one (a : α) : a ^ 1 = a
pow_two (a : α) : a ^ 2 = (a ^ 1) * a
pow_three (a : α) : a ^ 3 = (a ^ 2) * a

attribute [egg] CommRing.comm_add
attribute [egg] CommRing.comm_mul
attribute [egg] CommRing.add_assoc
attribute [egg] CommRing.mul_assoc
attribute [egg] CommRing.sub_canon
attribute [egg] CommRing.neg_add
attribute [egg] CommRing.div_canon
attribute [egg] CommRing.zero_add
attribute [egg] CommRing.zero_mul
attribute [egg] CommRing.one_mul
attribute [egg] CommRing.distrib
attribute [egg] CommRing.pow_zero
attribute [egg] CommRing.pow_one
attribute [egg] CommRing.pow_two
attribute [egg] CommRing.pow_three

class CharTwoRing (α) extends CommRing α where
char_two (a : α) : a + a = 0

variable [CharTwoRing α] (x y : α)

theorem freshmans_dream₂ : (x + y) ^ 2 = (x ^ 2) + (y ^ 2) := by
egg calc (x + y) ^ 2
_ = (x + y) * (x + y)
_ = x * (x + y) + y * (x + y)
_ = x ^ 2 + x * y + y * x + y ^ 2
_ = x ^ 2 + y ^ 2 with [CharTwoRing.char_two]

theorem freshmans_dream₂' : (x + y) ^ 2 = (x ^ 2) + (y ^ 2) := by
egg [CharTwoRing.char_two]

theorem freshmans_dream₃ : (x + y) ^ 3 = x ^ 3 + x * y ^ 2 + x ^ 2 * y + y ^ 3 := by
egg calc [CharTwoRing.char_two] (x + y) ^ 3
_ = (x + y) * (x + y) * (x + y)
_ = (x + y) * (x * (x + y) + y * (x + y))
_ = (x + y) * (x ^ 2 + x * y + y * x + y ^ 2)
_ = (x + y) * (x ^ 2 + y ^ 2)
_ = x * (x ^ 2 + y ^ 2) + y * (x ^ 2 + y ^ 2)
_ = (x * x ^ 2) + x * y ^ 2 + y * x ^ 2 + y * y ^ 2
_ = x ^ 3 + x * y ^ 2 + x ^ 2 * y + y ^ 3

theorem freshmans_dream₃' : (x + y) ^ 3 = x ^ 3 + x * y ^ 2 + x ^ 2 * y + y ^ 3 := by
egg [CharTwoRing.char_two]
2 changes: 2 additions & 0 deletions Lean/Egg/Tests/Prune.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ info: [egg.rewrites] Rewrites
expr: []
class: []
level: []
[egg.rewrites] Tagged (0)
[egg.rewrites] Generated (0)
[egg.rewrites] Builtin (0)
[egg.rewrites] Hypotheses (0)
Expand Down Expand Up @@ -55,6 +56,7 @@ info: [egg.rewrites] Rewrites
expr: [?n, ?m]
class: []
level: []
[egg.rewrites] Tagged (0)
[egg.rewrites] Generated (0)
[egg.rewrites] Builtin (0)
[egg.rewrites] Hypotheses (0)
Expand Down
1 change: 1 addition & 0 deletions Lean/Egg/Tests/TC Proj Binders.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ info: [egg.rewrites] Rewrites
expr: []
class: []
level: []
[egg.rewrites] Tagged (0)
[egg.rewrites] Generated (2)
[egg.rewrites] #0[0?69632,0](⇔)
[egg.rewrites] HMul.hMul = Mul.mul
Expand Down
Loading
Loading