Skip to content

Commit

Permalink
Add Halide interpreter
Browse files Browse the repository at this point in the history
  • Loading branch information
ninehusky committed Nov 13, 2024
1 parent 3bde238 commit e199555
Show file tree
Hide file tree
Showing 4 changed files with 555 additions and 149 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ log = "0.4.22"
# same version as ruler
z3 = {version = "0.10.0", features = ["static-link-z3"]}
itertools = "0.13.0"
num = "0.3"

serde = "1.0.214"
serde_json = "1.0.132"
242 changes: 93 additions & 149 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use egglog::{EGraph, SerializeConfig};
use ruler::enumo::Pattern;
use ruler::{HashMap, HashSet};
use ruler::{HashMap, HashSet, ValidationResult};
use utils::TERM_PLACEHOLDER;

use std::fmt::Debug;
Expand Down Expand Up @@ -65,28 +65,18 @@ pub trait Chomper {
result
}

fn run_chompy(
&mut self,
egraph: &mut EGraph,
rules: Vec<Rule>,
mask_to_preds: &HashMap<Vec<bool>, HashSet<String>>,
memo: &mut HashSet<i64>,
) {
let mut found: Vec<bool> = vec![false; rules.len()];

fn run_chompy(&mut self, egraph: &mut EGraph) {
let mut max_eclass_id = 0;

let mut found_rules: HashSet<String> = HashSet::default();

for current_size in 0..MAX_SIZE {
info!("adding programs of size {}:", current_size);
println!("adding programs of size {}:", current_size);

let mut filter = Filter::MetricEq(Metric::Atoms, current_size);
if current_size > 15 {
filter = Filter::And(vec![filter, Filter::Excludes(self.constant_pattern())]);
}
// let mut filter = Filter::MetricEq(Metric::Atoms, current_size);
// if current_size > 4 {
// filter = Filter::And(vec![filter, Filter::Excludes(self.constant_pattern())]);
// }

info!("finding eclass term map...");
println!("finding eclass term map...");
let eclass_term_map = self
.reset_eclass_term_map(egraph)
.values()
Expand All @@ -100,20 +90,21 @@ pub trait Chomper {
);

let new_workload = if term_workload.force().is_empty() {
self.atoms().clone().filter(filter)
self.atoms().clone()
} else {
self.productions()
.clone()
.plug(TERM_PLACEHOLDER, &term_workload)
.filter(filter)
};

info!("new workload len: {}", new_workload.force().len());
println!("new workload len: {}", new_workload.force().len());

let atoms = self.atoms().force();

let memo = &mut HashSet::default();

for term in &new_workload.force() {
info!("term: {}", term);
// println!("term: {}", term);
let term_string = self.make_string_not_bad(term.to_string().as_str());
if !atoms.contains(term) && !self.has_var(term) {
continue;
Expand All @@ -125,7 +116,7 @@ pub trait Chomper {
r#"
{term_string}
(set (eclass {term_string}) {max_eclass_id})
"#
"#
)
.as_str(),
)
Expand All @@ -143,110 +134,68 @@ pub trait Chomper {
"#,
)
.unwrap();
info!("starting cvec match");
let vals = self.cvec_match(egraph, mask_to_preds, memo);
if vals.non_conditional.is_empty()
|| vals.non_conditional.iter().all(|x| {
found_rules.contains(format!("{:?}", self.generalize_rule(x)).as_str())
})
{
println!("starting cvec match");
let vals = self.cvec_match(egraph, memo);

if vals.non_conditional.is_empty() && vals.conditional.is_empty() {
break;
}

for (i, rule) in rules.iter().enumerate() {
let lhs = self.make_string_not_bad(rule.lhs.to_string().as_str());
let rhs = self.make_string_not_bad(rule.rhs.to_string().as_str());
if (rule.condition.is_some()
&& egraph
.parse_and_run_program(
None,
format!(
r#"
(check (cond-equal {lhs} {rhs}))
"#
)
.as_str(),
)
.is_ok())
|| (rule.condition.is_none()
&& egraph
.parse_and_run_program(
None,
format!(
r#"
(check (= {lhs} {rhs}))
"#
)
.as_str(),
)
.is_ok())
{
found[i] = true;
}
if found.iter().all(|x| *x) {
return;
println!("found {} non-conditional rules", vals.non_conditional.len());
println!("found {} conditional rules", vals.conditional.len());
for val in &vals.conditional {
let generalized = self.generalize_rule(val);
if let ValidationResult::Valid = self.validate_rule(&generalized) {
if utils::does_rule_have_good_vars(&generalized) {
let lhs =
self.make_string_not_bad(generalized.lhs.to_string().as_str());
let rhs =
self.make_string_not_bad(generalized.rhs.to_string().as_str());
let cond = generalized.condition.as_ref().unwrap();
let pred = self.make_string_not_bad(cond.to_string().as_str());
println!("Conditional rule: if {} then {} ~> {}", pred, lhs, rhs);
self.add_conditional_rewrite(
egraph,
Sexp::from_str(&pred).unwrap(),
Sexp::from_str(&lhs).unwrap(),
Sexp::from_str(&rhs).unwrap(),
);
}
}
}

for val in &vals.non_conditional {
let generalized = self.generalize_rule(val);
if !found_rules.contains(format!("{:?}", generalized).as_str())
&& utils::does_rule_have_good_vars(&generalized)
{
let lhs = self.make_string_not_bad(generalized.lhs.to_string().as_str());
let rhs = self.make_string_not_bad(generalized.rhs.to_string().as_str());
if egraph
.parse_and_run_program(
None,
format!(
r#"
{lhs}
{rhs}
(check (= {lhs} {rhs}))
"#
if let ValidationResult::Valid = self.validate_rule(&generalized) {
if utils::does_rule_have_good_vars(&generalized) {
let lhs =
self.make_string_not_bad(generalized.lhs.to_string().as_str());
let rhs =
self.make_string_not_bad(generalized.rhs.to_string().as_str());

if egraph
.parse_and_run_program(
None,
format!(r#"(check (= {} {}))"#, val.lhs, val.rhs).as_str(),
)
.as_str(),
)
.is_err()
{
let validated = self.get_validated_rule(&generalized);
if found_rules.contains(format!("{:?}", validated).as_str()) {
.is_ok()
{
continue;
}
found_rules.insert(format!("{:?}", validated));
if validated.is_none() {
continue;
}
let validated = validated.unwrap();
if validated.condition.is_none() {
info!("Rule: {} -> {}", validated.lhs, validated.rhs);
self.add_rewrite(egraph, validated.lhs, validated.rhs);
} else {
info!(
"Conditional Rule: if {} then {} -> {}",
validated.condition.clone().unwrap(),
validated.lhs,
validated.rhs
);
self.add_conditional_rewrite(
egraph,
validated.condition.unwrap(),
validated.lhs,
validated.rhs,
);
}

self.add_rewrite(
egraph,
Sexp::from_str(&lhs).unwrap(),
Sexp::from_str(&rhs).unwrap(),
);
// TODO: derivability check here
}
} else {
// println!(
// "perfect cvec match but failed validation: {} ~> {}",
// val.lhs, val.rhs
// );
}
}

for val in &vals.conditional {
self.add_conditional_rewrite(
egraph,
val.condition.clone().unwrap(),
val.lhs.clone(),
val.rhs.clone(),
);
}
}
}

Expand Down Expand Up @@ -278,9 +227,13 @@ pub trait Chomper {
let mut id_to_gen_id: HashMap<String, String> = HashMap::default();
let new_lhs = self.generalize_sexp(rule.lhs.clone(), &mut id_to_gen_id);
let new_rhs = self.generalize_sexp(rule.rhs.clone(), &mut id_to_gen_id);
let condition = match &rule.condition {
Some(cond) => Some(self.generalize_sexp(cond.clone(), &mut id_to_gen_id)),
None => None,
};
Rule {
// TODO: later
condition: None,
condition,
lhs: new_lhs,
rhs: new_rhs,
}
Expand Down Expand Up @@ -314,7 +267,6 @@ pub trait Chomper {
fn cvec_match(
&mut self,
egraph: &mut EGraph,
mask_to_preds: &HashMap<Vec<bool>, HashSet<String>>,
// keeps track of what eclass IDs we've seen.
memo: &mut HashSet<i64>,
) -> Rules {
Expand All @@ -323,12 +275,14 @@ pub trait Chomper {
conditional: vec![],
};

let mask_to_preds = self.make_mask_to_preds();

println!("hi from cvec match");
let serialized = egraph.serialize(SerializeConfig::default());
println!("eclasses in egraph: {}", serialized.classes().len());
println!("nodes in egraph: {}", serialized.nodes.len());
let eclass_term_map: HashMap<i64, Sexp> = self.reset_eclass_term_map(egraph);
// println!("eclass term map len: {}", eclass_term_map.len());
println!("eclass term map len: {}", eclass_term_map.len());
let ec_keys: Vec<&i64> = eclass_term_map.keys().collect();
for i in 0..ec_keys.len() {
let ec1 = ec_keys[i];
Expand Down Expand Up @@ -357,11 +311,6 @@ pub trait Chomper {
lhs: term1.clone(),
rhs: term2.clone(),
});
result.non_conditional.push(Rule {
condition: None,
lhs: term2.clone(),
rhs: term1.clone(),
});
} else {
if egraph
.parse_and_run_program(
Expand Down Expand Up @@ -390,32 +339,23 @@ pub trait Chomper {
continue;
}

// sufficient and necessary conditions.
// we may want to experiment with just having sufficient conditions.
let masks = mask_to_preds.keys().filter(|mask| {
mask.iter()
.zip(same_vals.iter())
.all(|(mask_val, same_val)| mask_val == same_val)
});
// if the mask is all false, then skip it.
if same_vals.iter().all(|x| !x) {
continue;
}

for mask in masks {
// if the mask is completely false, skip it.
if mask.iter().all(|x| !x) {
continue;
}
let preds = mask_to_preds.get(mask).unwrap();
for pred in preds {
result.conditional.push(Rule {
condition: Some(Sexp::from_str(pred).unwrap()),
lhs: term1.clone(),
rhs: term2.clone(),
});
result.conditional.push(Rule {
condition: Some(Sexp::from_str(pred).unwrap()),
lhs: term2.clone(),
rhs: term1.clone(),
});
}
// sufficient and necessary conditions.
if !mask_to_preds.contains_key(&same_vals) {
continue;
}
let preds = mask_to_preds.get(&same_vals).unwrap();
for pred in preds {
let rule = Rule {
condition: Some(Sexp::from_str(pred).unwrap()),
lhs: term1.clone(),
rhs: term2.clone(),
};
result.conditional.push(rule);
}
}
}
Expand All @@ -426,6 +366,10 @@ pub trait Chomper {
fn add_rewrite(&mut self, egraph: &mut EGraph, lhs: Sexp, rhs: Sexp) {
let term1 = self.make_string_not_bad(lhs.to_string().as_str());
let term2 = self.make_string_not_bad(rhs.to_string().as_str());
if term1 == "?a" {
return;
}
println!("Rule: {} ~> {}", term1, term2);
egraph
.parse_and_run_program(
None,
Expand Down Expand Up @@ -490,8 +434,8 @@ pub trait Chomper {
fn productions(&self) -> Workload;
fn atoms(&self) -> Workload;
fn make_preds(&self) -> Workload;
fn get_env(&self) -> &HashMap<String, Vec<Value<Self>>>;
fn get_validated_rule(&self, rule: &Rule) -> Option<Rule>;
fn get_env(&self) -> &HashMap<String, CVec<Self>>;
fn validate_rule(&self, rule: &Rule) -> ValidationResult;
fn interpret_term(&mut self, term: &ruler::enumo::Sexp) -> CVec<Self>;
fn interpret_pred(&mut self, term: &ruler::enumo::Sexp) -> Vec<bool>;
fn constant_pattern(&self) -> Pattern;
Expand Down
Loading

0 comments on commit e199555

Please sign in to comment.