Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: monadic generalization of FindExpr #3970

Merged
merged 2 commits into from
Apr 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 25 additions & 19 deletions src/Lean/Util/FindExpr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ namespace Lean
namespace Expr
namespace FindImpl

unsafe abbrev FindM := StateT (PtrSet Expr) Id
unsafe abbrev FindM (m) := StateT (PtrSet Expr) m

@[inline] unsafe def checkVisited (e : Expr) : OptionT FindM Unit := do
@[inline] unsafe def checkVisited [Monad m] (e : Expr) : OptionT (FindM m) Unit := do
if (← get).contains e then
failure
modify fun s => s.insert e

unsafe def findM? (p : Expr → Bool) (e : Expr) : OptionT FindM Expr :=
unsafe def findM? [Monad m] (p : Expr → m Bool) (e : Expr) : OptionT (FindM m) Expr :=
let rec visit (e : Expr) := do
checkVisited e
if p e then
if p e then
pure e
else match e with
| .forallE _ d b _ => visit d <|> visit b
Expand All @@ -33,29 +33,35 @@ unsafe def findM? (p : Expr → Bool) (e : Expr) : OptionT FindM Expr :=
| _ => failure
visit e

unsafe def findUnsafe? (p : Expr → Bool) (e : Expr) : Option Expr :=
Id.run <| findM? p e |>.run' mkPtrSet
unsafe def findUnsafeM? {m} [Monad m] (p : Expr → m Bool) (e : Expr) : m (Option Expr) :=
findM? p e |>.run' mkPtrSet

@[inline] unsafe def findUnsafe? (p : Expr → Bool) (e : Expr) : Option Expr := findUnsafeM? (m := Id) p e

end FindImpl

@[implemented_by FindImpl.findUnsafe?]
def find? (p : Expr → Bool) (e : Expr) : Option Expr :=
/- This is a reference implementation for the unsafe one above -/
if p e then
some e
@[implemented_by FindImpl.findUnsafeM?]
/- This is a reference implementation for the unsafe one above -/
def findM? [Monad m] (p : Expr → m Bool) (e : Expr) : m (Option Expr) := do
if p e then
return some e
else match e with
| .forallE _ d b _ => find? p d <|> find? p b
| .lam _ d b _ => find? p d <|> find? p b
| .mdata _ b => find? p b
| .letE _ t v b _ => find? p t <|> find? p v <|> find? p b
| .app f a => find? p f <|> find? p a
| .proj _ _ b => find? p b
| _ => none
| .forallE _ d b _ => findM? p d <||> findM? p b
| .lam _ d b _ => findM? p d <||> findM? p b
| .mdata _ b => findM? p b
| .letE _ t v b _ => findM? p t <||> findM? p v <||> findM? p b
| .app f a => findM? p f <||> findM? p a
| .proj _ _ b => findM? p b
| _ => pure none

@[implemented_by FindImpl.findUnsafe?]
def find? (p : Expr → Bool) (e : Expr) : Option Expr := findM? (m := Id) p e

/-- Return true if `e` occurs in `t` -/
def occurs (e : Expr) (t : Expr) : Bool :=
(t.find? fun s => s == e).isSome


/--
Return type for `findExt?` function argument.
-/
Expand All @@ -66,7 +72,7 @@ inductive FindStep where

namespace FindExtImpl

unsafe def findM? (p : Expr → FindStep) (e : Expr) : OptionT FindImpl.FindM Expr :=
unsafe def findM? (p : Expr → FindStep) (e : Expr) : OptionT (FindImpl.FindM Id) Expr :=
visit e
where
visitApp (e : Expr) :=
Expand Down
Loading