Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply conditional rewrites, instead of just leaving them be :p #38

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions src/ite.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
use std::{str::FromStr, sync::Arc};

use log::info;

use std::fmt::Debug;

use egglog::{
ast::{Span, Symbol},
constraint::{SimpleTypeConstraint, TypeConstraint},
sort::Sort,
ArcSort, PrimitiveLike,
};
use ruler::enumo::Sexp;

pub trait PredicateInterpreter: Debug + Send + Sync {
fn interp_cond(&self, sexp: &Sexp) -> bool;
}

#[derive(Debug)]
pub struct DummySort {
// the language that a condition will operate on
pub sort: ArcSort,
pub interpreter: Arc<dyn PredicateInterpreter>,
}

impl Sort for DummySort {
fn name(&self) -> Symbol {
"dummy".into()
}

fn as_arc_any(
self: std::sync::Arc<Self>,
) -> std::sync::Arc<dyn std::any::Any + Send + Sync + 'static> {
self
}

fn make_expr(
&self,
_egraph: &egglog::EGraph,
_value: egglog::Value,
) -> (usize, egglog::ast::Expr) {
(0, egglog::ast::Expr::lit_no_span(Symbol::from("dummy")))
}

fn register_primitives(self: std::sync::Arc<Self>, info: &mut egglog::TypeInfo) {
info.add_primitive(Ite {
name: "ite".into(),
sort: self.sort.clone(),
interpreter: self.interpreter.clone(),
});
}
}

pub struct Ite {
name: Symbol,
sort: Arc<dyn Sort>,
interpreter: Arc<dyn PredicateInterpreter>,
}

// (ite pred_expr expr expr) -> expr.
// will evaluate to first expr if pred_expr = true (according to interpreter semantics), else the other expr.
impl PrimitiveLike for Ite {
fn name(&self) -> Symbol {
self.name
}

fn get_type_constraints(&self, span: &Span) -> Box<dyn TypeConstraint> {
SimpleTypeConstraint::new(
self.name(),
vec![
self.sort.clone(),
self.sort.clone(),
self.sort.clone(),
self.sort.clone(),
],
span.clone(),
)
.into_box()
}

fn apply(
&self,
values: &[egglog::Value],
egraph: Option<&mut egglog::EGraph>,
) -> Option<egglog::Value> {
let egraph = egraph.unwrap();
let sexp = Sexp::from_str(&egraph.extract_value_to_string(values[0])).unwrap();

info!("apply on {}", sexp);

if self.interpreter.interp_cond(&sexp) {
Some(values[1])
} else {
Some(values[2])
}
}
}

// idk why clippy complains about the two use statements below.
#[allow(unused_imports)]
pub mod tests {
use super::*;
use egglog::sort::EqSort;

#[test]
fn test_ite_create() {
#[derive(Debug)]
struct MathInterpreter;

impl PredicateInterpreter for MathInterpreter {
fn interp_cond(&self, sexp: &Sexp) -> bool {
fn interp_internal(sexp: &Sexp) -> i64 {
match sexp {
Sexp::Atom(atom) => panic!("Unexpected atom: {}", atom),
Sexp::List(l) => {
if let Sexp::Atom(op) = &l[0] {
match op.as_str() {
"Eq" => {
let a = interp_internal(&l[1]);
let b = interp_internal(&l[2]);
if a == b {
1
} else {
0
}
}
"Mul" => interp_internal(&l[1]) * interp_internal(&l[2]),
"Num" => l[1].to_string().parse().unwrap(),
_ => panic!("Unexpected operator: {:?}", op),
}
} else {
panic!("Unexpected list operator: {:?}", l[0]);
}
}
}
}

interp_internal(sexp) == 1
}
}

let math_sort = Arc::new(EqSort {
name: "Math".into(),
});
let dummy_sort = Arc::new(DummySort {
sort: math_sort.clone(),
interpreter: Arc::new(MathInterpreter),
});

let mut egraph = egglog::EGraph::default();

egraph.add_arcsort(math_sort.clone()).unwrap();
egraph.add_arcsort(dummy_sort).unwrap();

egraph
.parse_and_run_program(
None,
r#"
(function Num (i64) Math)
(function Mul (Math Math) Math)
(function Eq (Math Math) Math)

(relation universe (Math))
"#,
)
.unwrap();

egraph
.parse_and_run_program(
None,
r#"
(rule
((universe ?e))
((union ?e (ite (Eq ?e (Num 1)) (Mul ?e ?e) ?e)))
)

"#,
)
.unwrap();

egraph
.parse_and_run_program(
None,
r#"
(universe (Mul (Num 1) (Num 1)))
(universe (Num 1))
(universe (Num 2))
"#,
)
.unwrap();

egraph.parse_and_run_program(None, "(run 1000)").unwrap();

egraph
.parse_and_run_program(None, "(check (= (Mul (Num 1) (Num 1)) (Num 1)))")
.unwrap();

egraph
.parse_and_run_program(None, "(fail (check (= (Mul (Num 2) (Num 2)) (Num 2))))")
.unwrap();
}
}
66 changes: 29 additions & 37 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use egglog::{EGraph, SerializeConfig};
use ruler::enumo::Pattern;
use ruler::{HashMap, HashSet, ValidationResult};
use utils::TERM_PLACEHOLDER;
use utils::{TERM_PLACEHOLDER, UNIVERSAL_RELATION};

use std::fmt::Debug;
use std::hash::Hash;
use std::str::FromStr;

use ruler::enumo::{Sexp, Workload};

use log::info;
use ruler::enumo::{Sexp, Workload};

pub mod ite;
pub mod utils;

pub type Constant<R> = <R as Chomper>::Constant;
Expand Down Expand Up @@ -71,11 +71,6 @@ pub trait Chomper {
for current_size in 0..MAX_SIZE {
info!("adding programs of size {}:", current_size);

// let mut filter = Filter::MetricEq(Metric::Atoms, current_size);
// if current_size > 4 {
// filter = Filter::And(vec![filter, Filter::Excludes(self.constant_pattern())]);
// }

info!("finding eclass term map...");
let eclass_term_map = self
.reset_eclass_term_map(egraph)
Expand Down Expand Up @@ -388,35 +383,31 @@ pub trait Chomper {
.unwrap();
}

fn add_conditional_rewrite(
&mut self,
_egraph: &mut EGraph,
_cond: Sexp,
_lhs: Sexp,
_rhs: Sexp,
) {
fn add_conditional_rewrite(&mut self, egraph: &mut EGraph, cond: Sexp, lhs: Sexp, rhs: Sexp) {
// TODO: @ninehusky: let's brainstorm ways to encode conditional equality with respect to a
// specific condition (see #20).
// let _pred = self.make_string_not_bad(cond.to_string().as_str());
// let term1 = self.make_string_not_bad(lhs.to_string().as_str());
// let term2 = self.make_string_not_bad(rhs.to_string().as_str());
// info!(
// "adding conditional rewrite: {} -> {} if {}",
// term1, term2, _pred
// );
// info!("term2 has cvec: {:?}", self.interpret_term(&rhs));
// egraph
// .parse_and_run_program(
// None,
// format!(
// r#"
// (cond-equal {term1} {term2})
// (cond-equal {term2} {term1})
// "#
// )
// .as_str(),
// )
// .unwrap();
let cond = self.make_string_not_bad(cond.to_string().as_str());
let term1 = self.make_string_not_bad(lhs.to_string().as_str());
let term2 = self.make_string_not_bad(rhs.to_string().as_str());

info!(
"adding conditional rewrite: if {} then {} -> {}",
cond, term1, term2
);

let cond_rewrite_prog = format!(
r#"
(rule
(({UNIVERSAL_RELATION} {term1}))
((union {term1} (ite {cond} {term2} {term1}))))
"#
);

println!("cond rewrite prog: {}", cond_rewrite_prog);

egraph
.parse_and_run_program(None, &cond_rewrite_prog)
.unwrap();
}

fn has_var(&self, term: &Sexp) -> bool {
Expand All @@ -433,13 +424,14 @@ pub trait Chomper {
}
}

fn language_name() -> String;
fn productions(&self) -> Workload;
fn atoms(&self) -> Workload;
fn make_preds(&self) -> Workload;
fn get_env(&self) -> &HashMap<String, CVec<Self>>;
fn validate_rule(&self, rule: &Rule) -> ValidationResult;
fn interpret_term(&mut self, term: &ruler::enumo::Sexp) -> CVec<Self>;
fn interpret_pred(&mut self, term: &ruler::enumo::Sexp) -> Vec<bool>;
fn interpret_term(&self, term: &ruler::enumo::Sexp) -> CVec<Self>;
fn interpret_pred(&self, term: &ruler::enumo::Sexp) -> Vec<bool>;
fn constant_pattern(&self) -> Pattern;
fn matches_var_pattern(&self, term: &Sexp) -> bool;
}
1 change: 1 addition & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use ruler::enumo::Sexp;
// Atoms with this name will `not` count in a production
// toward its size.
pub const TERM_PLACEHOLDER: &str = "?term";
pub const UNIVERSAL_RELATION: &str = "universe";

pub fn get_production_size(term: &Sexp) -> usize {
get_size(term, true)
Expand Down
44 changes: 22 additions & 22 deletions tests/egglog/halide.egg
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
;;; Halide language definition.

(datatype Expr
(Lit i64)
(Var String)
(Lt Expr Expr)
(Leq Expr Expr)
(Eq Expr Expr)
(Neq Expr Expr)
(Implies Expr Expr)
(Not Expr)
(Neg Expr)
(And Expr Expr)
(Or Expr Expr)
(Xor Expr Expr)
(Add Expr Expr)
(Sub Expr Expr)
(Mul Expr Expr)
(Div Expr Expr)
(Min Expr Expr)
(Max Expr Expr)
(Select Expr Expr Expr)
)
(function Lit (i64) HalideExpr)
(function Var (String) HalideExpr)
(function Lt (HalideExpr HalideExpr) HalideExpr)
(function Leq (HalideExpr HalideExpr) HalideExpr)
(function Eq (HalideExpr HalideExpr) HalideExpr)
(function Neq (HalideExpr HalideExpr) HalideExpr)
(function Implies (HalideExpr HalideExpr) HalideExpr)
(function Not (HalideExpr) HalideExpr)
(function Neg (HalideExpr) HalideExpr)
(function And (HalideExpr HalideExpr) HalideExpr)
(function Or (HalideExpr HalideExpr) HalideExpr)
(function Xor (HalideExpr HalideExpr) HalideExpr)
(function Add (HalideExpr HalideExpr) HalideExpr)
(function Sub (HalideExpr HalideExpr) HalideExpr)
(function Mul (HalideExpr HalideExpr) HalideExpr)
(function Div (HalideExpr HalideExpr) HalideExpr)
(function Min (HalideExpr HalideExpr) HalideExpr)
(function Max (HalideExpr HalideExpr) HalideExpr)
(function Select (HalideExpr HalideExpr HalideExpr) HalideExpr)


(function eclass (Expr) i64 :merge (min old new))
(function eclass (HalideExpr) i64 :merge (min old new))

(relation universe (HalideExpr))

(ruleset eclass-report)
(ruleset non-cond-rewrites)
Expand Down
Loading
Loading