From d173e5dd7c03e5b59ec229fd27f1d71b7bc7f1b3 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 27 Dec 2022 19:42:14 +0800 Subject: [PATCH] perf(expr): vectorize infallible casts (#7079) Similar to #7055, this PR vectorizes infallible casts. perf-infallible-cast
Click to show full results bench | Before time(us) | After time(us) | Change(%) | Speedup -- | -- | -- | -- | -- cast(int16->float32) | 4.434 | 0.146 | -96.7% | 29.3 cast(int16->int32) | 4.408 | 0.154 | -96.5% | 27.7 cast(float32->float64) | 4.432 | 0.187 | -95.8% | 22.7 cast(int32->int64) | 4.415 | 0.192 | -95.7% | 22.0 cast(int32->float64) | 4.422 | 0.194 | -95.6% | 21.8 cast(int16->int64) | 4.412 | 0.212 | -95.2% | 19.8 cast(timestamp->date) | 4.409 | 0.226 | -94.9% | 18.5 cast(timestamp->time) | 5.443 | 0.300 | -94.5% | 17.1 cast(date->timestamp) | 5.504 | 0.304 | -94.5% | 17.1 cast(int16->float64) | 4.430 | 0.298 | -93.3% | 13.9 cast(int32->decimal) | 5.582 | 0.592 | -89.4% | 8.4 cast(time->interval) | 5.511 | 0.727 | -86.8% | 6.6 cast(int64->decimal) | 5.739 | 0.766 | -86.7% | 6.5 cast(int16->decimal) | 5.760 | 0.845 | -85.3% | 5.8 cast(interval->time) | 5.903 | 1.289 | -78.2% | 3.6 cast(float32->decimal) | 21.970 | 18.170 | -17.3% | 0.2 cast(float64->decimal) | 40.131 | 36.049 | -10.2% | 0.1
Approved-By: BowenXiao1999 --- src/common/src/types/decimal.rs | 6 + src/common/src/types/ordered_float.rs | 19 --- src/expr/src/expr/expr_unary.rs | 19 ++- src/expr/src/vector_op/cast.rs | 184 ++++++++++++++------------ src/expr/src/vector_op/tests.rs | 4 +- 5 files changed, 120 insertions(+), 112 deletions(-) diff --git a/src/common/src/types/decimal.rs b/src/common/src/types/decimal.rs index 8a954a10cc6c9..557ebdbcf828b 100644 --- a/src/common/src/types/decimal.rs +++ b/src/common/src/types/decimal.rs @@ -27,6 +27,7 @@ use super::to_text::ToText; use super::DataType; use crate::array::ArrayResult; use crate::error::Result as RwResult; +use crate::types::ordered_float::OrderedFloat; use crate::types::Decimal::Normalized; #[derive(Debug, Copy, parse_display::Display, Clone, PartialEq, Hash, Eq, Ord, PartialOrd)] @@ -191,6 +192,11 @@ macro_rules! impl_try_from_float { $convert(value).expect("f32/f64 to decimal should not fail") } } + impl core::convert::From> for $to_ty { + fn from(value: OrderedFloat<$from_ty>) -> Self { + $convert(value.0).expect("f32/f64 to decimal should not fail") + } + } }; } diff --git a/src/common/src/types/ordered_float.rs b/src/common/src/types/ordered_float.rs index af94a27588c3a..e8296c3b4e07f 100644 --- a/src/common/src/types/ordered_float.rs +++ b/src/common/src/types/ordered_float.rs @@ -41,7 +41,6 @@ //! Wrappers for total order on Floats. See the [`OrderedFloat`] docs for details. use core::cmp::Ordering; -use core::convert::TryFrom; use core::fmt; use core::hash::{Hash, Hasher}; use core::iter::{Product, Sum}; @@ -1063,24 +1062,6 @@ mod impl_as_primitive { mod impl_from { use super::*; - macro_rules! impl_try_from_for { - ($ty:ty) => { - impl TryFrom> for $ty - where - F: 'static + Float, - Self: TryFrom, - { - type Error = >::Error; - - fn try_from(value: OrderedFloat) -> Result { - TryFrom::try_from(value.0) - } - } - }; - } - - impl_try_from_for!(crate::types::Decimal); - macro_rules! impl_from_for { ($ty:ty) => { impl From> for $ty diff --git a/src/expr/src/expr/expr_unary.rs b/src/expr/src/expr/expr_unary.rs index b469fa5200dcf..4595f296831ab 100644 --- a/src/expr/src/expr/expr_unary.rs +++ b/src/expr/src/expr/expr_unary.rs @@ -162,30 +162,37 @@ pub fn new_unary_expr( )), (ProstType::Cast, _, _) => { macro_rules! gen_cast_impl { - ($( { $input:ident, $cast:ident, $func:expr } ),*) => { + ($( { $input:ident, $cast:ident, $func:expr, $infallible:ident } ),*) => { match (child_expr.return_type(), return_type.clone()) { $( - ($input! { type_match_pattern }, $cast! { type_match_pattern }) => gen_cast_impl!(arm: $input, $cast, $func), + ($input! { type_match_pattern }, $cast! { type_match_pattern }) => gen_cast_impl!(arm: $input, $cast, $func, $infallible), )* _ => { return Err(ExprError::UnsupportedCast(child_expr.return_type(), return_type)); } } }; - (arm: $input:ident, varchar, $func:expr) => { + (arm: $input:ident, varchar, $func:expr, false) => { UnaryBytesExpression::< $input! { type_array }, _>::new( child_expr, return_type.clone(), $func ).boxed() }; - (arm: $input:ident, $cast:ident, $func:expr) => { + (arm: $input:ident, $cast:ident, $func:expr, false) => { UnaryExpression::< $input! { type_array }, $cast! { type_array }, _>::new( child_expr, return_type.clone(), $func ).boxed() }; + (arm: $input:ident, $cast:ident, $func:expr, true) => { + template_fast::UnaryExpression::new( + child_expr, + return_type.clone(), + $func + ).boxed() + }; } for_all_cast_variants! { gen_cast_impl } @@ -348,12 +355,12 @@ mod tests { use super::super::*; use crate::expr::test_utils::{make_expression, make_input_ref}; - use crate::vector_op::cast::{general_cast, str_parse}; + use crate::vector_op::cast::{str_parse, try_cast}; #[test] fn test_unary() { test_unary_bool::(|x| !x, Type::Not); - test_unary_date::(|x| general_cast(x).unwrap(), Type::Cast); + test_unary_date::(|x| try_cast(x).unwrap(), Type::Cast); test_str_to_int16::(|x| str_parse(x).unwrap()); } diff --git a/src/expr/src/vector_op/cast.rs b/src/expr/src/vector_op/cast.rs index 0e711361abb99..2bc9695ad6cd7 100644 --- a/src/expr/src/vector_op/cast.rs +++ b/src/expr/src/vector_op/cast.rs @@ -321,29 +321,27 @@ pub fn dec_to_i64(elem: Decimal) -> Result { /// In `PostgreSQL`, casting from timestamp to date discards the time part. #[inline(always)] -pub fn timestamp_to_date(elem: NaiveDateTimeWrapper) -> Result { - Ok(NaiveDateWrapper(elem.0.date())) +pub fn timestamp_to_date(elem: NaiveDateTimeWrapper) -> NaiveDateWrapper { + NaiveDateWrapper(elem.0.date()) } /// In `PostgreSQL`, casting from timestamp to time discards the date part. #[inline(always)] -pub fn timestamp_to_time(elem: NaiveDateTimeWrapper) -> Result { - Ok(NaiveTimeWrapper(elem.0.time())) +pub fn timestamp_to_time(elem: NaiveDateTimeWrapper) -> NaiveTimeWrapper { + NaiveTimeWrapper(elem.0.time()) } /// In `PostgreSQL`, casting from interval to time discards the days part. #[inline(always)] -pub fn interval_to_time(elem: IntervalUnit) -> Result { +pub fn interval_to_time(elem: IntervalUnit) -> NaiveTimeWrapper { let ms = elem.get_ms_of_day(); let secs = (ms / 1000) as u32; let nano = (ms % 1000 * 1_000_000) as u32; - Ok(NaiveTimeWrapper::from_num_seconds_from_midnight_uncheck( - secs, nano, - )) + NaiveTimeWrapper::from_num_seconds_from_midnight_uncheck(secs, nano) } #[inline(always)] -pub fn general_cast(elem: T1) -> Result +pub fn try_cast(elem: T1) -> Result where T1: TryInto + std::fmt::Debug + Copy, >::Error: std::fmt::Display, @@ -352,6 +350,14 @@ where .map_err(|_| ExprError::CastOutOfRange(std::any::type_name::())) } +#[inline(always)] +pub fn cast(elem: T1) -> T2 +where + T1: Into, +{ + elem.into() +} + #[inline(always)] pub fn str_to_bool(input: &str) -> Result { let trimmed_input = input.trim(); @@ -401,80 +407,81 @@ pub fn bool_out(input: bool, writer: &mut dyn Write) -> Result<()> { /// * `$cast`: The cast type in that the operation will calculate /// * `$func`: The scalar function for expression, it's a generic function and specialized by the /// type of `$input, $cast` +/// * `$infallible`: Whether the cast is infallible #[macro_export] macro_rules! for_all_cast_variants { ($macro:ident) => { $macro! { - { varchar, date, str_to_date }, - { varchar, time, str_to_time }, - { varchar, interval, str_parse }, - { varchar, timestamp, str_to_timestamp }, - { varchar, timestampz, str_to_timestampz }, - { varchar, int16, str_parse }, - { varchar, int32, str_parse }, - { varchar, int64, str_parse }, - { varchar, float32, str_parse }, - { varchar, float64, str_parse }, - { varchar, decimal, str_parse }, - { varchar, boolean, str_to_bool }, - { varchar, bytea, str_to_bytea }, + { varchar, date, str_to_date, false }, + { varchar, time, str_to_time, false }, + { varchar, interval, str_parse, false }, + { varchar, timestamp, str_to_timestamp, false }, + { varchar, timestampz, str_to_timestampz, false }, + { varchar, int16, str_parse, false }, + { varchar, int32, str_parse, false }, + { varchar, int64, str_parse, false }, + { varchar, float32, str_parse, false }, + { varchar, float64, str_parse, false }, + { varchar, decimal, str_parse, false }, + { varchar, boolean, str_to_bool, false }, + { varchar, bytea, str_to_bytea, false }, // `str_to_list` requires `target_elem_type` and is handled elsewhere - { boolean, varchar, bool_to_varchar }, - { int16, varchar, general_to_text }, - { int32, varchar, general_to_text }, - { int64, varchar, general_to_text }, - { float32, varchar, general_to_text }, - { float64, varchar, general_to_text }, - { decimal, varchar, general_to_text }, - { time, varchar, general_to_text }, - { interval, varchar, general_to_text }, - { date, varchar, general_to_text }, - { timestamp, varchar, general_to_text }, - { timestampz, varchar, timestampz_to_utc_string }, - { list, varchar, |x, w| general_to_text(x, w) }, - - { boolean, int32, general_cast }, - { int32, boolean, int32_to_bool }, - - { int16, int32, general_cast }, - { int16, int64, general_cast }, - { int16, float32, general_cast }, - { int16, float64, general_cast }, - { int16, decimal, general_cast }, - { int32, int16, general_cast }, - { int32, int64, general_cast }, - { int32, float32, to_f32 }, // lossy - { int32, float64, general_cast }, - { int32, decimal, general_cast }, - { int64, int16, general_cast }, - { int64, int32, general_cast }, - { int64, float32, to_f32 }, // lossy - { int64, float64, to_f64 }, // lossy - { int64, decimal, general_cast }, - - { float32, float64, general_cast }, - { float32, decimal, general_cast }, - { float32, int16, to_i16 }, - { float32, int32, to_i32 }, - { float32, int64, to_i64 }, - { float64, decimal, general_cast }, - { float64, int16, to_i16 }, - { float64, int32, to_i32 }, - { float64, int64, to_i64 }, - { float64, float32, to_f32 }, // lossy - - { decimal, int16, dec_to_i16 }, - { decimal, int32, dec_to_i32 }, - { decimal, int64, dec_to_i64 }, - { decimal, float32, to_f32 }, - { decimal, float64, to_f64 }, - - { date, timestamp, general_cast }, - { time, interval, general_cast }, - { timestamp, date, timestamp_to_date }, - { timestamp, time, timestamp_to_time }, - { interval, time, interval_to_time } + { boolean, varchar, bool_to_varchar, false }, + { int16, varchar, general_to_text, false }, + { int32, varchar, general_to_text, false }, + { int64, varchar, general_to_text, false }, + { float32, varchar, general_to_text, false }, + { float64, varchar, general_to_text, false }, + { decimal, varchar, general_to_text, false }, + { time, varchar, general_to_text, false }, + { interval, varchar, general_to_text, false }, + { date, varchar, general_to_text, false }, + { timestamp, varchar, general_to_text, false }, + { timestampz, varchar, timestampz_to_utc_string, false }, + { list, varchar, |x, w| general_to_text(x, w), false }, + + { boolean, int32, try_cast, false }, + { int32, boolean, int32_to_bool, false }, + + { int16, int32, cast::, true }, + { int16, int64, cast::, true }, + { int16, float32, cast::, true }, + { int16, float64, cast::, true }, + { int16, decimal, cast::, true }, + { int32, int16, try_cast, false }, + { int32, int64, cast::, true }, + { int32, float32, to_f32, false }, // lossy + { int32, float64, cast::, true }, + { int32, decimal, cast::, true }, + { int64, int16, try_cast, false }, + { int64, int32, try_cast, false }, + { int64, float32, to_f32, false }, // lossy + { int64, float64, to_f64, false }, // lossy + { int64, decimal, cast::, true }, + + { float32, float64, cast::, true }, + { float32, decimal, cast::, true }, + { float32, int16, to_i16, false }, + { float32, int32, to_i32, false }, + { float32, int64, to_i64, false }, + { float64, decimal, cast::, true }, + { float64, int16, to_i16, false }, + { float64, int32, to_i32, false }, + { float64, int64, to_i64, false }, + { float64, float32, to_f32, false }, // lossy + + { decimal, int16, dec_to_i16, false }, + { decimal, int32, dec_to_i32, false }, + { decimal, int64, dec_to_i64, false }, + { decimal, float32, to_f32, false }, + { decimal, float64, to_f64, false }, + + { date, timestamp, cast::, true }, + { time, interval, cast::, true }, + { timestamp, date, timestamp_to_date, true }, + { timestamp, time, timestamp_to_time, true }, + { interval, time, interval_to_time, true } } }; } @@ -624,17 +631,17 @@ fn scalar_cast( ) => str_to_list(source.try_into()?, target_elem_type).map(Scalar::to_scalar_value), (source_type, target_type) => { macro_rules! gen_cast_impl { - ($( { $input:ident, $cast:ident, $func:expr } ),*) => { + ($( { $input:ident, $cast:ident, $func:expr, $infallible:ident } ),*) => { match (source_type, target_type) { $( - ($input! { type_match_pattern }, $cast! { type_match_pattern }) => gen_cast_impl!(arm: $input, $cast, $func), + ($input! { type_match_pattern }, $cast! { type_match_pattern }) => gen_cast_impl!(arm: $input, $cast, $func, $infallible), )* _ => { return Err(ExprError::UnsupportedCast(source_type.clone(), target_type.clone())); } } }; - (arm: $input:ident, varchar, $func:expr) => { + (arm: $input:ident, varchar, $func:expr, false) => { { let source: <$input! { type_array } as Array>::RefItem<'_> = source.try_into()?; let mut writer = String::new(); @@ -642,13 +649,20 @@ fn scalar_cast( target.map(|_| Scalar::to_scalar_value(writer.into_boxed_str())) } }; - (arm: $input:ident, $cast:ident, $func:expr) => { + (arm: $input:ident, $cast:ident, $func:expr, false) => { { let source: <$input! { type_array } as Array>::RefItem<'_> = source.try_into()?; let target: Result<<$cast! { type_array } as Array>::OwnedItem> = $func(source); target.map(Scalar::to_scalar_value) } }; + (arm: $input:ident, $cast:ident, $func:expr, true) => { + { + let source: <$input! { type_array } as Array>::RefItem<'_> = source.try_into()?; + let target: Result<<$cast! { type_array } as Array>::OwnedItem> = Ok($func(source)); + target.map(Scalar::to_scalar_value) + } + }; } for_all_cast_variants!(gen_cast_impl) } @@ -749,19 +763,19 @@ mod tests { #[test] fn temporal_cast() { assert_eq!( - timestamp_to_date(str_to_timestamp("1999-01-08 04:02").unwrap()).unwrap(), + timestamp_to_date(str_to_timestamp("1999-01-08 04:02").unwrap()), str_to_date("1999-01-08").unwrap(), ); assert_eq!( - timestamp_to_time(str_to_timestamp("1999-01-08 04:02").unwrap()).unwrap(), + timestamp_to_time(str_to_timestamp("1999-01-08 04:02").unwrap()), str_to_time("04:02").unwrap(), ); assert_eq!( - interval_to_time(IntervalUnit::new(1, 2, 61003)).unwrap(), + interval_to_time(IntervalUnit::new(1, 2, 61003)), str_to_time("00:01:01.003").unwrap(), ); assert_eq!( - interval_to_time(IntervalUnit::new(0, 0, -61003)).unwrap(), + interval_to_time(IntervalUnit::new(0, 0, -61003)), str_to_time("23:58:58.997").unwrap(), ); } diff --git a/src/expr/src/vector_op/tests.rs b/src/expr/src/vector_op/tests.rs index d70215872a926..0a52cd43352fc 100644 --- a/src/expr/src/vector_op/tests.rs +++ b/src/expr/src/vector_op/tests.rs @@ -22,7 +22,7 @@ use risingwave_common::types::{ use crate::vector_op::arithmetic_op::*; use crate::vector_op::bitwise_op::*; -use crate::vector_op::cast::general_cast; +use crate::vector_op::cast::try_cast; use crate::vector_op::cmp::*; use crate::vector_op::conjunction::*; use crate::ExprError; @@ -239,7 +239,7 @@ fn test_conjunction() { #[test] fn test_cast() { assert_eq!( - general_cast::<_, NaiveDateTimeWrapper>(NaiveDateWrapper::from_ymd_uncheck(1994, 1, 1)) + try_cast::<_, NaiveDateTimeWrapper>(NaiveDateWrapper::from_ymd_uncheck(1994, 1, 1)) .unwrap(), NaiveDateTimeWrapper::new( NaiveDateTime::parse_from_str("1994-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap()