diff --git a/src/egraph.rs b/src/egraph.rs index 6af452b2..e449d0c8 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -7,6 +7,7 @@ use std::{ #[cfg(feature = "serde-1")] use serde::{Deserialize, Serialize}; +use crate::semi_persistent::UndoLogT; use log::*; /** A data structure to keep track of equalities between expressions. @@ -56,16 +57,24 @@ pub struct EGraph> { pub analysis: N, /// The `Explain` used to explain equivalences in this `EGraph`. pub(crate) explain: Option>, - unionfind: UnionFind, + #[cfg_attr( + feature = "serde-1", + serde(bound( + serialize = "N::UndoLog: Serialize", + deserialize = "N::UndoLog: for<'a> Deserialize<'a>", + )) + )] + pub(crate) undo_log: N::UndoLog, + pub(crate) unionfind: UnionFind, /// Stores each enode's `Id`, not the `Id` of the eclass. /// Enodes in the memo are canonicalized at each rebuild, but after rebuilding new /// unions can cause them to become out of date. #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] - memo: HashMap, + pub(crate) memo: HashMap, /// Nodes which need to be processed for rebuilding. The `Id` is the `Id` of the enode, /// not the canonical id of the eclass. - pending: Vec<(L, Id)>, - analysis_pending: UniqueQueue<(L, Id)>, + pub(crate) pending: Vec<(L, Id)>, + pub(crate) analysis_pending: UniqueQueue<(L, Id)>, #[cfg_attr( feature = "serde-1", serde(bound( @@ -103,6 +112,8 @@ impl> Debug for EGraph { f.debug_struct("EGraph") .field("memo", &self.memo) .field("classes", &self.classes) + .field("undo_log", &self.undo_log) + .field("explain", &self.explain) .finish() } } @@ -120,6 +131,7 @@ impl> EGraph { memo: Default::default(), analysis_pending: Default::default(), classes_by_op: Default::default(), + undo_log: Default::default(), } } @@ -769,9 +781,12 @@ impl> EGraph { *existing_explain } else { let new_id = self.unionfind.make_set(); + self.undo_log.add_node(&original, new_id); explain.add(original, new_id, new_id); self.unionfind.union(id, new_id); + self.undo_log.union(id, new_id); explain.union(existing_id, new_id, Justification::Congruence, true); + self.undo_log.union_explain(existing_id, new_id); new_id } } else { @@ -780,6 +795,7 @@ impl> EGraph { } else { let id = self.make_new_eclass(enode); if let Some(explain) = self.explain.as_mut() { + self.undo_log.add_node(&original, id); explain.add(original, id, id); } @@ -811,7 +827,8 @@ impl> EGraph { self.pending.push((enode.clone(), id)); self.classes.insert(id, class); - assert!(self.memo.insert(enode, id).is_none()); + let old = self.undo_log.modify_memo(&mut self.memo, enode, Some(id)); + assert!(old.is_none()); id } @@ -919,7 +936,12 @@ impl> EGraph { if id1 == id2 { if let Some(Justification::Rule(_)) = rule { if let Some(explain) = &mut self.explain { - explain.alternate_rewrite(enode_id1, enode_id2, rule.unwrap()); + explain.alternate_rewrite( + enode_id1, + enode_id2, + rule.unwrap(), + &mut self.undo_log, + ); } } return false; @@ -933,10 +955,12 @@ impl> EGraph { if let Some(explain) = &mut self.explain { explain.union(enode_id1, enode_id2, rule.unwrap(), any_new_rhs); + self.undo_log.union_explain(enode_id1, enode_id2); } // make id1 the new root self.unionfind.union(id1, id2); + self.undo_log.union(id1, id2); assert_ne!(id1, id2); let class2 = self.classes.remove(&id2).unwrap(); @@ -1105,7 +1129,9 @@ impl> EGraph { while !self.pending.is_empty() || !self.analysis_pending.is_empty() { while let Some((mut node, class)) = self.pending.pop() { node.update_children(|id| self.find_mut(id)); - if let Some(memo_class) = self.memo.insert(node, class) { + if let Some(memo_class) = + self.undo_log.modify_memo(&mut self.memo, node, Some(class)) + { let did_something = self.perform_union( memo_class, class, diff --git a/src/explain.rs b/src/explain.rs index 187aecfc..b08d0385 100644 --- a/src/explain.rs +++ b/src/explain.rs @@ -9,6 +9,7 @@ use std::collections::{BinaryHeap, VecDeque}; use std::fmt::{self, Debug, Display, Formatter}; use std::rc::Rc; +use crate::semi_persistent::UndoLogT; use symbolic_expressions::Sexp; type ProofCost = Saturating; @@ -29,8 +30,8 @@ pub enum Justification { #[derive(Debug, Clone, Hash, PartialEq, Eq)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] -struct Connection { - next: Id, +pub(crate) struct Connection { + pub(crate) next: Id, current: Id, justification: Justification, is_rewrite_forward: bool, @@ -38,23 +39,34 @@ struct Connection { #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] -struct ExplainNode { - node: L, +pub(crate) struct ExplainNode { + pub(crate) node: L, // neighbors includes parent connections - neighbors: Vec, - parent_connection: Connection, + pub(crate) neighbors: Vec, + pub(crate) parent_connection: Connection, // it was inserted because of: // 1) it's parent is inserted (points to parent enode) // 2) a rewrite instantiated it (points to adjacent enode) // 3) it was inserted directly (points to itself) // if 1 is true but it's also adjacent (2) then either works and it picks 2 - existance_node: Id, + pub(crate) existance_node: Id, +} + +impl Connection { + pub(crate) fn dummy(set: Id) -> Self { + Connection { + justification: Justification::Congruence, + is_rewrite_forward: false, + next: set, + current: set, + } + } } #[derive(Debug, Clone)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] pub struct Explain { - explainfind: Vec>, + pub(crate) explainfind: Vec>, #[cfg_attr(feature = "serde-1", serde(with = "vectorize"))] pub uncanon_memo: HashMap, /// By default, egg uses a greedy algorithm to find shorter explanations when they are extracted. @@ -66,7 +78,7 @@ pub struct Explain { // Invariant: The distance is always <= the unoptimized distance // That is, less than or equal to the result of `distance_between` #[cfg_attr(feature = "serde-1", serde(skip))] - shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>, + pub(crate) shortest_explanation_memo: HashMap<(Id, Id), (ProofCost, Id)>, } #[derive(Default)] @@ -1048,12 +1060,7 @@ impl Explain { self.explainfind.push(ExplainNode { node, neighbors: vec![], - parent_connection: Connection { - justification: Justification::Congruence, - is_rewrite_forward: false, - next: set, - current: set, - }, + parent_connection: Connection::dummy(set), existance_node, }); set @@ -1075,7 +1082,13 @@ impl Explain { } } - pub(crate) fn alternate_rewrite(&mut self, node1: Id, node2: Id, justification: Justification) { + pub(crate) fn alternate_rewrite( + &mut self, + node1: Id, + node2: Id, + justification: Justification, + undo: &mut impl UndoLogT, + ) { if node1 == node2 { return; } @@ -1084,6 +1097,7 @@ impl Explain { return; } } + undo.union_explain(node1, node2); let lconnection = Connection { justification: justification.clone(), diff --git a/src/language.rs b/src/language.rs index 6414c63a..261167e6 100644 --- a/src/language.rs +++ b/src/language.rs @@ -8,6 +8,7 @@ use std::{hash::Hash, str::FromStr}; use crate::*; +use crate::semi_persistent::{UndoLog, UndoLogT}; use fmt::Formatter; use symbolic_expressions::{Sexp, SexpError}; use thiserror::Error; @@ -655,6 +656,7 @@ define_language! { struct ConstantFolding; impl Analysis for ConstantFolding { type Data = Option; + type UndoLog = (); fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge { egg::merge_max(to, from) @@ -700,6 +702,12 @@ pub trait Analysis: Sized { /// The per-[`EClass`] data for this analysis. type Data: Debug; + /// Determines whether the [`EGraph`] supports [`push`](EGraph::push) and [`pop`](EGraph::pop) + /// Setting this to `()` disables [`push`](EGraph::push) and [`pop`](EGraph::pop) + /// Setting this to [`UndoLog`](UndoLog) enables [`push`](EGraph::push) and [`pop`](EGraph::pop) + /// Doing this requires that the [`EGraph`] has explanations enabled + type UndoLog: UndoLogT; + /// Makes a new [`Analysis`] data for a given e-node. /// /// Note the mutable `egraph` parameter: this is needed for some @@ -765,6 +773,22 @@ pub trait Analysis: Sized { impl Analysis for () { type Data = (); + + type UndoLog = (); + fn make(_egraph: &mut EGraph, _enode: &L) -> Self::Data {} + fn merge(&mut self, _: &mut Self::Data, _: Self::Data) -> DidMerge { + DidMerge(false, false) + } +} + +/// Simple [`Analysis`], similar to `()` but enables [`push`](EGraph::push) and [`pop`](EGraph::pop) +/// Doing this requires that the [`EGraph`] has explanations enabled +pub struct WithUndo; + +impl Analysis for WithUndo { + type Data = (); + + type UndoLog = UndoLog; fn make(_egraph: &mut EGraph, _enode: &L) -> Self::Data {} fn merge(&mut self, _: &mut Self::Data, _: Self::Data) -> DidMerge { DidMerge(false, false) diff --git a/src/lib.rs b/src/lib.rs index 5a293a58..76098701 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,6 +50,7 @@ mod multipattern; mod pattern; mod rewrite; mod run; +mod semi_persistent; mod subst; mod unionfind; mod util; @@ -101,6 +102,7 @@ pub use { pattern::{ENodeOrVar, Pattern, PatternAst, SearchMatches}, rewrite::{Applier, Condition, ConditionEqual, ConditionalApplier, Rewrite, Searcher}, run::*, + semi_persistent::UndoLog, subst::{Subst, Var}, util::*, }; diff --git a/src/rewrite.rs b/src/rewrite.rs index 6b34b48b..48edcf26 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -252,6 +252,7 @@ where /// struct MinSize; /// impl Analysis for MinSize { /// type Data = usize; +/// type UndoLog = (); /// fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge { /// merge_min(to, from) /// } diff --git a/src/semi_persistent.rs b/src/semi_persistent.rs new file mode 100644 index 00000000..2afc7f6b --- /dev/null +++ b/src/semi_persistent.rs @@ -0,0 +1,388 @@ +use crate::explain::Connection; +use crate::util::HashMap; +use crate::{EClass, EGraph, Id, Language, WithUndo}; +use indexmap::IndexSet; +use std::fmt::Debug; + +#[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +struct UndoNode { + /// Other ENodes that were unioned with this ENode and chose it as their representative + representative_of: Vec, + /// Non-canonical Id's of direct parents of this non-canonical node + parents: Vec, +} + +fn visit_undo_node(id: Id, undo_find: &[UndoNode], f: &mut impl FnMut(Id, &UndoNode)) { + let node = &undo_find[usize::from(id)]; + f(id, node); + node.representative_of + .iter() + .for_each(|&id| visit_undo_node(id, undo_find, &mut *f)) +} + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +struct PushInfo { + node_count: Id, + union_count: usize, + explain_count: usize, + memo_log_count: usize, +} + +/// Value for [`Analysis::UndoLog`](crate::Analysis::UndoLog) that enables [`push`](EGraph::push) and +/// [`pop`](EGraph::pop) +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] +pub struct UndoLog { + undo_find: Vec, + union_log: Vec, + explain_log: Vec<(Id, Id)>, + memo_log: Vec<(L, Option)>, + push_log: Vec, + // Scratch space, should be empty other that when inside `pop` + #[cfg_attr(feature = "serde-1", serde(skip))] + dirty: IndexSet, +} + +impl Default for UndoLog { + fn default() -> Self { + UndoLog { + undo_find: Default::default(), + union_log: Default::default(), + explain_log: Default::default(), + memo_log: Default::default(), + push_log: Default::default(), + dirty: Default::default(), + } + } +} + +pub trait UndoLogT: Default + Debug { + fn add_node(&mut self, node: &L, node_id: Id); + + fn union(&mut self, id1: Id, id2: Id); + + fn union_explain(&mut self, node1: Id, node2: Id); + + fn modify_memo(&mut self, memo: &mut HashMap, key: L, new_val: Option) + -> Option; +} + +impl UndoLogT for UndoLog { + fn add_node(&mut self, node: &L, node_id: Id) { + assert_eq!(self.undo_find.len(), usize::from(node_id)); + self.undo_find.push(UndoNode::default()); + for id in node.children() { + self.undo_find[usize::from(*id)].parents.push(node_id) + } + } + + fn union(&mut self, id1: Id, id2: Id) { + self.undo_find[usize::from(id1)].representative_of.push(id2); + self.union_log.push(id1) + } + + fn union_explain(&mut self, node1: Id, node2: Id) { + self.explain_log.push((node1, node2)); + } + + fn modify_memo( + &mut self, + memo: &mut HashMap, + key: L, + new_val: Option, + ) -> Option { + let res = match new_val { + None => memo.remove(&key), + Some(id) => memo.insert(key.clone(), id), + }; + self.memo_log.push((key, res)); + res + } +} + +impl UndoLogT for () { + fn add_node(&mut self, _: &L, _: Id) {} + + fn union(&mut self, _: Id, _: Id) {} + + fn union_explain(&mut self, _: Id, _: Id) {} + + fn modify_memo( + &mut self, + memo: &mut HashMap, + key: L, + new_val: Option, + ) -> Option { + match new_val { + None => memo.remove(&key), + Some(id) => memo.insert(key, id), + } + } +} + +impl EGraph { + /// Push the current egraph off the stack + /// Requires that the egraph is clean + /// + /// See [`EGraph::pop`] + pub fn push(&mut self) { + assert!(self.pending.is_empty() && self.analysis_pending.is_empty()); + let undo = &mut self.undo_log; + undo.push_log.push(PushInfo { + node_count: undo.undo_find.len().into(), + union_count: undo.union_log.len(), + explain_count: undo.explain_log.len(), + memo_log_count: undo.memo_log.len(), + }) + } + + /// Pop the current egraph off the stack, replacing + /// it with the previously [`push`](EGraph::push)ed egraph + /// + /// ``` + /// use egg::{EGraph, SymbolLang, WithUndo}; + /// let mut egraph = EGraph::new(WithUndo).with_explanations_enabled(); + /// let a = egraph.add_uncanonical(SymbolLang::leaf("a")); + /// let b = egraph.add_uncanonical(SymbolLang::leaf("b")); + /// egraph.rebuild(); + /// egraph.push(); + /// egraph.union(a, b); + /// assert_eq!(egraph.find(a), egraph.find(b)); + /// egraph.pop(); + /// assert_ne!(egraph.find(a), egraph.find(b)); + /// ``` + pub fn pop(&mut self) { + self.pop_n(1); + } + + /// Equivalent to calling [`pop`](EGraph::pop) `n` times but possibly more efficient + pub fn pop_n(&mut self, count: usize) { + if count == 0 { + return; + } + self.pending.clear(); + self.analysis_pending.clear(); + let mut push_info = None; + for _ in 0..count { + push_info = self.undo_log.push_log.pop(); + } + let PushInfo { + node_count, + union_count, + explain_count, + memo_log_count, + } = push_info.unwrap_or_else(|| panic!("Not enough pushes to pop")); + self.pop_memo(memo_log_count); + self.pop_unions(union_count, node_count); + self.pop_explain(explain_count); + self.pop_nodes(usize::from(node_count)); + } + + fn pop_memo(&mut self, old_count: usize) { + for (k, v) in self.undo_log.memo_log.drain(old_count..).rev() { + match v { + Some(v) => self.memo.insert(k, v), + None => self.memo.remove(&k), + }; + } + } + + fn pop_unions(&mut self, old_count: usize, node_count: Id) { + let explain = self.explain.as_mut().unwrap(); + explain.shortest_explanation_memo.clear(); + let undo = &mut self.undo_log; + for id in undo.union_log.drain(old_count..) { + let id2 = undo.undo_find[usize::from(id)] + .representative_of + .pop() + .unwrap(); + for id in [id, id2] { + if id < node_count { + undo.dirty.insert(id); + } + } + } + let very_dirty_count = undo.dirty.len(); + // Very dirty nodes were canonical in the state we are reverting to and were then unioned with other node + // either becoming non-canonical or having there EClass merged so the their EClasses must be completely rebuilt + // all the nodes they represented in the old state must have there parent field reset since it may have been + // path-compressed using unions that are now being reverted + for i in 0..very_dirty_count { + let root = *undo.dirty.get_index(i).unwrap(); + let mut class = EClass { + id: root, + nodes: Default::default(), + data: (), + parents: Default::default(), + }; + let union_find = &mut self.unionfind; + let dirty = &mut undo.dirty; + visit_undo_node(root, &undo.undo_find, &mut |id, node| { + union_find.union(root, id); + dirty.extend(node.parents.iter().copied().filter(|&id| id < node_count)); + class.parents.extend( + node.parents + .iter() + .map(|&id| (explain.node(id).clone(), id)), + ); + class.nodes.push(explain.node(id).clone()) + }); + self.classes.insert(root, class); + self.classes_by_op.values_mut().for_each(|ids| ids.clear()); + } + + let dirty_count = undo.dirty.len(); + // Dirty nodes are nodes that have very dirty children so canonicalization applied their `nodes` fields may no + // longer be correct and must be reverted + for i in very_dirty_count..dirty_count { + let root = *undo.dirty.get_index(i).unwrap(); + let class = self.classes.get_mut(&root).unwrap(); + class.nodes.clear(); + visit_undo_node(root, &undo.undo_find, &mut |id, _| { + class.nodes.push(explain.node(id).clone()) + }); + } + undo.dirty.clear() + } + + fn pop_explain(&mut self, old_count: usize) { + if let Some(explain) = self.explain.as_mut() { + for (node1, node2) in self + .undo_log + .explain_log + .drain(old_count..) + .rev() + .flat_map(|(n1, n2)| [(n1, n2), (n2, n1)]) + { + let exp_node = &mut explain.explainfind[usize::from(node1)]; + if exp_node.parent_connection.next == node2 { + exp_node.parent_connection = Connection::dummy(node1); + } + let c = exp_node.neighbors.pop().unwrap(); + debug_assert_eq!(c.next, node2); + } + } else { + debug_assert!(self.undo_log.explain_log.is_empty()) + } + } + + fn pop_nodes(&mut self, old_count: usize) { + if let Some(explain) = self.explain.as_mut() { + for x in explain.explainfind.drain(old_count..).rev() { + explain.uncanon_memo.remove(&x.node); + } + } + let new_count = self.undo_log.undo_find.len(); + for i in old_count..new_count { + self.classes.remove(&Id::from(i)); + } + self.undo_log.undo_find.truncate(old_count); + self.unionfind.parents.truncate(old_count); + } +} + +#[test] +fn simple_push_pop() { + use crate::{Pattern, Searcher, SymbolLang}; + use core::str::FromStr; + crate::init_logger(); + let mut egraph = EGraph::new(WithUndo).with_explanations_enabled(); + + let a = egraph.add_uncanonical(SymbolLang::leaf("a")); + let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a])); + let c = egraph.add_uncanonical(SymbolLang::leaf("c")); + egraph.rebuild(); + egraph.push(); + let b = egraph.add_uncanonical(SymbolLang::leaf("b")); + let _fb = egraph.add_uncanonical(SymbolLang::new("g", vec![b])); + egraph.union_trusted(b, a, "b=a"); + egraph.union_trusted(b, c, "b=c"); + egraph.rebuild(); + assert_eq!(egraph.find(a), b); + egraph.pop(); + assert_eq!(egraph.lookup(SymbolLang::leaf("a")), Some(a)); + assert_eq!(egraph.lookup(SymbolLang::new("f", vec![a])), Some(fa)); + assert_eq!(egraph.lookup(SymbolLang::leaf("b")), None); + + egraph.rebuild(); + let f_pat = Pattern::from_str("(f ?a)").unwrap(); + let s = f_pat.search(&egraph); + + assert_eq!(s.len(), 1); + assert_eq!(s[0].substs.len(), 1); + assert_eq!(s[0].substs[0].vec[0].1, a); +} + +#[test] +fn push_pop_explain() { + use crate::SymbolLang; + crate::init_logger(); + let mut egraph = EGraph::new(WithUndo).with_explanations_enabled(); + + let a = egraph.add_uncanonical(SymbolLang::leaf("a")); + let b = egraph.add_uncanonical(SymbolLang::leaf("b")); + let c = egraph.add_uncanonical(SymbolLang::leaf("c")); + let d = egraph.add_uncanonical(SymbolLang::leaf("d")); + egraph.union_trusted(a, b, "a=b"); + egraph.rebuild(); + let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a])); + let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b])); + egraph.union_trusted(c, fa, "c=fa"); + egraph.union_trusted(d, fb, "d=fb"); + egraph.rebuild(); + egraph.push(); + egraph.union_trusted(c, d, "bad"); + egraph.pop(); + let mut exp = egraph.explain_id_equivalence(c, d); + assert_eq!(exp.make_flat_explanation().len(), 4); +} + +#[test] +fn push_pop_explain2() { + use crate::SymbolLang; + crate::init_logger(); + let mut egraph = EGraph::new(WithUndo).with_explanations_enabled(); + + let a = egraph.add_uncanonical(SymbolLang::leaf("a")); + let b = egraph.add_uncanonical(SymbolLang::leaf("b")); + let c = egraph.add_uncanonical(SymbolLang::leaf("c")); + let d = egraph.add_uncanonical(SymbolLang::leaf("d")); + let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a])); + let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b])); + egraph.union_trusted(c, fa, "c=fa"); + egraph.union_trusted(d, fb, "d=fb"); + egraph.rebuild(); + egraph.push(); + egraph.union_trusted(c, d, "bad"); + egraph.pop(); + egraph.union_trusted(a, b, "a=b"); + egraph.rebuild(); + let mut exp = egraph.explain_id_equivalence(c, d); + assert_eq!(exp.make_flat_explanation().len(), 4); +} + +#[test] +fn push_pop_explain3() { + use crate::SymbolLang; + crate::init_logger(); + let mut egraph = EGraph::new(WithUndo).with_explanations_enabled(); + + let a = egraph.add_uncanonical(SymbolLang::leaf("a")); + let b = egraph.add_uncanonical(SymbolLang::leaf("b")); + let c = egraph.add_uncanonical(SymbolLang::leaf("c")); + let fa = egraph.add_uncanonical(SymbolLang::new("f", vec![a])); + egraph.rebuild(); + egraph.push(); + egraph.union_trusted(a, b, "a=b"); + let _ = egraph.add_uncanonical(SymbolLang::new("f", vec![b])); + egraph.rebuild(); + egraph.pop(); + let fb = egraph.add_uncanonical(SymbolLang::new("f", vec![b])); + egraph.union_trusted(fb, c, "fb=c"); + egraph.union_trusted(c, fa, "c=fa"); + egraph.rebuild(); + let mut exp = egraph.explain_id_equivalence(fa, fb); + assert_eq!(exp.make_flat_explanation().len(), 3); +} diff --git a/src/unionfind.rs b/src/unionfind.rs index 39e9bc58..9652b021 100644 --- a/src/unionfind.rs +++ b/src/unionfind.rs @@ -4,7 +4,7 @@ use std::fmt::Debug; #[derive(Debug, Clone, Default)] #[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))] pub struct UnionFind { - parents: Vec, + pub(crate) parents: Vec, } impl UnionFind { diff --git a/src/util.rs b/src/util.rs index 0e9051ee..628e185b 100644 --- a/src/util.rs +++ b/src/util.rs @@ -171,4 +171,9 @@ where debug_assert_eq!(r, self.set.is_empty()); r } + + pub fn clear(&mut self) { + self.set.clear(); + self.queue.clear(); + } } diff --git a/tests/lambda.rs b/tests/lambda.rs index 80ea4fbd..ad7d9497 100644 --- a/tests/lambda.rs +++ b/tests/lambda.rs @@ -61,6 +61,7 @@ fn eval(egraph: &EGraph, enode: &Lambda) -> Option<(Lambda, PatternAst)> impl Analysis for LambdaAnalysis { type Data = Data; + type UndoLog = (); fn merge(&mut self, to: &mut Data, from: Data) -> DidMerge { let before_len = to.free.len(); // to.free.extend(from.free); diff --git a/tests/math.rs b/tests/math.rs index a0d8c07a..a28db525 100644 --- a/tests/math.rs +++ b/tests/math.rs @@ -48,6 +48,7 @@ impl egg::CostFunction for MathCostFn { #[derive(Default)] pub struct ConstantFold; impl Analysis for ConstantFold { + type UndoLog = (); type Data = Option<(Constant, PatternAst)>; fn make(egraph: &mut EGraph, enode: &Math) -> Self::Data { diff --git a/tests/prop.rs b/tests/prop.rs index ed1c7469..6c356ddd 100644 --- a/tests/prop.rs +++ b/tests/prop.rs @@ -17,6 +17,7 @@ type Rewrite = egg::Rewrite; #[derive(Default)] struct ConstantFold; impl Analysis for ConstantFold { + type UndoLog = (); type Data = Option<(bool, PatternAst)>; fn merge(&mut self, to: &mut Self::Data, from: Self::Data) -> DidMerge { merge_option(to, from, |a, b| {