Skip to content

Commit

Permalink
Merge pull request #38 from ninehusky/ninehusky-conditional-rewrite-a…
Browse files Browse the repository at this point in the history
…pplication

Apply conditional rewrites, instead of just leaving them be :p
  • Loading branch information
ninehusky authored Nov 14, 2024
2 parents 53581dd + 7f64be3 commit 40bed9d
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 64 deletions.
202 changes: 202 additions & 0 deletions src/ite.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
use std::{str::FromStr, sync::Arc};

use log::info;

use std::fmt::Debug;

use egglog::{
ast::{Span, Symbol},
constraint::{SimpleTypeConstraint, TypeConstraint},
sort::Sort,
ArcSort, PrimitiveLike,
};
use ruler::enumo::Sexp;

pub trait PredicateInterpreter: Debug + Send + Sync {
fn interp_cond(&self, sexp: &Sexp) -> bool;
}

#[derive(Debug)]
pub struct DummySort {
// the language that a condition will operate on
pub sort: ArcSort,
pub interpreter: Arc<dyn PredicateInterpreter>,
}

impl Sort for DummySort {
fn name(&self) -> Symbol {
"dummy".into()
}

fn as_arc_any(
self: std::sync::Arc<Self>,
) -> std::sync::Arc<dyn std::any::Any + Send + Sync + 'static> {
self
}

fn make_expr(
&self,
_egraph: &egglog::EGraph,
_value: egglog::Value,
) -> (usize, egglog::ast::Expr) {
(0, egglog::ast::Expr::lit_no_span(Symbol::from("dummy")))
}

fn register_primitives(self: std::sync::Arc<Self>, info: &mut egglog::TypeInfo) {
info.add_primitive(Ite {
name: "ite".into(),
sort: self.sort.clone(),
interpreter: self.interpreter.clone(),
});
}
}

pub struct Ite {
name: Symbol,
sort: Arc<dyn Sort>,
interpreter: Arc<dyn PredicateInterpreter>,
}

// (ite pred_expr expr expr) -> expr.
// will evaluate to first expr if pred_expr = true (according to interpreter semantics), else the other expr.
impl PrimitiveLike for Ite {
fn name(&self) -> Symbol {
self.name
}

fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint> {
SimpleTypeConstraint::new(
self.name(),
vec![
self.sort.clone(),
self.sort.clone(),
self.sort.clone(),
self.sort.clone(),
],
span.clone(),
)
.into_box()
}

fn apply(
&self,
values: &[egglog::Value],
egraph: Option<&mut egglog::EGraph>,
) -> Option<egglog::Value> {
let egraph = egraph.unwrap();
let sexp = Sexp::from_str(&egraph.extract_value_to_string(values[0])).unwrap();

info!("apply on {}", sexp);

if self.interpreter.interp_cond(&sexp) {
Some(values[1])
} else {
Some(values[2])
}
}
}

// idk why clippy complains about the two use statements below.
#[allow(unused_imports)]
pub mod tests {
use super::*;
use egglog::sort::EqSort;

#[test]
fn test_ite_create() {
#[derive(Debug)]
struct MathInterpreter;

impl PredicateInterpreter for MathInterpreter {
fn interp_cond(&self, sexp: &Sexp) -> bool {
fn interp_internal(sexp: &Sexp) -> i64 {
match sexp {
Sexp::Atom(atom) => panic!("Unexpected atom: {}", atom),
Sexp::List(l) => {
if let Sexp::Atom(op) = &l[0] {
match op.as_str() {
"Eq" => {
let a = interp_internal(&l[1]);
let b = interp_internal(&l[2]);
if a == b {
1
} else {
0
}
}
"Mul" => interp_internal(&l[1]) * interp_internal(&l[2]),
"Num" => l[1].to_string().parse().unwrap(),
_ => panic!("Unexpected operator: {:?}", op),
}
} else {
panic!("Unexpected list operator: {:?}", l[0]);
}
}
}
}

interp_internal(sexp) == 1
}
}

let math_sort = Arc::new(EqSort {
name: "Math".into(),
});
let dummy_sort = Arc::new(DummySort {
sort: math_sort.clone(),
interpreter: Arc::new(MathInterpreter),
});

let mut egraph = egglog::EGraph::default();

egraph.add_arcsort(math_sort.clone()).unwrap();
egraph.add_arcsort(dummy_sort).unwrap();

egraph
.parse_and_run_program(
None,
r#"
(function Num (i64) Math)
(function Mul (Math Math) Math)
(function Eq (Math Math) Math)
(relation universe (Math))
"#,
)
.unwrap();

egraph
.parse_and_run_program(
None,
r#"
(rule
((universe ?e))
((union ?e (ite (Eq ?e (Num 1)) (Mul ?e ?e) ?e)))
)
"#,
)
.unwrap();

egraph
.parse_and_run_program(
None,
r#"
(universe (Mul (Num 1) (Num 1)))
(universe (Num 1))
(universe (Num 2))
"#,
)
.unwrap();

egraph.parse_and_run_program(None, "(run 1000)").unwrap();

egraph
.parse_and_run_program(None, "(check (= (Mul (Num 1) (Num 1)) (Num 1)))")
.unwrap();

egraph
.parse_and_run_program(None, "(fail (check (= (Mul (Num 2) (Num 2)) (Num 2))))")
.unwrap();
}
}
66 changes: 29 additions & 37 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use egglog::{EGraph, SerializeConfig};
use ruler::enumo::Pattern;
use ruler::{HashMap, HashSet, ValidationResult};
use utils::TERM_PLACEHOLDER;
use utils::{TERM_PLACEHOLDER, UNIVERSAL_RELATION};

use std::fmt::Debug;
use std::hash::Hash;
use std::str::FromStr;

use ruler::enumo::{Sexp, Workload};

use log::info;
use ruler::enumo::{Sexp, Workload};

pub mod ite;
pub mod utils;

pub type Constant<R> = <R as Chomper>::Constant;
Expand Down Expand Up @@ -71,11 +71,6 @@ pub trait Chomper {
for current_size in 0..MAX_SIZE {
info!("adding programs of size {}:", current_size);

// let mut filter = Filter::MetricEq(Metric::Atoms, current_size);
// if current_size > 4 {
// filter = Filter::And(vec![filter, Filter::Excludes(self.constant_pattern())]);
// }

info!("finding eclass term map...");
let eclass_term_map = self
.reset_eclass_term_map(egraph)
Expand Down Expand Up @@ -388,35 +383,31 @@ pub trait Chomper {
.unwrap();
}

fn add_conditional_rewrite(
&mut self,
_egraph: &mut EGraph,
_cond: Sexp,
_lhs: Sexp,
_rhs: Sexp,
) {
fn add_conditional_rewrite(&mut self, egraph: &mut EGraph, cond: Sexp, lhs: Sexp, rhs: Sexp) {
// TODO: @ninehusky: let's brainstorm ways to encode conditional equality with respect to a
// specific condition (see #20).
// let _pred = self.make_string_not_bad(cond.to_string().as_str());
// let term1 = self.make_string_not_bad(lhs.to_string().as_str());
// let term2 = self.make_string_not_bad(rhs.to_string().as_str());
// info!(
// "adding conditional rewrite: {} -> {} if {}",
// term1, term2, _pred
// );
// info!("term2 has cvec: {:?}", self.interpret_term(&rhs));
// egraph
// .parse_and_run_program(
// None,
// format!(
// r#"
// (cond-equal {term1} {term2})
// (cond-equal {term2} {term1})
// "#
// )
// .as_str(),
// )
// .unwrap();
let cond = self.make_string_not_bad(cond.to_string().as_str());
let term1 = self.make_string_not_bad(lhs.to_string().as_str());
let term2 = self.make_string_not_bad(rhs.to_string().as_str());

info!(
"adding conditional rewrite: if {} then {} -> {}",
cond, term1, term2
);

let cond_rewrite_prog = format!(
r#"
(rule
(({UNIVERSAL_RELATION} {term1}))
((union {term1} (ite {cond} {term2} {term1}))))
"#
);

println!("cond rewrite prog: {}", cond_rewrite_prog);

egraph
.parse_and_run_program(None, &cond_rewrite_prog)
.unwrap();
}

fn has_var(&self, term: &Sexp) -> bool {
Expand All @@ -433,13 +424,14 @@ pub trait Chomper {
}
}

fn language_name() -> String;
fn productions(&self) -> Workload;
fn atoms(&self) -> Workload;
fn make_preds(&self) -> Workload;
fn get_env(&self) -> &HashMap<String, CVec<Self>>;
fn validate_rule(&self, rule: &Rule) -> ValidationResult;
fn interpret_term(&mut self, term: &ruler::enumo::Sexp) -> CVec<Self>;
fn interpret_pred(&mut self, term: &ruler::enumo::Sexp) -> Vec<bool>;
fn interpret_term(&self, term: &ruler::enumo::Sexp) -> CVec<Self>;
fn interpret_pred(&self, term: &ruler::enumo::Sexp) -> Vec<bool>;
fn constant_pattern(&self) -> Pattern;
fn matches_var_pattern(&self, term: &Sexp) -> bool;
}
1 change: 1 addition & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use ruler::enumo::Sexp;
// Atoms with this name will `not` count in a production
// toward its size.
pub const TERM_PLACEHOLDER: &str = "?term";
pub const UNIVERSAL_RELATION: &str = "universe";

pub fn get_production_size(term: &Sexp) -> usize {
get_size(term, true)
Expand Down
44 changes: 22 additions & 22 deletions tests/egglog/halide.egg
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
;;; Halide language definition.

(datatype Expr
(Lit i64)
(Var String)
(Lt Expr Expr)
(Leq Expr Expr)
(Eq Expr Expr)
(Neq Expr Expr)
(Implies Expr Expr)
(Not Expr)
(Neg Expr)
(And Expr Expr)
(Or Expr Expr)
(Xor Expr Expr)
(Add Expr Expr)
(Sub Expr Expr)
(Mul Expr Expr)
(Div Expr Expr)
(Min Expr Expr)
(Max Expr Expr)
(Select Expr Expr Expr)
)
(function Lit (i64) HalideExpr)
(function Var (String) HalideExpr)
(function Lt (HalideExpr HalideExpr) HalideExpr)
(function Leq (HalideExpr HalideExpr) HalideExpr)
(function Eq (HalideExpr HalideExpr) HalideExpr)
(function Neq (HalideExpr HalideExpr) HalideExpr)
(function Implies (HalideExpr HalideExpr) HalideExpr)
(function Not (HalideExpr) HalideExpr)
(function Neg (HalideExpr) HalideExpr)
(function And (HalideExpr HalideExpr) HalideExpr)
(function Or (HalideExpr HalideExpr) HalideExpr)
(function Xor (HalideExpr HalideExpr) HalideExpr)
(function Add (HalideExpr HalideExpr) HalideExpr)
(function Sub (HalideExpr HalideExpr) HalideExpr)
(function Mul (HalideExpr HalideExpr) HalideExpr)
(function Div (HalideExpr HalideExpr) HalideExpr)
(function Min (HalideExpr HalideExpr) HalideExpr)
(function Max (HalideExpr HalideExpr) HalideExpr)
(function Select (HalideExpr HalideExpr HalideExpr) HalideExpr)


(function eclass (Expr) i64 :merge (min old new))
(function eclass (HalideExpr) i64 :merge (min old new))

(relation universe (HalideExpr))

(ruleset eclass-report)
(ruleset non-cond-rewrites)
Expand Down
Loading

0 comments on commit 40bed9d

Please sign in to comment.