diff --git a/Cargo.lock b/Cargo.lock index 3ec034c..b757bd2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -205,6 +205,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num-cmp" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63335b2e2c34fae2fb0aa2cecfd9f0832a1e24b3b32ecec612c3426d46dc8aaa" + [[package]] name = "num-traits" version = "0.2.17" @@ -228,6 +234,7 @@ dependencies = [ "arbitrary", "borsh", "bytemuck", + "num-cmp", "num-traits", "proptest", "rand", diff --git a/Cargo.toml b/Cargo.toml index 1d2125f..6bd200f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ proptest = { version = "1.0.0", optional = true } speedy = { version = "0.8.3", optional = true, default-features = false } bytemuck = { version = "1.12.2", optional = true, default-features = false } borsh = { version = "1.2.0", optional = true, default-features = false } +num-cmp = { version = "0.1.0", optional = true } [dev-dependencies] serde_test = "1.0" diff --git a/src/lib.rs b/src/lib.rs index 7df9b71..930daef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -79,6 +79,66 @@ fn canonicalize_signed_zero(x: T) -> T { #[repr(transparent)] pub struct OrderedFloat(pub T); +#[cfg(feature = "num-cmp")] +mod impl_num_cmp { + use super::OrderedFloat; + use core::cmp::Ordering; + use num_cmp::NumCmp; + use num_traits::float::FloatCore; + + impl NumCmp for OrderedFloat + where + T: FloatCore + NumCmp, + U: Copy, + { + fn num_cmp(self, other: U) -> Option { + NumCmp::num_cmp(self.0, other) + } + + fn num_eq(self, other: U) -> bool { + NumCmp::num_eq(self.0, other) + } + + fn num_ne(self, other: U) -> bool { + NumCmp::num_ne(self.0, other) + } + + fn num_lt(self, other: U) -> bool { + NumCmp::num_lt(self.0, other) + } + + fn num_gt(self, other: U) -> bool { + NumCmp::num_gt(self.0, other) + } + + fn num_le(self, other: U) -> bool { + NumCmp::num_le(self.0, other) + } + + fn num_ge(self, other: U) -> bool { + NumCmp::num_ge(self.0, other) + } + } + + #[test] + pub fn test_num_cmp() { + let f = OrderedFloat(1.0); + + assert_eq!(NumCmp::num_cmp(f, 1.0), Some(Ordering::Equal)); + assert_eq!(NumCmp::num_cmp(f, -1.0), Some(Ordering::Greater)); + assert_eq!(NumCmp::num_cmp(f, 2.0), Some(Ordering::Less)); + + assert!(NumCmp::num_eq(f, 1)); + assert!(NumCmp::num_ne(f, -1)); + assert!(NumCmp::num_lt(f, 100)); + assert!(NumCmp::num_gt(f, 0)); + assert!(NumCmp::num_le(f, 1)); + assert!(NumCmp::num_le(f, 2)); + assert!(NumCmp::num_ge(f, 1)); + assert!(NumCmp::num_ge(f, -1)); + } +} + impl OrderedFloat { /// Get the value out. #[inline]