Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(expr): support sign and distinguish is_positive vs is_sign_positive #10819

Merged
merged 3 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions clippy.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ disallowed-methods = [
{ path = "risingwave_common::array::JsonbVal::from_serde", reason = "Please add dedicated methods as part of `JsonbRef`/`JsonbVal`, rather than take inner `serde_json::Value` out, process, and put back." },
{ path = "std::panic::catch_unwind", reason = "Please use `risingwave_common::util::panic::rw_catch_unwind` instead." },
{ path = "futures::FutureExt::catch_unwind", reason = "Please use `risingwave_common::util::panic::FutureCatchUnwindExt::rw_catch_unwind` instead." },
{ path = "num_traits::sign::Signed::is_positive", reason = "This returns true for 0.0 but false for 0." },
{ path = "num_traits::sign::Signed::is_negative", reason = "This returns true for -0.0 but false for 0." },
{ path = "num_traits::sign::Signed::signum", reason = "This returns 1.0 for 0.0 but 0 for 0." },
]
disallowed-types = [
{ path = "num_traits::AsPrimitive", reason = "Please use `From` or `TryFrom` with `OrderedFloat` instead." },
Expand Down
25 changes: 25 additions & 0 deletions e2e_test/batch/functions/abs.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,28 @@ query I
SELECT abs(2134)
----
2134

query TRRRR
with t(s) as (values
('-inf'),
('-2'),
('-0'),
('0'),
('2'),
('inf'),
('nan')
) select
s,
sign(s::decimal),
sign(s::double precision),
abs(s::decimal),
abs(s::double precision)
from t;
----
-inf -1 -1 Infinity Infinity
-2 -1 -1 2 2
-0 0 0 0 0
0 0 0 0 0
2 1 1 2 2
inf 1 1 Infinity Infinity
nan NaN 0 NaN NaN
1 change: 1 addition & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ message ExprNode {
LN = 272;
LOG10 = 273;
CBRT = 274;
SIGN = 275;

// Boolean comparison
IS_TRUE = 301;
Expand Down
58 changes: 12 additions & 46 deletions src/common/src/types/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use std::ops::{Add, Div, Mul, Neg, Rem, Sub};

use bytes::{BufMut, Bytes, BytesMut};
use num_traits::{
CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Num, One, Signed, Zero,
CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Num, One, Zero,
};
use postgres_types::{ToSql, Type};
use rust_decimal::prelude::FromStr;
Expand Down Expand Up @@ -446,10 +446,6 @@ impl Decimal {
Self::Normalized(RustDecimal::new(num, scale))
}

pub fn zero() -> Self {
Self::from(0)
}

#[must_use]
pub fn round_dp_ties_away(&self, dp: u32) -> Self {
match self {
Expand Down Expand Up @@ -553,6 +549,17 @@ impl Decimal {
}
}

pub fn sign(&self) -> Self {
match self {
Self::NaN => Self::NaN,
_ => match self.cmp(&0.into()) {
std::cmp::Ordering::Less => (-1).into(),
std::cmp::Ordering::Equal => 0.into(),
std::cmp::Ordering::Greater => 1.into(),
},
}
}

pub fn checked_exp(&self) -> Option<Decimal> {
match self {
Self::Normalized(d) => d.checked_exp().map(Self::Normalized),
Expand Down Expand Up @@ -744,47 +751,6 @@ impl Num for Decimal {
}
}

impl Signed for Decimal {
fn abs(&self) -> Self {
self.abs()
}

fn abs_sub(&self, other: &Self) -> Self {
if self <= other {
Self::zero()
} else {
*self - *other
}
}

fn signum(&self) -> Self {
match self {
Self::Normalized(d) => Self::Normalized(d.signum()),
Self::NaN => Self::NaN,
Self::PositiveInf => Self::Normalized(RustDecimal::one()),
Self::NegativeInf => Self::Normalized(-RustDecimal::one()),
}
}

fn is_positive(&self) -> bool {
match self {
Self::Normalized(d) => d.is_sign_positive(),
Self::NaN => false,
Self::PositiveInf => true,
Self::NegativeInf => false,
}
}

fn is_negative(&self) -> bool {
match self {
Self::Normalized(d) => d.is_sign_negative(),
Self::NaN => false,
Self::PositiveInf => false,
Self::NegativeInf => true,
}
}
}

impl From<RustDecimal> for Decimal {
fn from(d: RustDecimal) -> Self {
Self::Normalized(d)
Expand Down
7 changes: 0 additions & 7 deletions src/common/src/types/interval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ use regex::Regex;
use risingwave_pb::data::PbInterval;
use rust_decimal::prelude::Decimal;

use super::ops::IsNegative;
use super::to_binary::ToBinary;
use super::*;
use crate::error::{ErrorCode, Result, RwError};
Expand Down Expand Up @@ -974,12 +973,6 @@ impl Zero for Interval {
}
}

impl IsNegative for Interval {
fn is_negative(&self) -> bool {
self < &Self::from_month_day_usec(0, 0, 0)
}
}

impl Neg for Interval {
type Output = Self;

Expand Down
35 changes: 1 addition & 34 deletions src/common/src/types/num256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::str::FromStr;
use bytes::{BufMut, Bytes};
use ethnum::{i256, u256, AsI256};
use num_traits::{
CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Num, One, Signed, Zero,
CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Num, One, Zero,
};
use risingwave_pb::data::ArrayType;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -323,32 +323,6 @@ impl Num for Int256 {
}
}

impl Signed for Int256 {
fn abs(&self) -> Self {
self.0.abs().into()
}

fn abs_sub(&self, other: &Self) -> Self {
if self <= other {
Self::zero()
} else {
self.abs()
}
}

fn signum(&self) -> Self {
self.0.signum().into()
}

fn is_positive(&self) -> bool {
self.0.is_positive()
}

fn is_negative(&self) -> bool {
self.0.is_negative()
}
}

impl From<arrow_buffer::i256> for Int256 {
fn from(value: arrow_buffer::i256) -> Self {
let buffer = value.to_be_bytes();
Expand Down Expand Up @@ -443,13 +417,6 @@ mod tests {
assert_eq!(-Int256::from(0), Int256::from(0));
}

#[test]
fn test_abs() {
assert_eq!(Int256::from(-1).abs(), Int256::from(1));
assert_eq!(Int256::from(1).abs(), Int256::from(1));
assert_eq!(Int256::from(0).abs(), Int256::from(0));
}

#[test]
fn test_float64() {
let vs: Vec<i64> = vec![-9007199254740990, -100, -1, 0, 1, 100, 9007199254740991];
Expand Down
8 changes: 5 additions & 3 deletions src/common/src/types/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ impl<T: num_traits::CheckedAdd> CheckedAdd for T {
}

/// A simplified version of [`num_traits::Signed`].
pub trait IsNegative: Zero {
/// Unlike `Signed::is_negative` or `f64::is_sign_negative`, this returns `false` for `-0.0` to keep
/// consistency among integers, decimals and floats.
pub trait IsNegative: Zero + Ord {
fn is_negative(&self) -> bool;
}

impl<T: num_traits::Signed> IsNegative for T {
impl<T: Zero + Ord> IsNegative for T {
fn is_negative(&self) -> bool {
num_traits::Signed::is_negative(self)
self < &Self::zero()
}
}
32 changes: 3 additions & 29 deletions src/common/src/types/ordered_float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ use core::str::FromStr;
pub use num_traits::Float;
use num_traits::{
Bounded, CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Num, One, Pow,
Signed, Zero,
Zero,
};

// masks for the parts of the IEEE 754 float
Expand Down Expand Up @@ -470,32 +470,6 @@ impl<'a, T: Float + Product + 'a> Product<&'a OrderedFloat<T>> for OrderedFloat<
}
}

impl<T: Float + Signed> Signed for OrderedFloat<T> {
#[inline]
fn abs(&self) -> Self {
OrderedFloat(self.0.abs())
}

fn abs_sub(&self, other: &Self) -> Self {
OrderedFloat(Signed::abs_sub(&self.0, &other.0))
}

#[inline]
fn signum(&self) -> Self {
OrderedFloat(self.0.signum())
}

#[inline]
fn is_positive(&self) -> bool {
self.0.is_positive()
}

#[inline]
fn is_negative(&self) -> bool {
self.0.is_negative()
}
}

impl<T: Bounded> Bounded for OrderedFloat<T> {
#[inline]
fn min_value() -> Self {
Expand Down Expand Up @@ -1042,8 +1016,8 @@ mod tests {
let nan = OrderedFloat::<f64>::from(nan_prim);
assert_eq!(nan, nan);

use num_traits::Signed as _;
assert_eq!(nan.abs(), nan.abs());
use crate::types::FloatExt as _;
assert_eq!(nan.round(), nan.round());
}

fn test_into_f32(expected: [u8; 4], v: impl Into<OrderedFloat<f32>>) {
Expand Down
31 changes: 25 additions & 6 deletions src/expr/src/vector_op/arithmetic_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ use std::convert::TryInto;
use std::fmt::Debug;

use chrono::{Duration, NaiveDateTime};
use num_traits::{CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Signed, Zero};
use num_traits::{CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, Zero};
use risingwave_common::types::{
CheckedAdd, Date, Decimal, FloatExt, Interval, Time, Timestamp, F64,
CheckedAdd, Date, Decimal, FloatExt, Interval, IsNegative, Time, Timestamp, F64,
};
use risingwave_expr_macro::function;
use rust_decimal::MathematicalOps;
Expand Down Expand Up @@ -128,20 +128,24 @@ where
}

#[function("abs(*int) -> auto")]
#[function("abs(*float) -> auto")]
pub fn general_abs<T1: Signed + CheckedNeg>(expr: T1) -> Result<T1> {
pub fn general_abs<T1: IsNegative + CheckedNeg>(expr: T1) -> Result<T1> {
if expr.is_negative() {
general_neg(expr)
} else {
Ok(expr)
}
}

#[function("abs(*float) -> auto")]
pub fn float_abs<F: num_traits::Float, T1: FloatExt<F>>(expr: T1) -> T1 {
expr.abs()
}

#[function("abs(int256) -> int256")]
pub fn int256_abs<TRef, T>(expr: TRef) -> Result<T>
where
TRef: Into<T> + Debug,
T: Signed + CheckedNeg + Debug,
T: IsNegative + CheckedNeg + Debug,
{
let expr = expr.into();
if expr.is_negative() {
Expand Down Expand Up @@ -376,7 +380,7 @@ pub fn sqrt_f64(expr: F64) -> Result<F64> {
});
}
// Edge cases: nan, inf, negative zero should return itself.
match expr.is_nan() || expr == f64::INFINITY || expr.is_negative() {
match expr.is_nan() || expr == f64::INFINITY || expr == -0.0 {
true => Ok(expr),
false => Ok(expr.sqrt()),
}
Expand Down Expand Up @@ -405,6 +409,21 @@ pub fn cbrt_f64(expr: F64) -> F64 {
expr.cbrt()
}

#[function("sign(float64) -> float64")]
pub fn sign_f64(input: F64) -> F64 {
match input.0.partial_cmp(&0.) {
Some(std::cmp::Ordering::Less) => (-1).into(),
Some(std::cmp::Ordering::Equal) => 0.into(),
Some(std::cmp::Ordering::Greater) => 1.into(),
None => 0.into(),
}
}

#[function("sign(decimal) -> decimal")]
pub fn sign_dec(input: Decimal) -> Decimal {
input.sign()
}

#[cfg(test)]
mod tests {
use std::str::FromStr;
Expand Down
1 change: 1 addition & 0 deletions src/expr/src/vector_op/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use num_traits::Zero;
use risingwave_common::types::{Decimal, F64};
use risingwave_expr_macro::function;

Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,7 @@ impl Binder {
("radians", raw_call(ExprType::Radians)),
("sqrt", raw_call(ExprType::Sqrt)),
("cbrt", raw_call(ExprType::Cbrt)),
("sign", raw_call(ExprType::Sign)),

(
"to_timestamp",
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/expr/pure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ impl ExprVisitor<bool> for ImpureAnalyzer {
| expr_node::Type::Atan2
| expr_node::Type::Sqrt
| expr_node::Type::Cbrt
| expr_node::Type::Sign
| expr_node::Type::Left
| expr_node::Type::Right
| expr_node::Type::Degrees
Expand Down
2 changes: 1 addition & 1 deletion src/tests/regress/data/sql/float8.sql
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ select ceiling(f1) as ceiling_f1 from float8_tbl f;
select floor(f1) as floor_f1 from float8_tbl f;

-- sign
--@ select sign(f1) as sign_f1 from float8_tbl f;
select sign(f1) as sign_f1 from float8_tbl f;

-- avoid bit-exact output here because operations may not be bit-exact.
SET extra_float_digits = 0;
Expand Down