diff --git a/Cargo.lock b/Cargo.lock index e8ea9802..428ae86b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -387,9 +387,7 @@ dependencies = [ "lazy_static", "libtest-mimic", "log", - "num-integer", - "num-rational", - "num-traits", + "num", "ordered-float", "rustc-hash", "serde_json", @@ -716,6 +714,20 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8452105ba047068f40ff7093dd1d9da90898e63dd61736462e9cdda6a90ad3c3" +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -726,6 +738,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-integer" version = "0.1.46" @@ -735,6 +756,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-rational" version = "0.4.2" diff --git a/Cargo.toml b/Cargo.toml index b1f366d8..67d2ec1b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,9 +34,7 @@ rustc-hash = "1.1" symbol_table = { version = "0.4.0", features = ["global"] } thiserror = "1" lazy_static = "1.4" -num-integer = "0.1.45" -num-rational = "0.4.1" -num-traits = "0.2.15" +num = "0.4.3" smallvec = "1.11" generic_symbolic_expressions = "5.0.4" diff --git a/src/sort/bigint.rs b/src/sort/bigint.rs new file mode 100644 index 00000000..e4ef63f3 --- /dev/null +++ b/src/sort/bigint.rs @@ -0,0 +1,102 @@ +use num::BigInt; +use std::ops::{Shl, Shr}; +use std::sync::Mutex; + +type Z = BigInt; +use crate::{ast::Literal, util::IndexSet}; + +use super::*; + +lazy_static! { + static ref BIG_INT_SORT_NAME: Symbol = "BigInt".into(); + static ref INTS: Mutex> = Default::default(); +} + +#[derive(Debug)] +pub struct BigIntSort; + +impl Sort for BigIntSort { + fn name(&self) -> Symbol { + *BIG_INT_SORT_NAME + } + + fn as_arc_any(self: Arc) -> Arc { + self + } + + #[rustfmt::skip] + fn register_primitives(self: Arc, eg: &mut TypeInfo) { + type Opt = Option; + + add_primitives!(eg, "bigint" = |a: i64| -> Z { a.into() }); + + add_primitives!(eg, "+" = |a: Z, b: Z| -> Z { a + b }); + add_primitives!(eg, "-" = |a: Z, b: Z| -> Z { a - b }); + add_primitives!(eg, "*" = |a: Z, b: Z| -> Z { a * b }); + add_primitives!(eg, "/" = |a: Z, b: Z| -> Opt { (b != BigInt::ZERO).then(|| a / b) }); + add_primitives!(eg, "%" = |a: Z, b: Z| -> Opt { (b != BigInt::ZERO).then(|| a % b) }); + + add_primitives!(eg, "&" = |a: Z, b: Z| -> Z { a & b }); + add_primitives!(eg, "|" = |a: Z, b: Z| -> Z { a | b }); + add_primitives!(eg, "^" = |a: Z, b: Z| -> Z { a ^ b }); + add_primitives!(eg, "<<" = |a: Z, b: i64| -> Z { a.shl(b) }); + add_primitives!(eg, ">>" = |a: Z, b: i64| -> Z { a.shr(b) }); + add_primitives!(eg, "not-Z" = |a: Z| -> Z { !a }); + + add_primitives!(eg, "bits" = |a: Z| -> Z { a.bits().into() }); + + add_primitives!(eg, "<" = |a: Z, b: Z| -> Opt { (a < b).then_some(()) }); + add_primitives!(eg, ">" = |a: Z, b: Z| -> Opt { (a > b).then_some(()) }); + add_primitives!(eg, "<=" = |a: Z, b: Z| -> Opt { (a <= b).then_some(()) }); + add_primitives!(eg, ">=" = |a: Z, b: Z| -> Opt { (a >= b).then_some(()) }); + + add_primitives!(eg, "bool-=" = |a: Z, b: Z| -> bool { a == b }); + add_primitives!(eg, "bool-<" = |a: Z, b: Z| -> bool { a < b }); + add_primitives!(eg, "bool->" = |a: Z, b: Z| -> bool { a > b }); + add_primitives!(eg, "bool-<=" = |a: Z, b: Z| -> bool { a <= b }); + add_primitives!(eg, "bool->=" = |a: Z, b: Z| -> bool { a >= b }); + + add_primitives!(eg, "min" = |a: Z, b: Z| -> Z { a.min(b) }); + add_primitives!(eg, "max" = |a: Z, b: Z| -> Z { a.max(b) }); + + add_primitives!(eg, "to-string" = |a: Z| -> Symbol { a.to_string().into() }); + add_primitives!(eg, "from-string" = |a: Symbol| -> Opt { a.as_str().parse::().ok() }); + } + + fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) { + #[cfg(debug_assertions)] + debug_assert_eq!(value.tag, self.name()); + + let bigint = Z::load(self, &value); + ( + 1, + Expr::call_no_span( + "from-string", + vec![GenericExpr::Lit( + DUMMY_SPAN.clone(), + Literal::String(bigint.to_string().into()), + )], + ), + ) + } +} + +impl FromSort for Z { + type Sort = BigIntSort; + fn load(_sort: &Self::Sort, value: &Value) -> Self { + let i = value.bits as usize; + INTS.lock().unwrap().get_index(i).unwrap().clone() + } +} + +impl IntoSort for Z { + type Sort = BigIntSort; + fn store(self, _sort: &Self::Sort) -> Option { + let (i, _) = INTS.lock().unwrap().insert_full(self); + Some(Value { + #[cfg(debug_assertions)] + tag: BigIntSort.name(), + bits: i as u64, + }) + } +} diff --git a/src/sort/bigrat.rs b/src/sort/bigrat.rs new file mode 100644 index 00000000..8e847d40 --- /dev/null +++ b/src/sort/bigrat.rs @@ -0,0 +1,155 @@ +use num::traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, One, Signed, ToPrimitive, Zero}; +use num::{rational::BigRational, BigInt}; +use std::sync::Mutex; + +type Z = BigInt; +type Q = BigRational; +use crate::{ast::Literal, util::IndexSet}; + +use super::*; + +lazy_static! { + static ref BIG_RAT_SORT_NAME: Symbol = "BigRat".into(); + static ref RATS: Mutex> = Default::default(); +} + +#[derive(Debug)] +pub struct BigRatSort; + +impl Sort for BigRatSort { + fn name(&self) -> Symbol { + *BIG_RAT_SORT_NAME + } + + fn as_arc_any(self: Arc) -> Arc { + self + } + + #[rustfmt::skip] + fn register_primitives(self: Arc, eg: &mut TypeInfo) { + type Opt = Option; + + add_primitives!(eg, "+" = |a: Q, b: Q| -> Opt { a.checked_add(&b) }); + add_primitives!(eg, "-" = |a: Q, b: Q| -> Opt { a.checked_sub(&b) }); + add_primitives!(eg, "*" = |a: Q, b: Q| -> Opt { a.checked_mul(&b) }); + add_primitives!(eg, "/" = |a: Q, b: Q| -> Opt { a.checked_div(&b) }); + + add_primitives!(eg, "min" = |a: Q, b: Q| -> Q { a.min(b) }); + add_primitives!(eg, "max" = |a: Q, b: Q| -> Q { a.max(b) }); + add_primitives!(eg, "neg" = |a: Q| -> Q { -a }); + add_primitives!(eg, "abs" = |a: Q| -> Q { a.abs() }); + add_primitives!(eg, "floor" = |a: Q| -> Q { a.floor() }); + add_primitives!(eg, "ceil" = |a: Q| -> Q { a.ceil() }); + add_primitives!(eg, "round" = |a: Q| -> Q { a.round() }); + add_primitives!(eg, "bigrat" = |a: Z, b: Z| -> Q { Q::new(a, b) }); + add_primitives!(eg, "numer" = |a: Q| -> Z { a.numer().clone() }); + add_primitives!(eg, "denom" = |a: Q| -> Z { a.denom().clone() }); + + add_primitives!(eg, "to-f64" = |a: Q| -> f64 { a.to_f64().unwrap() }); + + add_primitives!(eg, "pow" = |a: Q, b: Q| -> Option { + if a.is_zero() { + if b.is_positive() { + Some(Q::zero()) + } else { + None + } + } else if b.is_zero() { + Some(Q::one()) + } else if let Some(b) = b.to_i64() { + if let Ok(b) = usize::try_from(b) { + num::traits::checked_pow(a, b) + } else { + // TODO handle negative powers + None + } + } else { + None + } + }); + add_primitives!(eg, "log" = |a: Q| -> Option { + if a.is_one() { + Some(Q::zero()) + } else { + todo!() + } + }); + add_primitives!(eg, "sqrt" = |a: Q| -> Option { + if a.numer().is_positive() && a.denom().is_positive() { + let s1 = a.numer().sqrt(); + let s2 = a.denom().sqrt(); + let is_perfect = &(s1.clone() * s1.clone()) == a.numer() && &(s2.clone() * s2.clone()) == a.denom(); + if is_perfect { + Some(Q::new(s1, s2)) + } else { + None + } + } else { + None + } + }); + add_primitives!(eg, "cbrt" = |a: Q| -> Option { + if a.is_one() { + Some(Q::one()) + } else { + todo!() + } + }); + + add_primitives!(eg, "<" = |a: Q, b: Q| -> Opt { if a < b {Some(())} else {None} }); + add_primitives!(eg, ">" = |a: Q, b: Q| -> Opt { if a > b {Some(())} else {None} }); + add_primitives!(eg, "<=" = |a: Q, b: Q| -> Opt { if a <= b {Some(())} else {None} }); + add_primitives!(eg, ">=" = |a: Q, b: Q| -> Opt { if a >= b {Some(())} else {None} }); + } + + fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) { + #[cfg(debug_assertions)] + debug_assert_eq!(value.tag, self.name()); + + let rat = Q::load(self, &value); + let numer = rat.numer(); + let denom = rat.denom(); + ( + 1, + Expr::call_no_span( + "bigrat", + vec![ + Expr::call_no_span( + "from-string", + vec![GenericExpr::Lit( + DUMMY_SPAN.clone(), + Literal::String(numer.to_string().into()), + )], + ), + Expr::call_no_span( + "from-string", + vec![GenericExpr::Lit( + DUMMY_SPAN.clone(), + Literal::String(denom.to_string().into()), + )], + ), + ], + ), + ) + } +} + +impl FromSort for Q { + type Sort = BigRatSort; + fn load(_sort: &Self::Sort, value: &Value) -> Self { + let i = value.bits as usize; + RATS.lock().unwrap().get_index(i).unwrap().clone() + } +} + +impl IntoSort for Q { + type Sort = BigRatSort; + fn store(self, _sort: &Self::Sort) -> Option { + let (i, _) = RATS.lock().unwrap().insert_full(self); + Some(Value { + #[cfg(debug_assertions)] + tag: BigRatSort.name(), + bits: i as u64, + }) + } +} diff --git a/src/sort/i64.rs b/src/sort/i64.rs index 81d623f5..dfc797a8 100644 --- a/src/sort/i64.rs +++ b/src/sort/i64.rs @@ -45,10 +45,10 @@ impl Sort for I64Sort { add_primitives!(typeinfo, "log2" = |a: i64| -> i64 { (a as i64).ilog2() as i64 }); - add_primitives!(typeinfo, "<" = |a: i64, b: i64| -> Opt { (a < b).then(|| ()) }); - add_primitives!(typeinfo, ">" = |a: i64, b: i64| -> Opt { (a > b).then(|| ()) }); - add_primitives!(typeinfo, "<=" = |a: i64, b: i64| -> Opt { (a <= b).then(|| ()) }); - add_primitives!(typeinfo, ">=" = |a: i64, b: i64| -> Opt { (a >= b).then(|| ()) }); + add_primitives!(typeinfo, "<" = |a: i64, b: i64| -> Opt { (a < b).then_some(()) }); + add_primitives!(typeinfo, ">" = |a: i64, b: i64| -> Opt { (a > b).then_some(()) }); + add_primitives!(typeinfo, "<=" = |a: i64, b: i64| -> Opt { (a <= b).then_some(()) }); + add_primitives!(typeinfo, ">=" = |a: i64, b: i64| -> Opt { (a >= b).then_some(()) }); add_primitives!(typeinfo, "bool-=" = |a: i64, b: i64| -> bool { a == b }); add_primitives!(typeinfo, "bool-<" = |a: i64, b: i64| -> bool { a < b }); diff --git a/src/sort/macros.rs b/src/sort/macros.rs index e629c8d9..882e9a74 100644 --- a/src/sort/macros.rs +++ b/src/sort/macros.rs @@ -22,7 +22,7 @@ macro_rules! add_primitives { &self, span: &Span ) -> Box { - let sorts = vec![$(self.$param.clone(),)* self.__out.clone() as ArcSort]; + let sorts = vec![$(self.$param.clone() as ArcSort,)* self.__out.clone() as ArcSort]; SimpleTypeConstraint::new(self.name(), sorts, span.clone()).into_box() } diff --git a/src/sort/mod.rs b/src/sort/mod.rs index 2ab0f60a..50d854ae 100644 --- a/src/sort/mod.rs +++ b/src/sort/mod.rs @@ -4,6 +4,10 @@ use lazy_static::lazy_static; use std::fmt::Debug; use std::{any::Any, sync::Arc}; +mod bigint; +pub use bigint::*; +mod bigrat; +pub use bigrat::*; mod bool; pub use self::bool::*; mod rational; diff --git a/src/sort/rational.rs b/src/sort/rational.rs index fc6cba06..d56c820d 100644 --- a/src/sort/rational.rs +++ b/src/sort/rational.rs @@ -1,8 +1,8 @@ -use num_integer::Roots; -use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, One, Signed, ToPrimitive, Zero}; +use num::integer::Roots; +use num::traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, One, Signed, ToPrimitive, Zero}; use std::sync::Mutex; -type R = num_rational::Rational64; +type R = num::rational::Rational64; use crate::{ast::Literal, util::IndexSet}; use super::*; @@ -59,7 +59,7 @@ impl Sort for RationalSort { Some(R::one()) } else if let Some(b) = b.to_i64() { if let Ok(b) = usize::try_from(b) { - num_traits::checked_pow(a, b) + num::traits::checked_pow(a, b) } else { // TODO handle negative powers None diff --git a/src/typechecking.rs b/src/typechecking.rs index 113c0062..fd72cf20 100644 --- a/src/typechecking.rs +++ b/src/typechecking.rs @@ -42,6 +42,8 @@ impl Default for TypeInfo { res.add_sort(I64Sort, DUMMY_SPAN.clone()).unwrap(); res.add_sort(F64Sort, DUMMY_SPAN.clone()).unwrap(); res.add_sort(RationalSort, DUMMY_SPAN.clone()).unwrap(); + res.add_sort(BigIntSort, DUMMY_SPAN.clone()).unwrap(); + res.add_sort(BigRatSort, DUMMY_SPAN.clone()).unwrap(); res.add_presort::(DUMMY_SPAN.clone()).unwrap(); res.add_presort::(DUMMY_SPAN.clone()).unwrap(); diff --git a/tests/bignum.egg b/tests/bignum.egg new file mode 100644 index 00000000..94c17324 --- /dev/null +++ b/tests/bignum.egg @@ -0,0 +1,13 @@ + +(let x (bigint -1234)) +(let y (from-string "2")) +(let z (bigrat x y)) +(check (= (to-string (numer z)) "-617")) + +(function bignums (BigInt BigInt) BigRat) +(set (bignums x y) z) +(check + (= (bignums a b) c) + (= (numer c) (>> a 1)) + (= (denom c) (>> b 1)) +)