Skip to content

Commit

Permalink
BigInt and BigRat (#457)
Browse files Browse the repository at this point in the history
* WIP

* Add bigrat

* Add test, fix bugs

* Improve test slightly
  • Loading branch information
Alex-Fischman authored Oct 30, 2024
1 parent a38edb3 commit 225d0a1
Show file tree
Hide file tree
Showing 10 changed files with 321 additions and 15 deletions.
38 changes: 35 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ rustc-hash = "1.1"
symbol_table = { version = "0.4.0", features = ["global"] }
thiserror = "1"
lazy_static = "1.4"
num-integer = "0.1.45"
num-rational = "0.4.1"
num-traits = "0.2.15"
num = "0.4.3"
smallvec = "1.11"

generic_symbolic_expressions = "5.0.4"
Expand Down
102 changes: 102 additions & 0 deletions src/sort/bigint.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use num::BigInt;
use std::ops::{Shl, Shr};
use std::sync::Mutex;

type Z = BigInt;
use crate::{ast::Literal, util::IndexSet};

use super::*;

lazy_static! {
static ref BIG_INT_SORT_NAME: Symbol = "BigInt".into();
static ref INTS: Mutex<IndexSet<Z>> = Default::default();
}

#[derive(Debug)]
pub struct BigIntSort;

impl Sort for BigIntSort {
fn name(&self) -> Symbol {
*BIG_INT_SORT_NAME
}

fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
self
}

#[rustfmt::skip]
fn register_primitives(self: Arc<Self>, eg: &mut TypeInfo) {
type Opt<T=()> = Option<T>;

add_primitives!(eg, "bigint" = |a: i64| -> Z { a.into() });

add_primitives!(eg, "+" = |a: Z, b: Z| -> Z { a + b });
add_primitives!(eg, "-" = |a: Z, b: Z| -> Z { a - b });
add_primitives!(eg, "*" = |a: Z, b: Z| -> Z { a * b });
add_primitives!(eg, "/" = |a: Z, b: Z| -> Opt<Z> { (b != BigInt::ZERO).then(|| a / b) });
add_primitives!(eg, "%" = |a: Z, b: Z| -> Opt<Z> { (b != BigInt::ZERO).then(|| a % b) });

add_primitives!(eg, "&" = |a: Z, b: Z| -> Z { a & b });
add_primitives!(eg, "|" = |a: Z, b: Z| -> Z { a | b });
add_primitives!(eg, "^" = |a: Z, b: Z| -> Z { a ^ b });
add_primitives!(eg, "<<" = |a: Z, b: i64| -> Z { a.shl(b) });
add_primitives!(eg, ">>" = |a: Z, b: i64| -> Z { a.shr(b) });
add_primitives!(eg, "not-Z" = |a: Z| -> Z { !a });

add_primitives!(eg, "bits" = |a: Z| -> Z { a.bits().into() });

add_primitives!(eg, "<" = |a: Z, b: Z| -> Opt { (a < b).then_some(()) });
add_primitives!(eg, ">" = |a: Z, b: Z| -> Opt { (a > b).then_some(()) });
add_primitives!(eg, "<=" = |a: Z, b: Z| -> Opt { (a <= b).then_some(()) });
add_primitives!(eg, ">=" = |a: Z, b: Z| -> Opt { (a >= b).then_some(()) });

add_primitives!(eg, "bool-=" = |a: Z, b: Z| -> bool { a == b });
add_primitives!(eg, "bool-<" = |a: Z, b: Z| -> bool { a < b });
add_primitives!(eg, "bool->" = |a: Z, b: Z| -> bool { a > b });
add_primitives!(eg, "bool-<=" = |a: Z, b: Z| -> bool { a <= b });
add_primitives!(eg, "bool->=" = |a: Z, b: Z| -> bool { a >= b });

add_primitives!(eg, "min" = |a: Z, b: Z| -> Z { a.min(b) });
add_primitives!(eg, "max" = |a: Z, b: Z| -> Z { a.max(b) });

add_primitives!(eg, "to-string" = |a: Z| -> Symbol { a.to_string().into() });
add_primitives!(eg, "from-string" = |a: Symbol| -> Opt<Z> { a.as_str().parse::<Z>().ok() });
}

fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) {
#[cfg(debug_assertions)]
debug_assert_eq!(value.tag, self.name());

let bigint = Z::load(self, &value);
(
1,
Expr::call_no_span(
"from-string",
vec![GenericExpr::Lit(
DUMMY_SPAN.clone(),
Literal::String(bigint.to_string().into()),
)],
),
)
}
}

impl FromSort for Z {
type Sort = BigIntSort;
fn load(_sort: &Self::Sort, value: &Value) -> Self {
let i = value.bits as usize;
INTS.lock().unwrap().get_index(i).unwrap().clone()
}
}

impl IntoSort for Z {
type Sort = BigIntSort;
fn store(self, _sort: &Self::Sort) -> Option<Value> {
let (i, _) = INTS.lock().unwrap().insert_full(self);
Some(Value {
#[cfg(debug_assertions)]
tag: BigIntSort.name(),
bits: i as u64,
})
}
}
155 changes: 155 additions & 0 deletions src/sort/bigrat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
use num::traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, One, Signed, ToPrimitive, Zero};
use num::{rational::BigRational, BigInt};
use std::sync::Mutex;

type Z = BigInt;
type Q = BigRational;
use crate::{ast::Literal, util::IndexSet};

use super::*;

lazy_static! {
static ref BIG_RAT_SORT_NAME: Symbol = "BigRat".into();
static ref RATS: Mutex<IndexSet<Q>> = Default::default();
}

#[derive(Debug)]
pub struct BigRatSort;

impl Sort for BigRatSort {
fn name(&self) -> Symbol {
*BIG_RAT_SORT_NAME
}

fn as_arc_any(self: Arc<Self>) -> Arc<dyn Any + Send + Sync + 'static> {
self
}

#[rustfmt::skip]
fn register_primitives(self: Arc<Self>, eg: &mut TypeInfo) {
type Opt<T=()> = Option<T>;

add_primitives!(eg, "+" = |a: Q, b: Q| -> Opt<Q> { a.checked_add(&b) });
add_primitives!(eg, "-" = |a: Q, b: Q| -> Opt<Q> { a.checked_sub(&b) });
add_primitives!(eg, "*" = |a: Q, b: Q| -> Opt<Q> { a.checked_mul(&b) });
add_primitives!(eg, "/" = |a: Q, b: Q| -> Opt<Q> { a.checked_div(&b) });

add_primitives!(eg, "min" = |a: Q, b: Q| -> Q { a.min(b) });
add_primitives!(eg, "max" = |a: Q, b: Q| -> Q { a.max(b) });
add_primitives!(eg, "neg" = |a: Q| -> Q { -a });
add_primitives!(eg, "abs" = |a: Q| -> Q { a.abs() });
add_primitives!(eg, "floor" = |a: Q| -> Q { a.floor() });
add_primitives!(eg, "ceil" = |a: Q| -> Q { a.ceil() });
add_primitives!(eg, "round" = |a: Q| -> Q { a.round() });
add_primitives!(eg, "bigrat" = |a: Z, b: Z| -> Q { Q::new(a, b) });
add_primitives!(eg, "numer" = |a: Q| -> Z { a.numer().clone() });
add_primitives!(eg, "denom" = |a: Q| -> Z { a.denom().clone() });

add_primitives!(eg, "to-f64" = |a: Q| -> f64 { a.to_f64().unwrap() });

add_primitives!(eg, "pow" = |a: Q, b: Q| -> Option<Q> {
if a.is_zero() {
if b.is_positive() {
Some(Q::zero())
} else {
None
}
} else if b.is_zero() {
Some(Q::one())
} else if let Some(b) = b.to_i64() {
if let Ok(b) = usize::try_from(b) {
num::traits::checked_pow(a, b)
} else {
// TODO handle negative powers
None
}
} else {
None
}
});
add_primitives!(eg, "log" = |a: Q| -> Option<Q> {
if a.is_one() {
Some(Q::zero())
} else {
todo!()
}
});
add_primitives!(eg, "sqrt" = |a: Q| -> Option<Q> {
if a.numer().is_positive() && a.denom().is_positive() {
let s1 = a.numer().sqrt();
let s2 = a.denom().sqrt();
let is_perfect = &(s1.clone() * s1.clone()) == a.numer() && &(s2.clone() * s2.clone()) == a.denom();
if is_perfect {
Some(Q::new(s1, s2))
} else {
None
}
} else {
None
}
});
add_primitives!(eg, "cbrt" = |a: Q| -> Option<Q> {
if a.is_one() {
Some(Q::one())
} else {
todo!()
}
});

add_primitives!(eg, "<" = |a: Q, b: Q| -> Opt { if a < b {Some(())} else {None} });
add_primitives!(eg, ">" = |a: Q, b: Q| -> Opt { if a > b {Some(())} else {None} });
add_primitives!(eg, "<=" = |a: Q, b: Q| -> Opt { if a <= b {Some(())} else {None} });
add_primitives!(eg, ">=" = |a: Q, b: Q| -> Opt { if a >= b {Some(())} else {None} });
}

fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) {
#[cfg(debug_assertions)]
debug_assert_eq!(value.tag, self.name());

let rat = Q::load(self, &value);
let numer = rat.numer();
let denom = rat.denom();
(
1,
Expr::call_no_span(
"bigrat",
vec![
Expr::call_no_span(
"from-string",
vec![GenericExpr::Lit(
DUMMY_SPAN.clone(),
Literal::String(numer.to_string().into()),
)],
),
Expr::call_no_span(
"from-string",
vec![GenericExpr::Lit(
DUMMY_SPAN.clone(),
Literal::String(denom.to_string().into()),
)],
),
],
),
)
}
}

impl FromSort for Q {
type Sort = BigRatSort;
fn load(_sort: &Self::Sort, value: &Value) -> Self {
let i = value.bits as usize;
RATS.lock().unwrap().get_index(i).unwrap().clone()
}
}

impl IntoSort for Q {
type Sort = BigRatSort;
fn store(self, _sort: &Self::Sort) -> Option<Value> {
let (i, _) = RATS.lock().unwrap().insert_full(self);
Some(Value {
#[cfg(debug_assertions)]
tag: BigRatSort.name(),
bits: i as u64,
})
}
}
8 changes: 4 additions & 4 deletions src/sort/i64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ impl Sort for I64Sort {

add_primitives!(typeinfo, "log2" = |a: i64| -> i64 { (a as i64).ilog2() as i64 });

add_primitives!(typeinfo, "<" = |a: i64, b: i64| -> Opt { (a < b).then(|| ()) });
add_primitives!(typeinfo, ">" = |a: i64, b: i64| -> Opt { (a > b).then(|| ()) });
add_primitives!(typeinfo, "<=" = |a: i64, b: i64| -> Opt { (a <= b).then(|| ()) });
add_primitives!(typeinfo, ">=" = |a: i64, b: i64| -> Opt { (a >= b).then(|| ()) });
add_primitives!(typeinfo, "<" = |a: i64, b: i64| -> Opt { (a < b).then_some(()) });
add_primitives!(typeinfo, ">" = |a: i64, b: i64| -> Opt { (a > b).then_some(()) });
add_primitives!(typeinfo, "<=" = |a: i64, b: i64| -> Opt { (a <= b).then_some(()) });
add_primitives!(typeinfo, ">=" = |a: i64, b: i64| -> Opt { (a >= b).then_some(()) });

add_primitives!(typeinfo, "bool-=" = |a: i64, b: i64| -> bool { a == b });
add_primitives!(typeinfo, "bool-<" = |a: i64, b: i64| -> bool { a < b });
Expand Down
2 changes: 1 addition & 1 deletion src/sort/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ macro_rules! add_primitives {
&self,
span: &Span
) -> Box<dyn TypeConstraint> {
let sorts = vec![$(self.$param.clone(),)* self.__out.clone() as ArcSort];
let sorts = vec![$(self.$param.clone() as ArcSort,)* self.__out.clone() as ArcSort];
SimpleTypeConstraint::new(self.name(), sorts, span.clone()).into_box()
}

Expand Down
4 changes: 4 additions & 0 deletions src/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ use lazy_static::lazy_static;
use std::fmt::Debug;
use std::{any::Any, sync::Arc};

mod bigint;
pub use bigint::*;
mod bigrat;
pub use bigrat::*;
mod bool;
pub use self::bool::*;
mod rational;
Expand Down
8 changes: 4 additions & 4 deletions src/sort/rational.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use num_integer::Roots;
use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, One, Signed, ToPrimitive, Zero};
use num::integer::Roots;
use num::traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, One, Signed, ToPrimitive, Zero};
use std::sync::Mutex;

type R = num_rational::Rational64;
type R = num::rational::Rational64;
use crate::{ast::Literal, util::IndexSet};

use super::*;
Expand Down Expand Up @@ -59,7 +59,7 @@ impl Sort for RationalSort {
Some(R::one())
} else if let Some(b) = b.to_i64() {
if let Ok(b) = usize::try_from(b) {
num_traits::checked_pow(a, b)
num::traits::checked_pow(a, b)
} else {
// TODO handle negative powers
None
Expand Down
Loading

0 comments on commit 225d0a1

Please sign in to comment.