Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sort declaration cleanup #442

Merged
merged 12 commits into from
Oct 16, 2024
2 changes: 1 addition & 1 deletion src/actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ impl EGraph {
let ts = self.timestamp;
let out = &function.schema.output;
match function.decl.default.as_ref() {
None if out.name() == UNIT_SYM.into() => {
None if out.name() == UnitSort.name() => {
function.insert(values, Value::unit(), ts);
Value::unit()
}
Expand Down
4 changes: 2 additions & 2 deletions src/ast/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ pub enum GenericExpr<Head, Leaf> {
}

impl ResolvedExpr {
pub fn output_type(&self, type_info: &TypeInfo) -> ArcSort {
pub fn output_type(&self) -> ArcSort {
Alex-Fischman marked this conversation as resolved.
Show resolved Hide resolved
match self {
ResolvedExpr::Lit(_, lit) => type_info.infer_literal(lit),
ResolvedExpr::Lit(_, lit) => sort::literal_sort(lit),
ResolvedExpr::Var(_, resolved_var) => resolved_var.sort.clone(),
ResolvedExpr::Call(_, resolved_call, _) => resolved_call.output().clone(),
}
Expand Down
18 changes: 5 additions & 13 deletions src/ast/remove_globals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
//! When a globally-bound primitive value is used in the actions of a rule,
//! we add a new variable to the query bound to the primitive value.

use crate::{
core::ResolvedCall, typechecking::FuncType, FreshGen, GenericAction, GenericActions,
GenericExpr, GenericFact, GenericNCommand, GenericRule, HashMap, ResolvedAction, ResolvedExpr,
ResolvedFact, ResolvedFunctionDecl, ResolvedNCommand, ResolvedVar, Schema, SymbolGen, TypeInfo,
};
use crate::*;
use crate::{core::ResolvedCall, typechecking::FuncType};

struct GlobalRemover<'a> {
fresh: &'a mut SymbolGen,
Expand Down Expand Up @@ -45,13 +42,12 @@ struct GlobalRemover<'a> {
/// ((Add fresh_var_for_x fresh_var_for_x)))
/// ```
pub(crate) fn remove_globals(
type_info: &TypeInfo,
prog: Vec<ResolvedNCommand>,
fresh: &mut SymbolGen,
) -> Vec<ResolvedNCommand> {
let mut remover = GlobalRemover { fresh };
prog.into_iter()
.flat_map(|cmd| remover.remove_globals_cmd(type_info, cmd))
.flat_map(|cmd| remover.remove_globals_cmd(cmd))
.collect()
}

Expand Down Expand Up @@ -91,15 +87,11 @@ fn remove_globals_action(action: ResolvedAction) -> ResolvedAction {
}

impl<'a> GlobalRemover<'a> {
fn remove_globals_cmd(
&mut self,
type_info: &TypeInfo,
cmd: ResolvedNCommand,
) -> Vec<ResolvedNCommand> {
fn remove_globals_cmd(&mut self, cmd: ResolvedNCommand) -> Vec<ResolvedNCommand> {
match cmd {
GenericNCommand::CoreAction(action) => match action {
GenericAction::Let(span, name, expr) => {
let ty = expr.output_type(type_info);
let ty = expr.output_type();

let func_decl = ResolvedFunctionDecl {
name: name.name,
Expand Down
17 changes: 7 additions & 10 deletions src/constraint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ impl Assignment<AtomTerm, ArcSort> {
.collect();
let types: Vec<_> = args
.iter()
.map(|arg| arg.output_type(typeinfo))
.map(|arg| arg.output_type())
.chain(once(
self.get(&AtomTerm::Var(DUMMY_SPAN.clone(), *corresponding_var))
.unwrap()
Expand Down Expand Up @@ -351,8 +351,8 @@ impl Assignment<AtomTerm, ArcSort> {
let rhs = self.annotate_expr(rhs, typeinfo);
let types: Vec<_> = children
.iter()
.map(|child| child.output_type(typeinfo))
.chain(once(rhs.output_type(typeinfo)))
.map(|child| child.output_type())
.chain(once(rhs.output_type()))
.collect();
let resolved_call = ResolvedCall::from_resolution(head, &types, typeinfo);
if !matches!(resolved_call, ResolvedCall::Func(_)) {
Expand All @@ -379,10 +379,7 @@ impl Assignment<AtomTerm, ArcSort> {
.iter()
.map(|child| self.annotate_expr(child, typeinfo))
.collect();
let types: Vec<_> = children
.iter()
.map(|child| child.output_type(typeinfo))
.collect();
let types: Vec<_> = children.iter().map(|child| child.output_type()).collect();
let resolved_call =
ResolvedCall::from_resolution_func_types(head, &types, typeinfo)
.ok_or_else(|| TypeError::UnboundFunction(*head, span.clone()))?;
Expand Down Expand Up @@ -568,7 +565,7 @@ impl CoreAction {
get_literal_and_global_constraints(&[e.clone(), n.clone()], typeinfo)
.chain(once(Constraint::Assign(
n.clone(),
typeinfo.get_sort_nofail::<I64Sort>() as ArcSort,
std::sync::Arc::new(I64Sort) as ArcSort,
)))
.collect(),
)
Expand Down Expand Up @@ -684,8 +681,8 @@ fn get_literal_and_global_constraints<'a>(
AtomTerm::Var(_, _) => None,
// Literal to type constraint
AtomTerm::Literal(_, lit) => {
let typ = type_info.infer_literal(lit);
Some(Constraint::Assign(arg.clone(), typ.clone()))
let typ = crate::sort::literal_sort(lit);
Some(Constraint::Assign(arg.clone(), typ))
}
AtomTerm::Global(_, v) => {
if let Some(typ) = type_info.lookup_global(v) {
Expand Down
12 changes: 5 additions & 7 deletions src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
use std::hash::Hasher;
use std::ops::AddAssign;

use crate::HashMap;
use crate::{typechecking::FuncType, *};
use crate::{typechecking::FuncType, HashMap, *};
use typechecking::TypeError;

#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -190,10 +189,10 @@ impl<Leaf: Clone> GenericAtomTerm<Leaf> {
}

impl ResolvedAtomTerm {
pub fn output(&self, typeinfo: &TypeInfo) -> ArcSort {
pub fn output(&self) -> ArcSort {
match self {
ResolvedAtomTerm::Var(_, v) => v.sort.clone(),
ResolvedAtomTerm::Literal(_, l) => typeinfo.infer_literal(l),
ResolvedAtomTerm::Literal(_, l) => literal_sort(l),
ResolvedAtomTerm::Global(_, v) => v.sort.clone(),
}
}
Expand Down Expand Up @@ -838,12 +837,11 @@ impl ResolvedRule {
fresh_gen: &mut SymbolGen,
) -> Result<ResolvedCoreRule, TypeError> {
let value_eq = &typeinfo.primitives.get(&Symbol::from("value-eq")).unwrap()[0];
let unit = typeinfo.get_sort_nofail::<UnitSort>();
self.to_canonicalized_core_rule_impl(typeinfo, fresh_gen, |at1, at2| {
ResolvedCall::Primitive(SpecializedPrimitive {
primitive: value_eq.clone(),
input: vec![at1.output(typeinfo), at2.output(typeinfo)],
output: unit.clone(),
input: vec![at1.output(), at2.output()],
output: Arc::new(UnitSort),
})
})
}
Expand Down
40 changes: 16 additions & 24 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use thiserror::Error;
use generic_symbolic_expressions::Sexp;

use ast::*;
pub use typechecking::{TypeInfo, UNIT_SYM};
pub use typechecking::TypeInfo;

use crate::core::{AtomTerm, ResolvedCall};
use actions::Program;
Expand Down Expand Up @@ -678,11 +678,11 @@ impl EGraph {

pub fn eval_lit(&self, lit: &Literal) -> Value {
match lit {
Literal::Int(i) => i.store(&self.type_info.get_sort_nofail()).unwrap(),
Literal::F64(f) => f.store(&self.type_info.get_sort_nofail()).unwrap(),
Literal::String(s) => s.store(&self.type_info.get_sort_nofail()).unwrap(),
Literal::Unit => ().store(&self.type_info.get_sort_nofail()).unwrap(),
Literal::Bool(b) => b.store(&self.type_info.get_sort_nofail()).unwrap(),
Literal::Int(i) => i.store(&I64Sort).unwrap(),
Literal::F64(f) => f.store(&F64Sort).unwrap(),
Literal::String(s) => s.store(&StringSort).unwrap(),
Literal::Unit => ().store(&UnitSort).unwrap(),
Literal::Bool(b) => b.store(&BoolSort).unwrap(),
}
}

Expand Down Expand Up @@ -739,7 +739,7 @@ impl EGraph {
.get(&sym)
// function_to_dag should have checked this
.unwrap();
let out_is_unit = f.schema.output.name() == UNIT_SYM.into();
let out_is_unit = f.schema.output.name() == UnitSort.name();

let mut buf = String::new();
let s = &mut buf;
Expand Down Expand Up @@ -1300,7 +1300,7 @@ impl EGraph {
let mut termdag = TermDag::default();
for expr in exprs {
let value = self.eval_resolved_expr(&expr)?;
let expr_type = expr.output_type(&self.type_info);
let expr_type = expr.output_type();
let term = self.extract(value, &mut termdag, &expr_type).1;
use std::io::Write;
writeln!(f, "{}", termdag.to_string(&term))
Expand Down Expand Up @@ -1367,7 +1367,7 @@ impl EGraph {
let mut exprs: Vec<Expr> = str_buf.iter().map(|&s| parse(s)).collect();

actions.push(
if function_type.is_datatype || function_type.output.name() == UNIT_SYM.into() {
if function_type.is_datatype || function_type.output.name() == UnitSort.name() {
Action::Expr(span.clone(), Expr::Call(span.clone(), func_name, exprs))
} else {
let out = exprs.pop().unwrap();
Expand Down Expand Up @@ -1412,7 +1412,7 @@ impl EGraph {
.type_info
.typecheck_program(&mut self.symbol_gen, &program)?;

let program = remove_globals(&self.type_info, program, &mut self.symbol_gen);
let program = remove_globals(program, &mut self.symbol_gen);

Ok(program)
}
Expand Down Expand Up @@ -1476,11 +1476,6 @@ impl EGraph {
self.type_info.get_sort_by(pred)
}

/// Returns a sort based on the type
Alex-Fischman marked this conversation as resolved.
Show resolved Hide resolved
pub fn get_sort<S: Sort + Send + Sync>(&self) -> Option<Arc<S>> {
self.type_info.get_sort_by(|_| true)
}

/// Add a user-defined sort
pub fn add_arcsort(&mut self, arcsort: ArcSort) -> Result<(), TypeError> {
self.type_info.add_arcsort(arcsort, DUMMY_SPAN.clone())
Expand Down Expand Up @@ -1601,21 +1596,18 @@ mod tests {
fn test_user_defined_primitive() {
let mut egraph = EGraph::default();
egraph
.parse_and_run_program(
None,
"
(sort IntVec (Vec i64))
",
)
.parse_and_run_program(None, "(sort IntVec (Vec i64))")
.unwrap();
let i64_sort: Arc<I64Sort> = egraph.get_sort().unwrap();

let int_vec_sort: Arc<VecSort> = egraph
.get_sort_by(|s: &Arc<VecSort>| s.element_name() == i64_sort.name())
.get_sort_by(|s: &Arc<VecSort>| s.element_name() == I64Sort.name())
.unwrap();

egraph.add_primitive(InnerProduct {
ele: i64_sort,
ele: I64Sort.into(),
vec: int_vec_sort,
});

egraph
.parse_and_run_program(
None,
Expand Down
14 changes: 5 additions & 9 deletions src/sort/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,15 @@ use crate::ast::Literal;
use super::*;

#[derive(Debug)]
pub struct BoolSort {
name: Symbol,
}
pub struct BoolSort;

impl BoolSort {
pub fn new(name: Symbol) -> Self {
Self { name }
}
lazy_static! {
static ref BOOL_SORT_NAME: Symbol = "bool".into();
}

impl Sort for BoolSort {
fn name(&self) -> Symbol {
self.name
*BOOL_SORT_NAME
}

fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
Expand Down Expand Up @@ -44,7 +40,7 @@ impl IntoSort for bool {
type Sort = BoolSort;
fn store(self, sort: &Self::Sort) -> Option<Value> {
Some(Value {
tag: sort.name,
tag: sort.name(),
bits: self as u64,
})
}
Expand Down
14 changes: 5 additions & 9 deletions src/sort/f64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,15 @@ use crate::ast::Literal;
use ordered_float::OrderedFloat;

#[derive(Debug)]
pub struct F64Sort {
name: Symbol,
}
pub struct F64Sort;

impl F64Sort {
pub fn new(name: Symbol) -> Self {
Self { name }
}
lazy_static! {
static ref F64_SORT_NAME: Symbol = "f64".into();
}

impl Sort for F64Sort {
fn name(&self) -> Symbol {
self.name
*F64_SORT_NAME
}

fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
Expand Down Expand Up @@ -70,7 +66,7 @@ impl IntoSort for f64 {
type Sort = F64Sort;
fn store(self, sort: &Self::Sort) -> Option<Value> {
Some(Value {
tag: sort.name,
tag: sort.name(),
bits: self.to_bits(),
})
}
Expand Down
Loading
Loading