Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(expr): vectorize infallible casts #7079

Merged
merged 1 commit into from
Dec 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/common/src/types/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use super::to_text::ToText;
use super::DataType;
use crate::array::ArrayResult;
use crate::error::Result as RwResult;
use crate::types::ordered_float::OrderedFloat;
use crate::types::Decimal::Normalized;

#[derive(Debug, Copy, parse_display::Display, Clone, PartialEq, Hash, Eq, Ord, PartialOrd)]
Expand Down Expand Up @@ -191,6 +192,11 @@ macro_rules! impl_try_from_float {
$convert(value).expect("f32/f64 to decimal should not fail")
}
}
impl core::convert::From<OrderedFloat<$from_ty>> for $to_ty {
fn from(value: OrderedFloat<$from_ty>) -> Self {
$convert(value.0).expect("f32/f64 to decimal should not fail")
}
}
};
}

Expand Down
19 changes: 0 additions & 19 deletions src/common/src/types/ordered_float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
//! Wrappers for total order on Floats. See the [`OrderedFloat`] docs for details.
use core::cmp::Ordering;
use core::convert::TryFrom;
use core::fmt;
use core::hash::{Hash, Hasher};
use core::iter::{Product, Sum};
Expand Down Expand Up @@ -1063,24 +1062,6 @@ mod impl_as_primitive {
mod impl_from {
use super::*;

macro_rules! impl_try_from_for {
($ty:ty) => {
impl<F> TryFrom<OrderedFloat<F>> for $ty
where
F: 'static + Float,
Self: TryFrom<F>,
{
type Error = <Self as TryFrom<F>>::Error;

fn try_from(value: OrderedFloat<F>) -> Result<Self, Self::Error> {
TryFrom::try_from(value.0)
}
}
};
}

impl_try_from_for!(crate::types::Decimal);

macro_rules! impl_from_for {
($ty:ty) => {
impl<F> From<OrderedFloat<F>> for $ty
Expand Down
19 changes: 13 additions & 6 deletions src/expr/src/expr/expr_unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,30 +162,37 @@ pub fn new_unary_expr(
)),
(ProstType::Cast, _, _) => {
macro_rules! gen_cast_impl {
($( { $input:ident, $cast:ident, $func:expr } ),*) => {
($( { $input:ident, $cast:ident, $func:expr, $infallible:ident } ),*) => {
match (child_expr.return_type(), return_type.clone()) {
$(
($input! { type_match_pattern }, $cast! { type_match_pattern }) => gen_cast_impl!(arm: $input, $cast, $func),
($input! { type_match_pattern }, $cast! { type_match_pattern }) => gen_cast_impl!(arm: $input, $cast, $func, $infallible),
)*
_ => {
return Err(ExprError::UnsupportedCast(child_expr.return_type(), return_type));
}
}
};
(arm: $input:ident, varchar, $func:expr) => {
(arm: $input:ident, varchar, $func:expr, false) => {
UnaryBytesExpression::< $input! { type_array }, _>::new(
child_expr,
return_type.clone(),
$func
).boxed()
};
(arm: $input:ident, $cast:ident, $func:expr) => {
(arm: $input:ident, $cast:ident, $func:expr, false) => {
UnaryExpression::< $input! { type_array }, $cast! { type_array }, _>::new(
child_expr,
return_type.clone(),
$func
).boxed()
};
(arm: $input:ident, $cast:ident, $func:expr, true) => {
template_fast::UnaryExpression::new(
child_expr,
return_type.clone(),
$func
).boxed()
};
}

for_all_cast_variants! { gen_cast_impl }
Expand Down Expand Up @@ -348,12 +355,12 @@ mod tests {

use super::super::*;
use crate::expr::test_utils::{make_expression, make_input_ref};
use crate::vector_op::cast::{general_cast, str_parse};
use crate::vector_op::cast::{str_parse, try_cast};

#[test]
fn test_unary() {
test_unary_bool::<BoolArray, _>(|x| !x, Type::Not);
test_unary_date::<NaiveDateTimeArray, _>(|x| general_cast(x).unwrap(), Type::Cast);
test_unary_date::<NaiveDateTimeArray, _>(|x| try_cast(x).unwrap(), Type::Cast);
test_str_to_int16::<I16Array, _>(|x| str_parse(x).unwrap());
}

Expand Down
184 changes: 99 additions & 85 deletions src/expr/src/vector_op/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,29 +323,27 @@ pub fn dec_to_i64(elem: Decimal) -> Result<i64> {

/// In `PostgreSQL`, casting from timestamp to date discards the time part.
#[inline(always)]
pub fn timestamp_to_date(elem: NaiveDateTimeWrapper) -> Result<NaiveDateWrapper> {
Ok(NaiveDateWrapper(elem.0.date()))
pub fn timestamp_to_date(elem: NaiveDateTimeWrapper) -> NaiveDateWrapper {
NaiveDateWrapper(elem.0.date())
}

/// In `PostgreSQL`, casting from timestamp to time discards the date part.
#[inline(always)]
pub fn timestamp_to_time(elem: NaiveDateTimeWrapper) -> Result<NaiveTimeWrapper> {
Ok(NaiveTimeWrapper(elem.0.time()))
pub fn timestamp_to_time(elem: NaiveDateTimeWrapper) -> NaiveTimeWrapper {
NaiveTimeWrapper(elem.0.time())
}

/// In `PostgreSQL`, casting from interval to time discards the days part.
#[inline(always)]
pub fn interval_to_time(elem: IntervalUnit) -> Result<NaiveTimeWrapper> {
pub fn interval_to_time(elem: IntervalUnit) -> NaiveTimeWrapper {
let ms = elem.get_ms_of_day();
let secs = (ms / 1000) as u32;
let nano = (ms % 1000 * 1_000_000) as u32;
Ok(NaiveTimeWrapper::from_num_seconds_from_midnight_uncheck(
secs, nano,
))
NaiveTimeWrapper::from_num_seconds_from_midnight_uncheck(secs, nano)
}

#[inline(always)]
pub fn general_cast<T1, T2>(elem: T1) -> Result<T2>
pub fn try_cast<T1, T2>(elem: T1) -> Result<T2>
where
T1: TryInto<T2> + std::fmt::Debug + Copy,
<T1 as TryInto<T2>>::Error: std::fmt::Display,
Expand All @@ -354,6 +352,14 @@ where
.map_err(|_| ExprError::CastOutOfRange(std::any::type_name::<T2>()))
}

#[inline(always)]
pub fn cast<T1, T2>(elem: T1) -> T2
where
T1: Into<T2>,
{
elem.into()
}

#[inline(always)]
pub fn str_to_bool(input: &str) -> Result<bool> {
let trimmed_input = input.trim();
Expand Down Expand Up @@ -403,80 +409,81 @@ pub fn bool_out(input: bool, writer: &mut dyn Write) -> Result<()> {
/// * `$cast`: The cast type in that the operation will calculate
/// * `$func`: The scalar function for expression, it's a generic function and specialized by the
/// type of `$input, $cast`
/// * `$infallible`: Whether the cast is infallible
#[macro_export]
macro_rules! for_all_cast_variants {
($macro:ident) => {
$macro! {
{ varchar, date, str_to_date },
{ varchar, time, str_to_time },
{ varchar, interval, str_parse },
{ varchar, timestamp, str_to_timestamp },
{ varchar, timestampz, str_to_timestampz },
{ varchar, int16, str_parse },
{ varchar, int32, str_parse },
{ varchar, int64, str_parse },
{ varchar, float32, str_parse },
{ varchar, float64, str_parse },
{ varchar, decimal, str_parse },
{ varchar, boolean, str_to_bool },
{ varchar, bytea, str_to_bytea },
{ varchar, date, str_to_date, false },
{ varchar, time, str_to_time, false },
{ varchar, interval, str_parse, false },
{ varchar, timestamp, str_to_timestamp, false },
{ varchar, timestampz, str_to_timestampz, false },
{ varchar, int16, str_parse, false },
{ varchar, int32, str_parse, false },
{ varchar, int64, str_parse, false },
{ varchar, float32, str_parse, false },
{ varchar, float64, str_parse, false },
{ varchar, decimal, str_parse, false },
{ varchar, boolean, str_to_bool, false },
{ varchar, bytea, str_to_bytea, false },
// `str_to_list` requires `target_elem_type` and is handled elsewhere

{ boolean, varchar, bool_to_varchar },
{ int16, varchar, general_to_text },
{ int32, varchar, general_to_text },
{ int64, varchar, general_to_text },
{ float32, varchar, general_to_text },
{ float64, varchar, general_to_text },
{ decimal, varchar, general_to_text },
{ time, varchar, general_to_text },
{ interval, varchar, general_to_text },
{ date, varchar, general_to_text },
{ timestamp, varchar, general_to_text },
{ timestampz, varchar, timestampz_to_utc_string },
{ list, varchar, |x, w| general_to_text(x, w) },

{ boolean, int32, general_cast },
{ int32, boolean, int32_to_bool },

{ int16, int32, general_cast },
{ int16, int64, general_cast },
{ int16, float32, general_cast },
{ int16, float64, general_cast },
{ int16, decimal, general_cast },
{ int32, int16, general_cast },
{ int32, int64, general_cast },
{ int32, float32, to_f32 }, // lossy
{ int32, float64, general_cast },
{ int32, decimal, general_cast },
{ int64, int16, general_cast },
{ int64, int32, general_cast },
{ int64, float32, to_f32 }, // lossy
{ int64, float64, to_f64 }, // lossy
{ int64, decimal, general_cast },

{ float32, float64, general_cast },
{ float32, decimal, general_cast },
{ float32, int16, to_i16 },
{ float32, int32, to_i32 },
{ float32, int64, to_i64 },
{ float64, decimal, general_cast },
{ float64, int16, to_i16 },
{ float64, int32, to_i32 },
{ float64, int64, to_i64 },
{ float64, float32, to_f32 }, // lossy

{ decimal, int16, dec_to_i16 },
{ decimal, int32, dec_to_i32 },
{ decimal, int64, dec_to_i64 },
{ decimal, float32, to_f32 },
{ decimal, float64, to_f64 },

{ date, timestamp, general_cast },
{ time, interval, general_cast },
{ timestamp, date, timestamp_to_date },
{ timestamp, time, timestamp_to_time },
{ interval, time, interval_to_time }
{ boolean, varchar, bool_to_varchar, false },
{ int16, varchar, general_to_text, false },
{ int32, varchar, general_to_text, false },
{ int64, varchar, general_to_text, false },
{ float32, varchar, general_to_text, false },
{ float64, varchar, general_to_text, false },
{ decimal, varchar, general_to_text, false },
{ time, varchar, general_to_text, false },
{ interval, varchar, general_to_text, false },
{ date, varchar, general_to_text, false },
{ timestamp, varchar, general_to_text, false },
{ timestampz, varchar, timestampz_to_utc_string, false },
{ list, varchar, |x, w| general_to_text(x, w), false },

{ boolean, int32, try_cast, false },
{ int32, boolean, int32_to_bool, false },

{ int16, int32, cast::<i16, i32>, true },
{ int16, int64, cast::<i16, i64>, true },
{ int16, float32, cast::<i16, OrderedF32>, true },
{ int16, float64, cast::<i16, OrderedF64>, true },
{ int16, decimal, cast::<i16, Decimal>, true },
{ int32, int16, try_cast, false },
{ int32, int64, cast::<i32, i64>, true },
{ int32, float32, to_f32, false }, // lossy
{ int32, float64, cast::<i32, OrderedF64>, true },
{ int32, decimal, cast::<i32, Decimal>, true },
{ int64, int16, try_cast, false },
{ int64, int32, try_cast, false },
{ int64, float32, to_f32, false }, // lossy
{ int64, float64, to_f64, false }, // lossy
{ int64, decimal, cast::<i64, Decimal>, true },

{ float32, float64, cast::<OrderedF32, OrderedF64>, true },
{ float32, decimal, cast::<OrderedF32, Decimal>, true },
{ float32, int16, to_i16, false },
{ float32, int32, to_i32, false },
{ float32, int64, to_i64, false },
{ float64, decimal, cast::<OrderedF64, Decimal>, true },
{ float64, int16, to_i16, false },
{ float64, int32, to_i32, false },
{ float64, int64, to_i64, false },
{ float64, float32, to_f32, false }, // lossy

{ decimal, int16, dec_to_i16, false },
{ decimal, int32, dec_to_i32, false },
{ decimal, int64, dec_to_i64, false },
{ decimal, float32, to_f32, false },
{ decimal, float64, to_f64, false },

{ date, timestamp, cast::<NaiveDateWrapper, NaiveDateTimeWrapper>, true },
{ time, interval, cast::<NaiveTimeWrapper, IntervalUnit>, true },
{ timestamp, date, timestamp_to_date, true },
{ timestamp, time, timestamp_to_time, true },
{ interval, time, interval_to_time, true }
}
};
}
Expand Down Expand Up @@ -632,31 +639,38 @@ fn scalar_cast(
) => str_to_list(source.try_into()?, target_elem_type).map(Scalar::to_scalar_value),
(source_type, target_type) => {
macro_rules! gen_cast_impl {
($( { $input:ident, $cast:ident, $func:expr } ),*) => {
($( { $input:ident, $cast:ident, $func:expr, $infallible:ident } ),*) => {
match (source_type, target_type) {
$(
($input! { type_match_pattern }, $cast! { type_match_pattern }) => gen_cast_impl!(arm: $input, $cast, $func),
($input! { type_match_pattern }, $cast! { type_match_pattern }) => gen_cast_impl!(arm: $input, $cast, $func, $infallible),
)*
_ => {
return Err(ExprError::UnsupportedCast(source_type.clone(), target_type.clone()));
}
}
};
(arm: $input:ident, varchar, $func:expr) => {
(arm: $input:ident, varchar, $func:expr, false) => {
{
let source: <$input! { type_array } as Array>::RefItem<'_> = source.try_into()?;
let mut writer = String::new();
let target: Result<()> = $func(source, &mut writer);
target.map(|_| Scalar::to_scalar_value(writer.into_boxed_str()))
}
};
(arm: $input:ident, $cast:ident, $func:expr) => {
(arm: $input:ident, $cast:ident, $func:expr, false) => {
{
let source: <$input! { type_array } as Array>::RefItem<'_> = source.try_into()?;
let target: Result<<$cast! { type_array } as Array>::OwnedItem> = $func(source);
target.map(Scalar::to_scalar_value)
}
};
(arm: $input:ident, $cast:ident, $func:expr, true) => {
{
let source: <$input! { type_array } as Array>::RefItem<'_> = source.try_into()?;
let target: Result<<$cast! { type_array } as Array>::OwnedItem> = Ok($func(source));
target.map(Scalar::to_scalar_value)
}
};
}
for_all_cast_variants!(gen_cast_impl)
}
Expand Down Expand Up @@ -757,19 +771,19 @@ mod tests {
#[test]
fn temporal_cast() {
assert_eq!(
timestamp_to_date(str_to_timestamp("1999-01-08 04:02").unwrap()).unwrap(),
timestamp_to_date(str_to_timestamp("1999-01-08 04:02").unwrap()),
str_to_date("1999-01-08").unwrap(),
);
assert_eq!(
timestamp_to_time(str_to_timestamp("1999-01-08 04:02").unwrap()).unwrap(),
timestamp_to_time(str_to_timestamp("1999-01-08 04:02").unwrap()),
str_to_time("04:02").unwrap(),
);
assert_eq!(
interval_to_time(IntervalUnit::new(1, 2, 61003)).unwrap(),
interval_to_time(IntervalUnit::new(1, 2, 61003)),
str_to_time("00:01:01.003").unwrap(),
);
assert_eq!(
interval_to_time(IntervalUnit::new(0, 0, -61003)).unwrap(),
interval_to_time(IntervalUnit::new(0, 0, -61003)),
str_to_time("23:58:58.997").unwrap(),
);
}
Expand Down
4 changes: 2 additions & 2 deletions src/expr/src/vector_op/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use risingwave_common::types::{

use crate::vector_op::arithmetic_op::*;
use crate::vector_op::bitwise_op::*;
use crate::vector_op::cast::general_cast;
use crate::vector_op::cast::try_cast;
use crate::vector_op::cmp::*;
use crate::vector_op::conjunction::*;
use crate::ExprError;
Expand Down Expand Up @@ -239,7 +239,7 @@ fn test_conjunction() {
#[test]
fn test_cast() {
assert_eq!(
general_cast::<_, NaiveDateTimeWrapper>(NaiveDateWrapper::from_ymd_uncheck(1994, 1, 1))
try_cast::<_, NaiveDateTimeWrapper>(NaiveDateWrapper::from_ymd_uncheck(1994, 1, 1))
.unwrap(),
NaiveDateTimeWrapper::new(
NaiveDateTime::parse_from_str("1994-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap()
Expand Down