Skip to content

Commit

Permalink
Add rule pattern index
Browse files Browse the repository at this point in the history
Rule patterns are now stored in a discrimination tree and lookups in
this tree are cached. This should make rule pattern matching much more
efficient.
  • Loading branch information
JLimperg committed Oct 22, 2024
1 parent 568d91f commit 66834e0
Show file tree
Hide file tree
Showing 11 changed files with 292 additions and 73 deletions.
10 changes: 4 additions & 6 deletions Aesop/Index.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Authors: Jannis Limperg

import Aesop.Forward.Match
import Aesop.Index.Basic
import Aesop.Index.RulePattern
import Aesop.RulePattern
import Aesop.Rule.Basic
import Aesop.Tracing
Expand Down Expand Up @@ -134,7 +135,8 @@ private def applicableUnindexedRules (ri : Index α) (include? : Rule α → Boo
-- Returns the rules in the order given by the `Ord α` instance.
@[specialize]
def applicableRules (ri : Index α) (goal : MVarId)
(additionalRules : Array (Rule α)) (include? : Rule α → Bool) :
(patInstMap : RulePatternInstMap) (additionalRules : Array (Rule α))
(include? : Rule α → Bool) :
MetaM (Array (IndexMatchResult (Rule α))) := do
withConstAesopTraceNode .debug (return "rule selection") do
goal.instantiateMVars
Expand All @@ -148,19 +150,15 @@ def applicableRules (ri : Index α) (goal : MVarId)
(applicableUnindexedRules ri include?)
ruleMap := additionalRules.foldl (init := ruleMap) λ ruleMap r =>
ruleMap.insert r #[] -- NOTE: additional rules are not checked with include?
aesop_trace[debug] "selected rules before pattern check:{indentD $ flip MessageData.joinSep "\n" $ ruleMap.toList.map (toMessageData ·.fst.name)}"
let mut patterns := Array.mkEmpty ruleMap.size
for (rule, _) in ruleMap do
if let some pattern := rule.pattern? then
patterns := patterns.push (rule.name, pattern)
aesop_trace[debug] "patterns:{indentD $ flip MessageData.joinSep "\n" $ patterns.map (λ (name, pat) => m!"{name}: {pat.pattern.expr}") |>.toList}"
let patternInstsMap ← matchRulePatterns patterns goal
aesop_trace[debug] "found pattern instantiations:{indentD $ flip MessageData.joinSep "\n" $ patternInstsMap.toList.map λ (name, insts) => m!"{name}: {insts.toArray.map (·.toArray)}"}"
let mut result := Array.mkEmpty ruleMap.size
for (rule, locs) in ruleMap do
let locations := (∅ : Std.HashSet _).insertMany locs
if rule.pattern?.isSome then
if let some patternInstantiations := patternInstsMap[rule.name]? then
if let some patternInstantiations := patInstMap[rule.name]? then
result := result.push { rule, locations, patternInstantiations }
else
result := result.push { rule, locations, patternInstantiations := ∅ }
Expand Down
4 changes: 0 additions & 4 deletions Aesop/Index/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ open Lean Lean.Meta

namespace Aesop

-- This value controls whether we use 'powerful' reductions, e.g. iota, when
-- indexing Aesop rules. See the `DiscrTree` docs for details.
def discrTreeConfig : WhnfCoreConfig := { iota := false }

inductive IndexingMode : Type
| unindexed
| target (keys : Array DiscrTree.Key)
Expand Down
16 changes: 16 additions & 0 deletions Aesop/Index/DiscrTreeConfig.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/-
Copyright (c) 2021 Jannis Limperg. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Jannis Limperg
-/

import Lean

open Lean.Meta

namespace Aesop

/-- Discrimination tree configuration used by all Aesop indices. -/
def discrTreeConfig : WhnfCoreConfig := { iota := false }

end Aesop
164 changes: 164 additions & 0 deletions Aesop/Index/RulePattern.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/-
Copyright (c) 2024 Jannis Limperg. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Jannis Limperg
-/

import Aesop.RuleTac.GoalDiff
import Aesop.Index.Basic
import Aesop.RulePattern
import Batteries.Lean.Meta.DiscrTree

set_option linter.missingDocs true

open Lean Lean.Meta

namespace Aesop

/-- A map from rule names to rule pattern instantiations. When run on a goal,
the rule pattern index returns such a map. -/
abbrev RulePatternInstMap :=
Std.HashMap RuleName (Std.HashSet RulePatternInstantiation)

namespace RulePatternInstMap

instance : EmptyCollection RulePatternInstMap :=
⟨{}⟩

/-- Insert an array of rule pattern instantiations into a rule pattern
instantiation map. -/
def insertArray (xs : Array (RuleName × RulePatternInstantiation))
(m : RulePatternInstMap) : RulePatternInstMap :=
xs.foldl (init := m) λ m (r, inst) =>
match m[r]? with
| none => m.insert r $ (∅ : Std.HashSet _).insert inst
| some insts => m.insert r $ insts.insert inst

end RulePatternInstMap

set_option linter.missingDocs false in
/-- A cache for the rule pattern index. -/
structure RulePatternCache where
map : Std.HashMap Expr (Array (RuleName × RulePatternInstantiation))
deriving Inhabited

instance : EmptyCollection RulePatternCache :=
⟨⟨∅⟩⟩

/-- Type class for monads with access to a rule pattern cache. -/
abbrev MonadRulePatternCache m :=
MonadCache Expr (Array (RuleName × RulePatternInstantiation)) m

instance [Monad m] [MonadLiftT (ST ω) m] [STWorld ω m]
[MonadStateOf RulePatternCache m] :
MonadHashMapCacheAdapter Expr (Array (RuleName × RulePatternInstantiation)) m where
getCache := return (← getThe RulePatternCache).map
modifyCache f := modifyThe RulePatternCache λ s => { s with map := f s.map }

-- TODO upstream
scoped instance [MonadCache α β m] : MonadCache α β (StateRefT' ω σ m) where
findCached? a := MonadCache.findCached? (m := m) a
cache a b := MonadCache.cache (m := m) a b

/-- An entry of the rule pattern index. -/
structure RulePatternIndex.Entry where
/-- The name of the rule to which the pattern belongs. -/
name : RuleName
/-- The rule's pattern. We assume that there is at most one pattern per
rule. -/
pattern : RulePattern
deriving Inhabited

instance : BEq RulePatternIndex.Entry where
beq e₁ e₂ := e₁.name == e₂.name

set_option linter.missingDocs false in
/-- A rule pattern index. Maps expressions `e` to rule patterns that likely
unify with `e`. -/
structure RulePatternIndex where
tree : DiscrTree RulePatternIndex.Entry
deriving Inhabited

namespace RulePatternIndex

instance : EmptyCollection RulePatternIndex :=
⟨⟨{}⟩⟩

/-- Add a rule pattern to the index. -/
def add (name : RuleName) (pattern : RulePattern) (idx : RulePatternIndex) :
RulePatternIndex :=
⟨idx.tree.insertCore pattern.discrTreeKeys { name, pattern }⟩

/-- Merge two rule pattern indices. Patterns that appear in both indices appear
twice in the result. -/
def merge (idx₁ idx₂ : RulePatternIndex) : RulePatternIndex :=
⟨idx₁.tree.mergePreservingDuplicates idx₂.tree⟩

section Get

variable [Monad m] [MonadRulePatternCache m] [MonadLiftT MetaM m]
[MonadControlT MetaM m]

local instance : STWorld IO.RealWorld m where

local instance : MonadLiftT (ST IO.RealWorld) m where
monadLift x := (x : MetaM _)

local instance : MonadMCtx m where
getMCtx := (getMCtx : MetaM _)
modifyMCtx f := (modifyMCtx f : MetaM _)

/-- Get rule pattern instantiations for the patterns that match `e`. -/
def getSingle (e : Expr) (idx : RulePatternIndex) :
MetaM (Array (RuleName × RulePatternInstantiation)) := do
let ms ← idx.tree.getUnify e discrTreeConfig
ms.foldlM (init := #[]) λ acc { name := r, pattern } =>
withNewMCtxDepth do
let (mvarIds, p) ← pattern.open
if ← isDefEq e p then
let inst ← mvarIds.mapM λ mvarId => do
let mvar := .mvar mvarId
let result ← instantiateMVars mvar
if result == mvar then
throwError "matchRulePatterns: while matching pattern '{p}' against expression '{e}': expected metavariable ?{(← mvarId.getDecl).userName} ({mvarId.name}) to be assigned"
pure result
return acc.push (r, inst)
else
return acc

/-- Get all instantiations of the rule patterns that match a subexpression of
`e`. Subexpressions containing bound variables are not considered. -/
def get (e : Expr) (idx : RulePatternIndex) :
m (Array (RuleName × RulePatternInstantiation)) := do
let e ← instantiateMVars e
checkCache e λ _ => (·.snd) <$> (e.forEach getSubexpr |>.run #[])
where
getSubexpr (e : Expr) :
StateRefT (Array (RuleName × RulePatternInstantiation)) m Unit := do
if e.hasLooseBVars then
-- We don't visit subexpressions with loose bvars. Instantiations
-- derived from such subexpressions would not be valid in the goal's
-- context. E.g. if a rule `(x : T) → P x` has pattern `x` and we
-- have the expression `λ (y : T), y` in the goal, then it makes no
-- sense to match `y` and generate `P y`.
return
let ms ← idx.getSingle e
modify (· ++ ms)

/-- Get all instantiations of the rule patterns that match a subexpression of
a hypothesis or the target. Subexpressions containing bound variables are not
considered. -/
def getInGoal (goal : MVarId) (idx : RulePatternIndex) : m RulePatternInstMap :=
goal.withContext do
let mut result := ∅
for ldecl in (← goal.getDecl).lctx do
result := result.insertArray $ ← idx.get ldecl.toExpr
result := result.insertArray $ ← idx.get ldecl.type
if let some val := ldecl.value? then
result := result.insertArray $ ← idx.get val
result := result.insertArray $ ← idx.get (← goal.getType)
return result

end Get

end Aesop.RulePatternIndex
48 changes: 7 additions & 41 deletions Aesop/RulePattern.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Authors: Jannis Limperg

import Aesop.Rule.Name
import Aesop.Tracing
import Aesop.Index.DiscrTreeConfig

open Lean Lean.Meta

Expand Down Expand Up @@ -38,6 +39,10 @@ structure RulePattern where
instantiation `tⱼ` of `yⱼ` should be substituted for `xᵢ`.
-/
argMap : Array (Option Nat)
/--
Discrimination tree keys for `p`.
-/
discrTreeKeys : Array DiscrTree.Key
deriving Inhabited

namespace RulePattern
Expand All @@ -57,46 +62,6 @@ def RulePatternInstantiation.toArray : RulePatternInstantiation → Array Expr :
instance : EmptyCollection RulePatternInstantiation :=
⟨.empty⟩

def matchRulePatternsCore (pats : Array (RuleName × RulePattern))
(mvarId : MVarId) :
StateRefT (Std.HashMap RuleName (Std.HashSet RulePatternInstantiation)) MetaM Unit :=
withNewMCtxDepth do -- TODO use (allowLevelAssignments := true)?
let openPats ← pats.mapM λ (name, pat) => return (name, ← pat.open)
let initialState ← show MetaM _ from saveState
forEachExprInGoal mvarId λ e => do
if e.hasLooseBVars then
-- We don't visit subexpressions with loose bvars. Instantiations
-- derived from such subexpressions would not be valid in the goal's
-- context. E.g. if a rule `(x : T) → P x` has pattern `x` and we
-- have the expression `λ (y : T), y` in the goal, then it makes no
-- sense to match `y` and generate `P y`.
return
for (name, mvarIds, p) in openPats do
initialState.restore
-- The many `isDefEq` checks here are quite expensive. Perhaps a better
-- strategy would be to reducibly normalise the goal once and for all.
-- Then we could use a variant of `isDefEq` that only checks for
-- syntactic equality up to mvars.
if ← isDefEq e p then
let instances ← mvarIds.mapM λ mvarId => do
let mvar := .mvar mvarId
let result ← instantiateMVars mvar
if result == mvar then
initialState.restore
throwError "matchRulePatterns: while matching pattern '{p}' against expression '{e}': expected metavariable ?{(← mvarId.getDecl).userName} ({mvarId.name}) to be assigned"
pure result
modify λ m =>
-- TODO loss of linearity?
if let some instanceSet := m[name]? then
m.insert name (instanceSet.insert instances)
else
m.insert name (.empty |>.insert instances)

def matchRulePatterns (pats : Array (RuleName × RulePattern))
(mvarId : MVarId) :
MetaM (Std.HashMap RuleName (Std.HashSet RulePatternInstantiation)) :=
(·.snd) <$> (matchRulePatternsCore pats mvarId |>.run ∅)

namespace RulePattern

def getInstantiation [Monad m] [MonadError m] (pat : RulePattern)
Expand Down Expand Up @@ -149,10 +114,11 @@ def «elab» (stx : Term) (ruleType : Expr) : TermElabM RulePattern :=
forallTelescope ruleType λ fvars _ => do
let pat := (← elabPattern stx).consumeMData
let (pat, mvarIds) ← fvarsToMVars fvars pat
let discrTreeKeys ← DiscrTree.mkPath pat discrTreeConfig
let (pat, mvarIdToPatternPos) ← abstractMVars' pat
let argMap := mvarIds.map (mvarIdToPatternPos[·]?)
aesop_trace[debug] "pattern '{stx}' elaborated into '{pat.expr}'"
return { pattern := pat, argMap }
return { pattern := pat, argMap, discrTreeKeys }
where
fvarsToMVars (fvars : Array Expr) (e : Expr) :
MetaM (Expr × Array MVarId) := do
Expand Down
Loading

0 comments on commit 66834e0

Please sign in to comment.