diff --git a/clippy.toml b/clippy.toml index 2177fdb469afb..465ccb68ced30 100644 --- a/clippy.toml +++ b/clippy.toml @@ -5,6 +5,9 @@ disallowed-methods = [ { path = "risingwave_common::array::JsonbVal::from_serde", reason = "Please add dedicated methods as part of `JsonbRef`/`JsonbVal`, rather than take inner `serde_json::Value` out, process, and put back." }, { path = "std::panic::catch_unwind", reason = "Please use `risingwave_common::util::panic::rw_catch_unwind` instead." }, { path = "futures::FutureExt::catch_unwind", reason = "Please use `risingwave_common::util::panic::FutureCatchUnwindExt::rw_catch_unwind` instead." }, + { path = "num_traits::sign::Signed::is_positive", reason = "This returns true for 0.0 but false for 0." }, + { path = "num_traits::sign::Signed::is_negative", reason = "This returns true for -0.0 but false for 0." }, + { path = "num_traits::sign::Signed::signum", reason = "This returns 1.0 for 0.0 but 0 for 0." }, ] disallowed-types = [ { path = "num_traits::AsPrimitive", reason = "Please use `From` or `TryFrom` with `OrderedFloat` instead." }, diff --git a/src/common/src/types/decimal.rs b/src/common/src/types/decimal.rs index 6fa4cd9b53c8e..26ec5e8fbde5b 100644 --- a/src/common/src/types/decimal.rs +++ b/src/common/src/types/decimal.rs @@ -18,7 +18,7 @@ use std::ops::{Add, Div, Mul, Neg, Rem, Sub}; use bytes::{BufMut, Bytes, BytesMut}; use num_traits::{ - CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Num, One, Signed, Zero, + CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Num, One, Zero, }; use postgres_types::{ToSql, Type}; use rust_decimal::prelude::FromStr; @@ -446,10 +446,6 @@ impl Decimal { Self::Normalized(RustDecimal::new(num, scale)) } - pub fn zero() -> Self { - Self::from(0) - } - #[must_use] pub fn round_dp_ties_away(&self, dp: u32) -> Self { match self { @@ -553,6 +549,17 @@ impl Decimal { } } + pub fn sign(&self) -> Self { + match self { + Self::NaN => Self::NaN, + _ => match self.cmp(&0.into()) { + std::cmp::Ordering::Less => (-1).into(), + std::cmp::Ordering::Equal => 0.into(), + std::cmp::Ordering::Greater => 1.into(), + }, + } + } + pub fn checked_exp(&self) -> Option { match self { Self::Normalized(d) => d.checked_exp().map(Self::Normalized), @@ -744,47 +751,6 @@ impl Num for Decimal { } } -impl Signed for Decimal { - fn abs(&self) -> Self { - self.abs() - } - - fn abs_sub(&self, other: &Self) -> Self { - if self <= other { - Self::zero() - } else { - *self - *other - } - } - - fn signum(&self) -> Self { - match self { - Self::Normalized(d) => Self::Normalized(d.signum()), - Self::NaN => Self::NaN, - Self::PositiveInf => Self::Normalized(RustDecimal::one()), - Self::NegativeInf => Self::Normalized(-RustDecimal::one()), - } - } - - fn is_positive(&self) -> bool { - match self { - Self::Normalized(d) => d.is_sign_positive(), - Self::NaN => false, - Self::PositiveInf => true, - Self::NegativeInf => false, - } - } - - fn is_negative(&self) -> bool { - match self { - Self::Normalized(d) => d.is_sign_negative(), - Self::NaN => false, - Self::PositiveInf => false, - Self::NegativeInf => true, - } - } -} - impl From for Decimal { fn from(d: RustDecimal) -> Self { Self::Normalized(d) diff --git a/src/common/src/types/interval.rs b/src/common/src/types/interval.rs index ad9db92c7ca62..15670d4b1f04f 100644 --- a/src/common/src/types/interval.rs +++ b/src/common/src/types/interval.rs @@ -29,7 +29,6 @@ use regex::Regex; use risingwave_pb::data::PbInterval; use rust_decimal::prelude::Decimal; -use super::ops::IsNegative; use super::to_binary::ToBinary; use super::*; use crate::error::{ErrorCode, Result, RwError}; @@ -974,12 +973,6 @@ impl Zero for Interval { } } -impl IsNegative for Interval { - fn is_negative(&self) -> bool { - self < &Self::from_month_day_usec(0, 0, 0) - } -} - impl Neg for Interval { type Output = Self; diff --git a/src/common/src/types/num256.rs b/src/common/src/types/num256.rs index d1a2d88e0a81e..2882ce123c729 100644 --- a/src/common/src/types/num256.rs +++ b/src/common/src/types/num256.rs @@ -23,7 +23,7 @@ use std::str::FromStr; use bytes::{BufMut, Bytes}; use ethnum::{i256, u256, AsI256}; use num_traits::{ - CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Num, One, Signed, Zero, + CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Num, One, Zero, }; use risingwave_pb::data::ArrayType; use serde::{Deserialize, Serialize}; @@ -323,32 +323,6 @@ impl Num for Int256 { } } -impl Signed for Int256 { - fn abs(&self) -> Self { - self.0.abs().into() - } - - fn abs_sub(&self, other: &Self) -> Self { - if self <= other { - Self::zero() - } else { - self.abs() - } - } - - fn signum(&self) -> Self { - self.0.signum().into() - } - - fn is_positive(&self) -> bool { - self.0.is_positive() - } - - fn is_negative(&self) -> bool { - self.0.is_negative() - } -} - impl From for Int256 { fn from(value: arrow_buffer::i256) -> Self { let buffer = value.to_be_bytes(); @@ -443,13 +417,6 @@ mod tests { assert_eq!(-Int256::from(0), Int256::from(0)); } - #[test] - fn test_abs() { - assert_eq!(Int256::from(-1).abs(), Int256::from(1)); - assert_eq!(Int256::from(1).abs(), Int256::from(1)); - assert_eq!(Int256::from(0).abs(), Int256::from(0)); - } - #[test] fn test_float64() { let vs: Vec = vec![-9007199254740990, -100, -1, 0, 1, 100, 9007199254740991]; diff --git a/src/common/src/types/ops.rs b/src/common/src/types/ops.rs index a2ac37d899a56..1974e9e8320e0 100644 --- a/src/common/src/types/ops.rs +++ b/src/common/src/types/ops.rs @@ -37,12 +37,14 @@ impl CheckedAdd for T { } /// A simplified version of [`num_traits::Signed`]. -pub trait IsNegative: Zero { +/// Unlike `Signed::is_negative` or `f64::is_sign_negative`, this returns `false` for `-0.0` to keep +/// consistency among integers, decimals and floats. +pub trait IsNegative: Zero + Ord { fn is_negative(&self) -> bool; } -impl IsNegative for T { +impl IsNegative for T { fn is_negative(&self) -> bool { - num_traits::Signed::is_negative(self) + self < &Self::zero() } } diff --git a/src/common/src/types/ordered_float.rs b/src/common/src/types/ordered_float.rs index c0ef51f66872b..f6cc12140125b 100644 --- a/src/common/src/types/ordered_float.rs +++ b/src/common/src/types/ordered_float.rs @@ -53,7 +53,7 @@ use core::str::FromStr; pub use num_traits::Float; use num_traits::{ Bounded, CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Num, One, Pow, - Signed, Zero, + Zero, }; // masks for the parts of the IEEE 754 float @@ -470,32 +470,6 @@ impl<'a, T: Float + Product + 'a> Product<&'a OrderedFloat> for OrderedFloat< } } -impl Signed for OrderedFloat { - #[inline] - fn abs(&self) -> Self { - OrderedFloat(self.0.abs()) - } - - fn abs_sub(&self, other: &Self) -> Self { - OrderedFloat(Signed::abs_sub(&self.0, &other.0)) - } - - #[inline] - fn signum(&self) -> Self { - OrderedFloat(self.0.signum()) - } - - #[inline] - fn is_positive(&self) -> bool { - self.0.is_positive() - } - - #[inline] - fn is_negative(&self) -> bool { - self.0.is_negative() - } -} - impl Bounded for OrderedFloat { #[inline] fn min_value() -> Self { @@ -1042,8 +1016,8 @@ mod tests { let nan = OrderedFloat::::from(nan_prim); assert_eq!(nan, nan); - use num_traits::Signed as _; - assert_eq!(nan.abs(), nan.abs()); + use crate::types::FloatExt as _; + assert_eq!(nan.round(), nan.round()); } fn test_into_f32(expected: [u8; 4], v: impl Into>) { diff --git a/src/expr/src/vector_op/arithmetic_op.rs b/src/expr/src/vector_op/arithmetic_op.rs index e0bb41214fe38..da08ec1385f78 100644 --- a/src/expr/src/vector_op/arithmetic_op.rs +++ b/src/expr/src/vector_op/arithmetic_op.rs @@ -16,9 +16,9 @@ use std::convert::TryInto; use std::fmt::Debug; use chrono::{Duration, NaiveDateTime}; -use num_traits::{CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Signed, Zero}; +use num_traits::{CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Zero}; use risingwave_common::types::{ - CheckedAdd, Date, Decimal, FloatExt, Interval, Time, Timestamp, F64, + CheckedAdd, Date, Decimal, FloatExt, Interval, IsNegative, Time, Timestamp, F64, }; use risingwave_expr_macro::function; use rust_decimal::MathematicalOps; @@ -129,7 +129,7 @@ where #[function("abs(*int) -> auto")] #[function("abs(*float) -> auto")] -pub fn general_abs(expr: T1) -> Result { +pub fn general_abs(expr: T1) -> Result { if expr.is_negative() { general_neg(expr) } else { @@ -141,7 +141,7 @@ pub fn general_abs(expr: T1) -> Result { pub fn int256_abs(expr: TRef) -> Result where TRef: Into + Debug, - T: Signed + CheckedNeg + Debug, + T: IsNegative + CheckedNeg + Debug, { let expr = expr.into(); if expr.is_negative() { @@ -376,7 +376,7 @@ pub fn sqrt_f64(expr: F64) -> Result { }); } // Edge cases: nan, inf, negative zero should return itself. - match expr.is_nan() || expr == f64::INFINITY || expr.is_negative() { + match expr.is_nan() || expr == f64::INFINITY || expr == -0.0 { true => Ok(expr), false => Ok(expr.sqrt()), } @@ -417,7 +417,7 @@ pub fn sign_f64(input: F64) -> F64 { #[function("sign(decimal) -> decimal")] pub fn sign_dec(input: Decimal) -> Decimal { - input.signum() + input.sign() } #[cfg(test)] diff --git a/src/expr/src/vector_op/round.rs b/src/expr/src/vector_op/round.rs index e83d18c653299..181cb9eefd33d 100644 --- a/src/expr/src/vector_op/round.rs +++ b/src/expr/src/vector_op/round.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use num_traits::Zero; use risingwave_common::types::{Decimal, F64}; use risingwave_expr_macro::function;