Skip to content

Commit

Permalink
refactor to use
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt committed Aug 20, 2024
1 parent e10a8ee commit 3bd82dc
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 34 deletions.
43 changes: 22 additions & 21 deletions src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
marker::PhantomData,
};

use explain::existenceReason;
use explain::{ExistenceReason, ExistsOrReason};
#[cfg(feature = "serde-1")]
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -464,8 +464,8 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
left_expr: &RecExpr<L>,
right_expr: &RecExpr<L>,
) -> Explanation<L> {
let left = self.add_expr_uncanonical_with_reason(left_expr, None);
let right = self.add_expr_uncanonical_with_reason(right_expr, None);
let left = self.add_expr_uncanonical_with_reason(left_expr, ExistsOrReason::ExpectExists);
let right = self.add_expr_uncanonical_with_reason(right_expr, ExistsOrReason::ExpectExists);

self.explain_id_equivalence(left, right)
}
Expand Down Expand Up @@ -505,7 +505,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// Note that this function can be called again to explain any intermediate terms
/// used in the output [`Explanation`].
pub fn explain_existence(&mut self, expr: &RecExpr<L>) -> Explanation<L> {
let id = self.add_expr_uncanonical_with_reason(expr, None);
let id = self.add_expr_uncanonical_with_reason(expr, ExistsOrReason::ExpectExists);
self.explain_existence_id(id)
}

Expand Down Expand Up @@ -543,7 +543,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
right_pattern: &PatternAst<L>,
subst: &Subst,
) -> Explanation<L> {
let left = self.add_expr_uncanonical_with_reason(left_expr, None);
let left = self.add_expr_uncanonical_with_reason(left_expr, ExistsOrReason::ExpectExists);
let right = self.add_instantiation_noncanonical(right_pattern, subst, None);

if self.find(left) != self.find(right) {
Expand Down Expand Up @@ -853,7 +853,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// Calling [`id_to_expr`](EGraph::id_to_expr) on this `Id` return a copy of `expr` when explanations are enabled
pub fn add_expr_uncanonical(&mut self, expr: &RecExpr<L>) -> Id {
eprintln!("Adding {:?} directly", expr);
self.add_expr_uncanonical_with_reason(expr, Some(existenceReason::Direct))
self.add_expr_uncanonical_with_reason(expr, ExistsOrReason::Reason(ExistenceReason::Direct))
}

/// Like `add_expr_uncanonical` but with an existence reason.
Expand All @@ -862,22 +862,23 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
fn add_expr_uncanonical_with_reason(
&mut self,
expr: &RecExpr<L>,
reason: Option<existenceReason>,
reason: ExistsOrReason,
) -> Id {
let reason = reason.to_option();
let nodes = expr.as_ref();
let mut new_ids = Vec::with_capacity(nodes.len());
for node in nodes {
let new_node = node.clone().map_children(|i| new_ids[usize::from(i)]);
let next_id = self.add_uncanonical_with_reason(
new_node,
reason.as_ref().map(|_e| existenceReason::Unset),
reason.as_ref().map(|_e| ExistenceReason::Unset),
);
if let Some(explain) = &mut self.explain {
node.for_each(|child| {
// Set the existence reason for new nodes to their parent node if it is unset.
explain.set_existence_reason(
new_ids[usize::from(child)],
existenceReason::ChildOf(next_id),
ExistenceReason::ChildOf(next_id),
);
});
}
Expand All @@ -895,7 +896,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// Adds a [`Pattern`] and a substitution to the [`EGraph`], returning
/// the eclass of the instantiated pattern.
pub fn add_instantiation(&mut self, pat: &PatternAst<L>, subst: &Subst) -> Id {
let id = self.add_instantiation_noncanonical(pat, subst, Some(existenceReason::Direct));
let id = self.add_instantiation_noncanonical(pat, subst, Some(ExistenceReason::Direct));
self.find(id)
}

Expand All @@ -912,7 +913,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
&mut self,
pat: &PatternAst<L>,
subst: &Subst,
existence: Option<existenceReason>,
existence: Option<ExistenceReason>,
) -> Id {
let nodes = pat.as_ref();
let mut new_ids = Vec::with_capacity(nodes.len());
Expand All @@ -930,7 +931,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
// When existence is Some, the new node is first added with `Unset` existence reason.
let next_id = self.add_uncanonical_with_reason(
new_node,
existence.as_ref().map(|_e| existenceReason::Unset),
existence.as_ref().map(|_e| ExistenceReason::Unset),
);
if self.unionfind.size() > size_before {
new_node_q.push(true);
Expand All @@ -944,7 +945,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
// Now set the existence reason for children when it's unset.
explain.set_existence_reason(
new_ids[usize::from(child)],
existenceReason::ChildOf(next_id),
ExistenceReason::ChildOf(next_id),
);
}
});
Expand Down Expand Up @@ -1082,7 +1083,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
/// assert_eq!(egraph.id_to_expr(fb), "(f a)".parse().unwrap());
/// ```
pub fn add_uncanonical(&mut self, enode: L) -> Id {
self.add_uncanonical_with_reason(enode, Some(existenceReason::Direct))
self.add_uncanonical_with_reason(enode, Some(ExistenceReason::Direct))
}

/// The private implementation of [`add_uncanonical`](EGraph::add_uncanonical)
Expand All @@ -1093,7 +1094,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
fn add_uncanonical_with_reason(
&mut self,
mut enode: L,
existence: Option<existenceReason>,
existence: Option<ExistenceReason>,
) -> Id {
let original = enode.clone();
if let Some(existing_id) = self.lookup_internal(&mut enode) {
Expand All @@ -1108,7 +1109,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
explain.add(
original.clone(),
new_id,
existence.unwrap_or(existenceReason::EqualTo(existing_id)),
existence.unwrap_or(ExistenceReason::EqualTo(existing_id)),
);
debug_assert_eq!(Id::from(self.nodes.len()), new_id);
self.nodes.push(original);
Expand All @@ -1126,7 +1127,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
if let Some(explain) = self.explain.as_mut() {
// We tried to add a term that wasn't there.
if existence.is_none() {
explain.add(original.clone(), id, existenceReason::Direct);
explain.add(original.clone(), id, ExistenceReason::Direct);
let term = self.id_to_expr(id);

panic!("Expected term {:?} to exist in the egraph, but it was not found. This may happen when calling 'explain_existence' or 'explain_equivalence' on terms that are not in the egraph.", term);
Expand Down Expand Up @@ -1218,10 +1219,10 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
) -> (Id, bool) {
// add the lhs directly
let id1 =
self.add_instantiation_noncanonical(from_pat, subst, Some(existenceReason::Direct));
self.add_instantiation_noncanonical(from_pat, subst, Some(ExistenceReason::Direct));
// add the rhs, with reason equal to lhs
let id2 =
self.add_instantiation_noncanonical(to_pat, subst, Some(existenceReason::EqualTo(id1)));
self.add_instantiation_noncanonical(to_pat, subst, Some(ExistenceReason::EqualTo(id1)));

let did_union = self.perform_union(id1, id2, Some(Justification::Rule(rule_name.into())));
(self.find(id1), did_union)
Expand All @@ -1242,7 +1243,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
let id1 = self.add_instantiation_noncanonical(from_pat, subst, None);
// add the rhs, making it equal to the lhs
let id2 =
self.add_instantiation_noncanonical(to_pat, subst, Some(existenceReason::EqualTo(id1)));
self.add_instantiation_noncanonical(to_pat, subst, Some(ExistenceReason::EqualTo(id1)));

let did_union = self.perform_union(id1, id2, Some(Justification::Rule(rule_name.into())));
(self.find(id1), did_union)
Expand All @@ -1263,7 +1264,7 @@ impl<L: Language, N: Analysis<L>> EGraph<L, N> {
let id2 = self.add_instantiation_noncanonical(
to_pat,
subst,
Some(existenceReason::EqualTo(from_term)),
Some(ExistenceReason::EqualTo(from_term)),
);

let did_union =
Expand Down
40 changes: 27 additions & 13 deletions src/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct Connection {

#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
pub(crate) enum existenceReason {
pub(crate) enum ExistenceReason {
/// The term was added as a child of some other term
ChildOf(Id),
/// The term was added when rewritting from another term
Expand All @@ -54,13 +54,27 @@ pub(crate) enum existenceReason {
Unset,
}

pub enum ExistsOrReason {
ExpectExists,
Reason(ExistenceReason),
}

impl ExistsOrReason {
pub fn to_option(&self) -> Option<ExistenceReason> {
match self {
ExistsOrReason::ExpectExists => None,
ExistsOrReason::Reason(r) => Some(r.clone()),
}
}
}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde-1", derive(serde::Serialize, serde::Deserialize))]
struct ExplainNode {
// neighbors includes parent connections
neighbors: Vec<Connection>,
parent_connection: Connection,
existence: existenceReason,
existence: ExistenceReason,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -925,7 +939,7 @@ impl<L: Language> Explain<L> {
}
}

pub(crate) fn add(&mut self, node: L, set: Id, existence: existenceReason) -> Id {
pub(crate) fn add(&mut self, node: L, set: Id, existence: ExistenceReason) -> Id {
assert_eq!(self.explainfind.len(), usize::from(set));
self.uncanon_memo.insert(node, set);
self.explainfind.push(ExplainNode {
Expand Down Expand Up @@ -1050,8 +1064,8 @@ impl<L: Language> Explain<L> {
}

/// Sets unset existence reasons to the new reason
pub(crate) fn set_existence_reason(&mut self, node: Id, reason: existenceReason) {
if self.explainfind[usize::from(node)].existence == existenceReason::Unset {
pub(crate) fn set_existence_reason(&mut self, node: Id, reason: ExistenceReason) {
if self.explainfind[usize::from(node)].existence == ExistenceReason::Unset {
self.explainfind[usize::from(node)].existence = reason;
}
}
Expand Down Expand Up @@ -1114,10 +1128,10 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
loop {
seen_existence.insert(existence);
let next = match self.explainfind[usize::from(existence)].existence {
existenceReason::ChildOf(id) => id,
existenceReason::EqualTo(id) => id,
existenceReason::Direct => existence,
existenceReason::Unset => panic!("Unset existence!"),
ExistenceReason::ChildOf(id) => id,
ExistenceReason::EqualTo(id) => id,
ExistenceReason::Direct => existence,
ExistenceReason::Unset => panic!("Unset existence!"),
};
if existence == next {
break;
Expand Down Expand Up @@ -1281,7 +1295,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
let existence = node.existence.clone();

match existence {
existenceReason::ChildOf(parent_id) => {
ExistenceReason::ChildOf(parent_id) => {
let mut new_rest_of_proof =
(*self.node_to_explanation(parent_id, enode_cache)).clone();
let mut index_of_child = 0;
Expand All @@ -1306,7 +1320,7 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
enode_cache,
)
}
existenceReason::EqualTo(adjacent_id) => {
ExistenceReason::EqualTo(adjacent_id) => {
let adjacent_node = &self.explainfind[usize::from(adjacent_id)];
// The node should be directly adjacent to another node
let connection = if node.parent_connection.next == adjacent_id {
Expand All @@ -1327,10 +1341,10 @@ impl<'x, L: Language> ExplainNodes<'x, L> {
exp.push(rest_of_proof);
exp
}
existenceReason::Direct => {
ExistenceReason::Direct => {
vec![self.node_to_explanation(term, enode_cache), rest_of_proof]
}
existenceReason::Unset => panic!("Unset existence!"),
ExistenceReason::Unset => panic!("Unset existence!"),
}
}

Expand Down

0 comments on commit 3bd82dc

Please sign in to comment.