Skip to content

Commit

Permalink
Fix eta-reduction with Andrés' forward-recursion approach for e-class…
Browse files Browse the repository at this point in the history
… cycles (keeping debug tracing code for now)

Enable eta by default
  • Loading branch information
marcusrossel committed Mar 5, 2024
1 parent 30fb7f5 commit 54d842d
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 94 deletions.
8 changes: 8 additions & 0 deletions C/ffi.c
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,12 @@ lean_obj_res lean_egg_explain_congr(
free(rws);

return lean_mk_string(result.expl);
}

lean_object* dbg_trace_thunk(lean_object* t) { return lean_box(0); }
void c_dbg_trace(char const* str) {
lean_object* thunk_obj = lean_alloc_closure(&dbg_trace_thunk, 1, 0);
lean_object* lstr = lean_mk_string(str);
lean_dbg_trace(lstr, thunk_obj);
return;
}
2 changes: 1 addition & 1 deletion Lean/Egg/Core/Config.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ structure Encoding where
structure Gen where
genTcProjRws := true
genNatLitRws := true
genEtaRw := false
genEtaRw := true
explode := true
deriving BEq

Expand Down
3 changes: 0 additions & 3 deletions Lean/Egg/Tactic/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,5 @@ elab "egg " cfg:egg_cfg rws:egg_rws base:(egg_base)? : tactic => do
let rawExpl := req.run
processRawExpl rawExpl goal rws cfg.toDebug amb

-- WORKAROUND: This fixes `Tests/EndOfInput`.
macro "egg" : tactic => `(tactic| egg)

-- WORKAROUND: This fixes `Tests/EndOfInput`.
macro "egg" cfg:egg_cfg : tactic => `(tactic| egg $cfg)
11 changes: 9 additions & 2 deletions Lean/Egg/Tests/Eta.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,16 @@ example : (fun x => Nat.succ x) x = Nat.succ x := by
example : (fun x => (fun y => Nat.succ y) x) = Nat.succ := by
egg (config := { genEtaRw := true })

example : (fun x => (fun x => (fun x => Nat.succ x) x) x) = Nat.succ := by
egg (config := { genEtaRw := true })

example : (fun x => (fun y => Nat.succ y) x) x = Nat.succ x := by
egg (config := { genEtaRw := true })

-- TODO: Is this an infinite loop in `eta_shift`?
example : (fun x => (fun x => (fun x => (fun x => Nat.succ x) x) x) x) = Nat.succ := by
egg (config := { genEtaRw := true })

example : id (fun x => (fun y => Nat.succ y) x) = id Nat.succ := by
sorry -- egg (config := { genEtaRw := true })
egg (config := { genEtaRw := true })

-- TODO: Construct a test case where the e-graph contains a cycle.
12 changes: 9 additions & 3 deletions Rust/src/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use crate::util::*;

#[derive(Debug, Default)]
pub struct LeanAnalysisData {
pub nat_val: Option<u64>,
pub bvars: HashSet<u64> // A bvar is in this set only if it is referenced by *all* e-nodes in the e-class.
pub nat_val: Option<u64>,
pub bvars: HashSet<u64>, // A bvar is in this set only iff it is referenced by *some* e-node in the e-class.
pub has_bvar: bool // This is true iff some e-node in the subgraph of the given e-class is a `LeanExpr::BVar`.
}

#[derive(Default)]
Expand All @@ -23,7 +24,9 @@ impl Analysis<LeanExpr> for LeanAnalysis {
//
// if let (Some(t), Some(f)) = (*to.nat_val, from.nat_val) { assert_eq!(t, f) }

egg::merge_max(&mut to.nat_val, from.nat_val) | intersect_sets(&mut to.bvars, from.bvars)
egg::merge_max(&mut to.nat_val, from.nat_val) |
union_sets(&mut to.bvars, from.bvars) |
egg::merge_max(&mut to.has_bvar, from.has_bvar)
}

fn make(egraph: &EGraph<LeanExpr, Self>, enode: &LeanExpr) -> Self::Data {
Expand All @@ -40,12 +43,14 @@ impl Analysis<LeanExpr> for LeanAnalysis {
Some(n) => vec![n].into_iter().collect(),
None => HashSet::new()
},
has_bvar: true,
..Default::default()
},

LeanExpr::App([fun, arg]) =>
Self::Data {
bvars: union_clone(&egraph[*fun].data.bvars, &egraph[*arg].data.bvars),
has_bvar: egraph[*fun].data.has_bvar || egraph[*arg].data.has_bvar,
..Default::default()
},

Expand All @@ -55,6 +60,7 @@ impl Analysis<LeanExpr> for LeanAnalysis {
&egraph[*ty].data.bvars,
&shift_down(&egraph[*body].data.bvars)
),
has_bvar: egraph[*ty].data.has_bvar || egraph[*body].data.has_bvar,
..Default::default()
},

Expand Down
175 changes: 96 additions & 79 deletions Rust/src/eta.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use egg::*;
use std::collections::VecDeque;
use std::collections::HashMap;
use std::collections::HashSet;
use std::cmp::Ordering;
use crate::lean_expr::*;
use crate::analysis::*;
use crate::util::*;

struct Eta {
fun: Var
Expand All @@ -13,146 +14,162 @@ impl Applier<LeanExpr, LeanAnalysis> for Eta {

fn apply_one(&self, egraph: &mut LeanEGraph, eta_class: Id, subst: &Subst, _: Option<&PatternAst<LeanExpr>>, rule: Symbol) -> Vec<Id> {
let fun_class = subst[self.fun];
match eta_shift(0, fun_class, egraph, &mut HashMap::new(), rule) {
ClassState::New(shifted_fun_class) => {
if egraph.union_trusted(eta_class, shifted_fun_class, rule) {
vec![eta_class]
} else {
vec![]
}
},
ClassState::Removed => vec![],
ClassState::Pending => unreachable!()
if egraph[fun_class].data.bvars.contains(&0) { return vec![] }
let shifted_fun_class = eta_shift(0, fun_class, egraph, &mut HashMap::new(), rule);
if egraph.union_trusted(eta_class, shifted_fun_class, rule) {
vec![eta_class]
} else {
vec![]
}
}
}

#[derive(Clone)]
enum ClassState {
New(Id),
Removed,
Pending
New(Id),
Visited(HashSet<LeanExpr>),
}

impl ToString for ClassState {

fn to_string(&self) -> String {
match self {
ClassState::New(id) => id.to_string(),
ClassState::Visited(ns) => format!("visited {:?}", ns.clone().into_iter().collect::<Vec<_>>().sort_by(|lhs, rhs| nonrec_cmp(lhs, rhs)))
}
}
}

// TODO: Prove termination of this function based on the rooted e-graph spanning tree property.
// The proof should probably somehow reason about the size of the retry queue.

// TODO: Prove that we can simply mutate the e-graph while traversing it without affecting the eta-reduction.
// I think the reason is that we're only every adding e-nodes but never unioning any e-classes.
// (This isn't actually true as we union in `register_node`).
// Thus, any e-node that is added is either already contained in the subgraph rooted at `target_class`
// anyway, or will end up in an e-class not contained in the subgraph rooted at `target_class`.

// If a bvar's index is below the threshold, we don't shift it.
// If a bvar's index is above the threshold, then we shift it.
// If a bvar's index equals the `threshold` value, then we remove it. This may entail that bvar's e-class is empty
// and hence the parent e-node doesn't have sufficient children and also needs to be removed.
fn eta_shift(threshold: u64, target_class: Id, egraph: &mut LeanEGraph, state: &mut HashMap<Id, ClassState>, rule: Symbol) -> ClassState {
// Optimization: If all of the target's e-nodes contain the bvar value that leads to removal,
// the whole class will be removed.
if egraph[target_class].data.bvars.contains(&threshold) {
return ClassState::Removed
fn eta_shift(threshold: u64, target_class: Id, egraph: &mut LeanEGraph, state: &mut HashMap<Id, ClassState>, rule: Symbol) -> Id {
dbg_trace("");
dbg_trace(format!("threshold: {}", threshold));
dbg_trace(format!("target_class: {}", target_class));
if state.is_empty() {
dbg_trace("∅".to_string());
} else {
for (key, value) in state.clone() { dbg_trace(format!("{} ↦ {}", key, value.to_string())); }
}

// * If we already have a new class for the given target, return that.
// * If the target class contains no applicable e-nodes, i.e. should be removed, propagate that state.
// * If the target's state is pending, we've looped back onto an e-class. In that case, propagate the pending state
// so that the caller can choose a different e-node to explore first.
if let Some(target_state) = state.get(&target_class) {
return target_state.clone()

// If we already have a new class for the given target, return that.
if let Some(ClassState::New(new_class)) = state.get(&target_class) {
dbg_trace("Early exit: cached new class");
return new_class.clone()
}

// If we reach this point, the target e-class has not been visited yet.
state.insert(target_class, ClassState::Pending);

// Gets all the nodes in the target e-class and sorts the non-recursive ones to the front.
let mut nodes = egraph[target_class].nodes.clone();
// Optimization: If the subgraph rooted at `target_class` doesn't contain any bvars, we can
// just return the e-class as is.
if !egraph[target_class].data.has_bvar {
dbg_trace("Early exit: no bvars");
state.insert(target_class, ClassState::New(target_class));
return target_class
}

// TODO: I think it might be really inefficient to fix the set of nodes we're going to visit
// *before* the for-loop as this means we could be revisiting nodes unnecessarily when
// unrolling the recursion. We might also not want to be sharing which nodes have been
// visited for a given e-class but rather keep that local like the `threshold`.

// Gets all the nodes we are going to visit. In e-class cycles, this is reduced by all nodes
// which have already been visited in a previous cycle. Though, if there are no more available
// nodes we simply allow all nodes to be visited again.
let mut nodes: Vec<LeanExpr> = egraph[target_class].nodes.clone().into_iter().filter(|n|
match state.get(&target_class) {
Some(ClassState::Visited(visited)) => !visited.contains(n),
_ => true
}
).collect();
if nodes.is_empty() { nodes = egraph[target_class].nodes.clone() }

// Sorts the nodes we are going to visit by `nonrec_cmp`.
// It is important for termination that we visit nodes according to a fixed total order.
// Moving non-recursive e-nodes to the front is simply an optimization as this means
// that we tend to visit leaves first which reduces the number of iterations we have to
// take on an e-class cycle.
nodes.sort_by(|lhs, rhs| nonrec_cmp(lhs, rhs));

// When a node has a child whose e-class is pending, that node is readded to the end of the queue.
let mut queue: VecDeque<LeanExpr> = nodes.into();
dbg_trace(format!("nodes: {:?}", nodes));

let mut new_class: Option<Id> = None;

'queue_loop: while let Some(node) = queue.pop_front() {
for node in nodes {
dbg_trace(format!("Entering: {}", node));
visit_node(&node, state, target_class);

match node {
LeanExpr::BVar(e) => {
// We expect `LeanExpr::BVar`s to always have a `LeanExpr::Nat` child which in turn has a `nat_val`.
let idx = egraph[e].data.nat_val.unwrap();
match idx.cmp(&threshold) {
Ordering::Less => {
let new_node = LeanExpr::Nat(idx);
// TODO: Optimize this branch by using the existing `target_class` as the new class.
let idx_node = LeanExpr::Nat(idx);
let idx_class = egraph.add(idx_node);
let new_node = LeanExpr::BVar(idx_class);
register_node(&mut new_class, new_node, egraph, state, target_class, rule)
}
Ordering::Greater => {
let new_node = LeanExpr::Nat(idx - 1);
let idx_node = LeanExpr::Nat(idx - 1);
let idx_class = egraph.add(idx_node);
let new_node = LeanExpr::BVar(idx_class);
register_node(&mut new_class, new_node, egraph, state, target_class, rule);
}
Ordering::Equal => continue
}
},

LeanExpr::Lam([ty, body]) | LeanExpr::Forall([ty, body]) => {
// TODO: Is it a problem that we're exploring both paths before checking the result of
// the first (aside from being less efficient)?
let s1 = eta_shift(threshold, ty, egraph, state, rule);
let s2 = eta_shift(threshold + 1, body, egraph, state, rule);
match (s1, s2) {
(ClassState::New(child1_class), ClassState::New(child2_class)) => {
let new_node = swap_children(&node, [child1_class, child2_class]);
register_node(&mut new_class, new_node, egraph, state, target_class, rule);
},
(ClassState::Removed, _) | (_, ClassState::Removed) => continue,
(ClassState::Pending, _) | (_, ClassState::Pending) => queue.push_back(node),
}
let shifted_ty = eta_shift(threshold, ty, egraph, state, rule);
let shifted_body = eta_shift(threshold + 1, body, egraph, state, rule);
let new_node = swap_children(&node, [shifted_ty, shifted_body]);
register_node(&mut new_class, new_node, egraph, state, target_class, rule)
}

LeanExpr::Const(es) => {
let mut child_classes = vec![];
for &e in es.iter() {
match eta_shift(threshold, e, egraph, state, rule) {
ClassState::New(child_class) => child_classes.push(child_class),
ClassState::Removed => continue 'queue_loop,
ClassState::Pending => { queue.push_back(LeanExpr::Const(es)); continue 'queue_loop }
}
}
let new_node = LeanExpr::Const(child_classes.into());
let shifted_children = es.iter().map(|e| eta_shift(threshold, *e, egraph, state, rule));
let new_node = LeanExpr::Const(shifted_children.collect());
register_node(&mut new_class, new_node, egraph, state, target_class, rule);
}

LeanExpr::App([e1, e2]) | LeanExpr::Max([e1, e2]) | LeanExpr::IMax([e1, e2]) => {
// TODO: Is it a problem that we're exploring both paths before checking the result of
// the first (aside from being less efficient)?
let s1 = eta_shift(threshold, e1, egraph, state, rule);
let s2 = eta_shift(threshold, e2, egraph, state, rule);
match (s1, s2) {
(ClassState::New(child1_class), ClassState::New(child2_class)) => {
let new_node = swap_children(&node, [child1_class, child2_class]);
register_node(&mut new_class, new_node, egraph, state, target_class, rule);
},
(ClassState::Removed, _) | (_, ClassState::Removed) => continue,
(ClassState::Pending, _) | (_, ClassState::Pending) => queue.push_back(node),
}
let new_node = swap_children(&node, [s1, s2]);
register_node(&mut new_class, new_node, egraph, state, target_class, rule)
},

LeanExpr::Lit(e) | LeanExpr::FVar(e) | LeanExpr::MVar(e) | LeanExpr::Sort(e) |
LeanExpr::UVar(e) | LeanExpr::Param(e) | LeanExpr::Succ(e) => {
match eta_shift(threshold, e, egraph, state, rule) {
ClassState::New(child_class) => {
let new_node = swap_child(&node, child_class);
register_node(&mut new_class, new_node, egraph, state, target_class, rule);
}
ClassState::Removed => continue,
ClassState::Pending => queue.push_back(node),
}
let shifted_child = eta_shift(threshold, e, egraph, state, rule);
let new_node = swap_child(&node, shifted_child);
register_node(&mut new_class, new_node, egraph, state, target_class, rule);
}

LeanExpr::Nat(_) | LeanExpr::Str(_) | LeanExpr::Erased =>
LeanExpr::Nat(_) | LeanExpr::Str(_) | LeanExpr::Erased =>
register_node(&mut new_class, node, egraph, state, target_class, rule)
}
}

match new_class {
Some(n) => ClassState::New(n),
None => ClassState::Removed
new_class.unwrap()
}

fn visit_node(node: &LeanExpr, state: &mut HashMap<Id, ClassState>, target_class: Id) {
match state.get_mut(&target_class) {
None => _ = state.insert(target_class, ClassState::Visited(HashSet::from([node.clone()]))),
Some(ClassState::Visited(visited)) => _ = visited.insert(node.clone()),
_ => return
}
}

Expand Down
7 changes: 4 additions & 3 deletions Rust/src/lean_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ pub fn is_nonrec(expr: &LeanExpr) -> bool {
}
}

// An expression `lhs` is smaller than another `rhs` wrt. non-recursiveness if `lhs` not recursive
// but `rhs` is.
// An expression `lhs` is smaller than another `rhs` wrt. non-recursiveness if `lhs` is not
// recursive but `rhs` is. If both are either recursive or non-recursive, the total order
// derived by `define_language!` applies.
pub fn nonrec_cmp(lhs: &LeanExpr, rhs: &LeanExpr) -> Ordering {
match (is_nonrec(lhs), is_nonrec(rhs)) {
(true, false) => Ordering::Less,
(false, true) => Ordering::Greater,
_ => Ordering::Equal,
_ => lhs.cmp(rhs),
}
}

Expand Down
Loading

0 comments on commit 54d842d

Please sign in to comment.