Skip to content

Commit

Permalink
Implement a map-reduce visitor for expressions and fix issues with ge…
Browse files Browse the repository at this point in the history
…t{M,F}VarIds
  • Loading branch information
sonmarcho committed Dec 12, 2023
1 parent 24c5289 commit c23f317
Showing 1 changed file with 81 additions and 31 deletions.
112 changes: 81 additions & 31 deletions backends/lean/Base/Utils.lean
Original file line number Diff line number Diff line change
Expand Up @@ -159,47 +159,96 @@ elab "print_ctx_decls" : tactic => do
let decls ← ctx.getDecls
printDecls decls

-- A map visitor function for expressions (adapted from `AbstractNestedProofs.visit`)
-- A map-reduce visitor function for expressions (adapted from `AbstractNestedProofs.visit`)
-- The continuation takes as parameters:
-- - the current depth of the expression (useful for printing/debugging)
-- - the expression to explore
partial def mapVisit (k : Nat → Expr → MetaM Expr) (e : Expr) : MetaM Expr := do
let mapVisitBinders (xs : Array Expr) (k2 : MetaM Expr) : MetaM Expr := do
partial def mapreduceVisit {a : Type} (k : Nat → a → Expr → MetaM (a × Expr))
(state : a) (e : Expr) : MetaM (a × Expr) := do
let mapreduceVisitBinders (state : a) (xs : Array Expr) (k2 : a → MetaM (a × Expr)) :
MetaM (a × Expr) := do
let localInstances ← getLocalInstances
let mut lctx ← getLCtx
for x in xs do
let xFVarId := x.fvarId!
let localDecl ← xFVarId.getDecl
let type ← mapVisit k localDecl.type
let localDecl := localDecl.setType type
let localDecl ← match localDecl.value? with
| some value => let value ← mapVisit k value; pure <| localDecl.setValue value
| none => pure localDecl
lctx :=lctx.modifyLocalDecl xFVarId fun _ => localDecl
withLCtx lctx localInstances k2
-- Update the local declarations for the bindings in context `lctx`
let rec visit_xs (lctx : LocalContext) (state : a) (xs : List Expr) : MetaM (LocalContext × a) := do
match xs with
| [] => pure (lctx, state)
| x :: xs => do
let xFVarId := x.fvarId!
let localDecl ← xFVarId.getDecl
let (state, type) ← mapreduceVisit k state localDecl.type
let localDecl := localDecl.setType type
let (state, localDecl) ← match localDecl.value? with
| some value =>
let (state, value) ← mapreduceVisit k state value
pure (state, localDecl.setValue value)
| none => pure (state, localDecl)
let lctx := lctx.modifyLocalDecl xFVarId fun _ => localDecl
-- Recursive call
visit_xs lctx state xs
let (lctx, state) ← visit_xs (← getLCtx) state xs.toList
-- Call the continuation with the updated context
withLCtx lctx localInstances (k2 state)
-- TODO: use a cache? (Lean.checkCache)
let rec visit (i : Nat) (e : Expr) : MetaM Expr := do
let rec visit (i : Nat) (state : a) (e : Expr) : MetaM (a × Expr) := do
-- Explore
let e ← k i e
let (state, e) ← k i state e
match e with
| .bvar _
| .fvar _
| .mvar _
| .sort _
| .lit _
| .const _ _ => pure e
| .app .. => do e.withApp fun f args => return mkAppN f (← args.mapM (visit (i + 1)))
| .const _ _ => pure (state, e)
| .app .. => do e.withApp fun f args => do
let (state, args) ← args.foldlM (fun (state, args) arg => do let (state, arg) ← visit (i + 1) state arg; pure (state, arg :: args)) (state, [])
let args := args.reverse
let (state, f) ← visit (i + 1) state f
let e' := mkAppN f (Array.mk args)
return (state, e')
| .lam .. =>
lambdaLetTelescope e fun xs b =>
mapVisitBinders xs do mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false)
mapreduceVisitBinders state xs fun state => do
let (state, b) ← visit (i + 1) state b
let e' ← mkLambdaFVars xs b (usedLetOnly := false)
return (state, e')
| .forallE .. => do
forallTelescope e fun xs b => mapVisitBinders xs do mkForallFVars xs (← visit (i + 1) b)
forallTelescope e fun xs b =>
mapreduceVisitBinders state xs fun state => do
let (state, b) ← visit (i + 1) state b
let e' ← mkForallFVars xs b
return (state, e')
| .letE .. => do
lambdaLetTelescope e fun xs b => mapVisitBinders xs do
mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false)
| .mdata _ b => return e.updateMData! (← visit (i + 1) b)
| .proj _ _ b => return e.updateProj! (← visit (i + 1) b)
visit 0 e
lambdaLetTelescope e fun xs b =>
mapreduceVisitBinders state xs fun state => do
let (state, b) ← visit (i + 1) state b
let e' ← mkLambdaFVars xs b (usedLetOnly := false)
return (state, e')
| .mdata _ b => do
let (state, b) ← visit (i + 1) state b
return (state, e.updateMData! b)
| .proj _ _ b => do
let (state, b) ← visit (i + 1) state b
return (state, e.updateProj! b)
visit 0 state e

-- A map visitor function for expressions (adapted from `AbstractNestedProofs.visit`)
-- The continuation takes as parameters:
-- - the current depth of the expression (useful for printing/debugging)
-- - the expression to explore
partial def mapVisit (k : Nat → Expr → MetaM Expr) (e : Expr) : MetaM Expr := do
let k' i (_ : Unit) e := do
let e ← k i e
pure ((), e)
let (_, e) ← mapreduceVisit k' () e
pure e

-- A reduce visitor
partial def reduceVisit {a : Type} (k : Nat → a → Expr → MetaM a) (s : a) (e : Expr) : MetaM a := do
let k' i (s : a) e := do
let s ← k i s e
pure (s, e)
let (s, _) ← mapreduceVisit k' s e
pure s

-- Generate a fresh user name for an anonymous proposition to introduce in the
-- assumptions
Expand Down Expand Up @@ -376,16 +425,17 @@ def destEq (e : Expr) : MetaM (Expr × Expr) := do
else throwError "Not an equality: {e}"

-- Return the set of FVarIds in the expression
-- TODO: this collects fvars introduced in the inner bindings
partial def getFVarIds (e : Expr) (hs : HashSet FVarId := HashSet.empty) : MetaM (HashSet FVarId) := do
e.consumeMData.withApp fun body args => do
let hs := if body.isFVar then hs.insert body.fvarId! else hs
args.foldlM (fun hs arg => getFVarIds arg hs) hs
reduceVisit (fun _ (hs : HashSet FVarId) e =>
if e.isFVar then pure (hs.insert e.fvarId!) else pure hs)
hs e

-- Return the set of MVarIds in the expression
partial def getMVarIds (e : Expr) (hs : HashSet MVarId := HashSet.empty) : MetaM (HashSet MVarId) := do
e.consumeMData.withApp fun body args => do
let hs := if body.isMVar then hs.insert body.mvarId! else hs
args.foldlM (fun hs arg => getMVarIds arg hs) hs
reduceVisit (fun _ (hs : HashSet MVarId) e =>
if e.isMVar then pure (hs.insert e.mvarId!) else pure hs)
hs e

-- Tactic to split on a disjunction.
-- The expression `h` should be an fvar.
Expand Down

0 comments on commit c23f317

Please sign in to comment.