Skip to content

Commit

Permalink
optimize infallible cast
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <wangrunji0408@163.com>
  • Loading branch information
wangrunji0408 committed Dec 27, 2022
1 parent 064ff7c commit 91d35f1
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 112 deletions.
6 changes: 6 additions & 0 deletions src/common/src/types/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use super::to_text::ToText;
use super::DataType;
use crate::array::ArrayResult;
use crate::error::Result as RwResult;
use crate::types::ordered_float::OrderedFloat;
use crate::types::Decimal::Normalized;

#[derive(Debug, Copy, parse_display::Display, Clone, PartialEq, Hash, Eq, Ord, PartialOrd)]
Expand Down Expand Up @@ -191,6 +192,11 @@ macro_rules! impl_try_from_float {
$convert(value).expect("f32/f64 to decimal should not fail")
}
}
impl core::convert::From<OrderedFloat<$from_ty>> for $to_ty {
fn from(value: OrderedFloat<$from_ty>) -> Self {
$convert(value.0).expect("f32/f64 to decimal should not fail")
}
}
};
}

Expand Down
19 changes: 0 additions & 19 deletions src/common/src/types/ordered_float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
//! Wrappers for total order on Floats. See the [`OrderedFloat`] docs for details.
use core::cmp::Ordering;
use core::convert::TryFrom;
use core::fmt;
use core::hash::{Hash, Hasher};
use core::iter::{Product, Sum};
Expand Down Expand Up @@ -1063,24 +1062,6 @@ mod impl_as_primitive {
mod impl_from {
use super::*;

macro_rules! impl_try_from_for {
($ty:ty) => {
impl<F> TryFrom<OrderedFloat<F>> for $ty
where
F: 'static + Float,
Self: TryFrom<F>,
{
type Error = <Self as TryFrom<F>>::Error;

fn try_from(value: OrderedFloat<F>) -> Result<Self, Self::Error> {
TryFrom::try_from(value.0)
}
}
};
}

impl_try_from_for!(crate::types::Decimal);

macro_rules! impl_from_for {
($ty:ty) => {
impl<F> From<OrderedFloat<F>> for $ty
Expand Down
19 changes: 13 additions & 6 deletions src/expr/src/expr/expr_unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,30 +162,37 @@ pub fn new_unary_expr(
)),
(ProstType::Cast, _, _) => {
macro_rules! gen_cast_impl {
($( { $input:ident, $cast:ident, $func:expr } ),*) => {
($( { $input:ident, $cast:ident, $func:expr, $infallible:ident } ),*) => {
match (child_expr.return_type(), return_type.clone()) {
$(
($input! { type_match_pattern }, $cast! { type_match_pattern }) => gen_cast_impl!(arm: $input, $cast, $func),
($input! { type_match_pattern }, $cast! { type_match_pattern }) => gen_cast_impl!(arm: $input, $cast, $func, $infallible),
)*
_ => {
return Err(ExprError::UnsupportedCast(child_expr.return_type(), return_type));
}
}
};
(arm: $input:ident, varchar, $func:expr) => {
(arm: $input:ident, varchar, $func:expr, false) => {
UnaryBytesExpression::< $input! { type_array }, _>::new(
child_expr,
return_type.clone(),
$func
).boxed()
};
(arm: $input:ident, $cast:ident, $func:expr) => {
(arm: $input:ident, $cast:ident, $func:expr, false) => {
UnaryExpression::< $input! { type_array }, $cast! { type_array }, _>::new(
child_expr,
return_type.clone(),
$func
).boxed()
};
(arm: $input:ident, $cast:ident, $func:expr, true) => {
template_fast::UnaryExpression::new(
child_expr,
return_type.clone(),
$func
).boxed()
};
}

for_all_cast_variants! { gen_cast_impl }
Expand Down Expand Up @@ -348,12 +355,12 @@ mod tests {

use super::super::*;
use crate::expr::test_utils::{make_expression, make_input_ref};
use crate::vector_op::cast::{general_cast, str_parse};
use crate::vector_op::cast::{str_parse, try_cast};

#[test]
fn test_unary() {
test_unary_bool::<BoolArray, _>(|x| !x, Type::Not);
test_unary_date::<NaiveDateTimeArray, _>(|x| general_cast(x).unwrap(), Type::Cast);
test_unary_date::<NaiveDateTimeArray, _>(|x| try_cast(x).unwrap(), Type::Cast);
test_str_to_int16::<I16Array, _>(|x| str_parse(x).unwrap());
}

Expand Down
184 changes: 99 additions & 85 deletions src/expr/src/vector_op/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,29 +323,27 @@ pub fn dec_to_i64(elem: Decimal) -> Result<i64> {

/// In `PostgreSQL`, casting from timestamp to date discards the time part.
#[inline(always)]
pub fn timestamp_to_date(elem: NaiveDateTimeWrapper) -> Result<NaiveDateWrapper> {
Ok(NaiveDateWrapper(elem.0.date()))
pub fn timestamp_to_date(elem: NaiveDateTimeWrapper) -> NaiveDateWrapper {
NaiveDateWrapper(elem.0.date())
}

/// In `PostgreSQL`, casting from timestamp to time discards the date part.
#[inline(always)]
pub fn timestamp_to_time(elem: NaiveDateTimeWrapper) -> Result<NaiveTimeWrapper> {
Ok(NaiveTimeWrapper(elem.0.time()))
pub fn timestamp_to_time(elem: NaiveDateTimeWrapper) -> NaiveTimeWrapper {
NaiveTimeWrapper(elem.0.time())
}

/// In `PostgreSQL`, casting from interval to time discards the days part.
#[inline(always)]
pub fn interval_to_time(elem: IntervalUnit) -> Result<NaiveTimeWrapper> {
pub fn interval_to_time(elem: IntervalUnit) -> NaiveTimeWrapper {
let ms = elem.get_ms_of_day();
let secs = (ms / 1000) as u32;
let nano = (ms % 1000 * 1_000_000) as u32;
Ok(NaiveTimeWrapper::from_num_seconds_from_midnight_uncheck(
secs, nano,
))
NaiveTimeWrapper::from_num_seconds_from_midnight_uncheck(secs, nano)
}

#[inline(always)]
pub fn general_cast<T1, T2>(elem: T1) -> Result<T2>
pub fn try_cast<T1, T2>(elem: T1) -> Result<T2>
where
T1: TryInto<T2> + std::fmt::Debug + Copy,
<T1 as TryInto<T2>>::Error: std::fmt::Display,
Expand All @@ -354,6 +352,14 @@ where
.map_err(|_| ExprError::CastOutOfRange(std::any::type_name::<T2>()))
}

#[inline(always)]
pub fn cast<T1, T2>(elem: T1) -> T2
where
T1: Into<T2>,
{
elem.into()
}

#[inline(always)]
pub fn str_to_bool(input: &str) -> Result<bool> {
let trimmed_input = input.trim();
Expand Down Expand Up @@ -403,80 +409,81 @@ pub fn bool_out(input: bool, writer: &mut dyn Write) -> Result<()> {
/// * `$cast`: The cast type in that the operation will calculate
/// * `$func`: The scalar function for expression, it's a generic function and specialized by the
/// type of `$input, $cast`
/// * `$infallible`: Whether the cast is infallible
#[macro_export]
macro_rules! for_all_cast_variants {
($macro:ident) => {
$macro! {
{ varchar, date, str_to_date },
{ varchar, time, str_to_time },
{ varchar, interval, str_parse },
{ varchar, timestamp, str_to_timestamp },
{ varchar, timestampz, str_to_timestampz },
{ varchar, int16, str_parse },
{ varchar, int32, str_parse },
{ varchar, int64, str_parse },
{ varchar, float32, str_parse },
{ varchar, float64, str_parse },
{ varchar, decimal, str_parse },
{ varchar, boolean, str_to_bool },
{ varchar, bytea, str_to_bytea },
{ varchar, date, str_to_date, false },
{ varchar, time, str_to_time, false },
{ varchar, interval, str_parse, false },
{ varchar, timestamp, str_to_timestamp, false },
{ varchar, timestampz, str_to_timestampz, false },
{ varchar, int16, str_parse, false },
{ varchar, int32, str_parse, false },
{ varchar, int64, str_parse, false },
{ varchar, float32, str_parse, false },
{ varchar, float64, str_parse, false },
{ varchar, decimal, str_parse, false },
{ varchar, boolean, str_to_bool, false },
{ varchar, bytea, str_to_bytea, false },
// `str_to_list` requires `target_elem_type` and is handled elsewhere

{ boolean, varchar, bool_to_varchar },
{ int16, varchar, general_to_text },
{ int32, varchar, general_to_text },
{ int64, varchar, general_to_text },
{ float32, varchar, general_to_text },
{ float64, varchar, general_to_text },
{ decimal, varchar, general_to_text },
{ time, varchar, general_to_text },
{ interval, varchar, general_to_text },
{ date, varchar, general_to_text },
{ timestamp, varchar, general_to_text },
{ timestampz, varchar, timestampz_to_utc_string },
{ list, varchar, |x, w| general_to_text(x, w) },

{ boolean, int32, general_cast },
{ int32, boolean, int32_to_bool },

{ int16, int32, general_cast },
{ int16, int64, general_cast },
{ int16, float32, general_cast },
{ int16, float64, general_cast },
{ int16, decimal, general_cast },
{ int32, int16, general_cast },
{ int32, int64, general_cast },
{ int32, float32, to_f32 }, // lossy
{ int32, float64, general_cast },
{ int32, decimal, general_cast },
{ int64, int16, general_cast },
{ int64, int32, general_cast },
{ int64, float32, to_f32 }, // lossy
{ int64, float64, to_f64 }, // lossy
{ int64, decimal, general_cast },

{ float32, float64, general_cast },
{ float32, decimal, general_cast },
{ float32, int16, to_i16 },
{ float32, int32, to_i32 },
{ float32, int64, to_i64 },
{ float64, decimal, general_cast },
{ float64, int16, to_i16 },
{ float64, int32, to_i32 },
{ float64, int64, to_i64 },
{ float64, float32, to_f32 }, // lossy

{ decimal, int16, dec_to_i16 },
{ decimal, int32, dec_to_i32 },
{ decimal, int64, dec_to_i64 },
{ decimal, float32, to_f32 },
{ decimal, float64, to_f64 },

{ date, timestamp, general_cast },
{ time, interval, general_cast },
{ timestamp, date, timestamp_to_date },
{ timestamp, time, timestamp_to_time },
{ interval, time, interval_to_time }
{ boolean, varchar, bool_to_varchar, false },
{ int16, varchar, general_to_text, false },
{ int32, varchar, general_to_text, false },
{ int64, varchar, general_to_text, false },
{ float32, varchar, general_to_text, false },
{ float64, varchar, general_to_text, false },
{ decimal, varchar, general_to_text, false },
{ time, varchar, general_to_text, false },
{ interval, varchar, general_to_text, false },
{ date, varchar, general_to_text, false },
{ timestamp, varchar, general_to_text, false },
{ timestampz, varchar, timestampz_to_utc_string, false },
{ list, varchar, |x, w| general_to_text(x, w), false },

{ boolean, int32, try_cast, false },
{ int32, boolean, int32_to_bool, false },

{ int16, int32, cast::<i16, i32>, true },
{ int16, int64, cast::<i16, i64>, true },
{ int16, float32, cast::<i16, OrderedF32>, true },
{ int16, float64, cast::<i16, OrderedF64>, true },
{ int16, decimal, cast::<i16, Decimal>, true },
{ int32, int16, try_cast, false },
{ int32, int64, cast::<i32, i64>, true },
{ int32, float32, to_f32, false }, // lossy
{ int32, float64, cast::<i32, OrderedF64>, true },
{ int32, decimal, cast::<i32, Decimal>, true },
{ int64, int16, try_cast, false },
{ int64, int32, try_cast, false },
{ int64, float32, to_f32, false }, // lossy
{ int64, float64, to_f64, false }, // lossy
{ int64, decimal, cast::<i64, Decimal>, true },

{ float32, float64, cast::<OrderedF32, OrderedF64>, true },
{ float32, decimal, cast::<OrderedF32, Decimal>, true },
{ float32, int16, to_i16, false },
{ float32, int32, to_i32, false },
{ float32, int64, to_i64, false },
{ float64, decimal, cast::<OrderedF64, Decimal>, true },
{ float64, int16, to_i16, false },
{ float64, int32, to_i32, false },
{ float64, int64, to_i64, false },
{ float64, float32, to_f32, false }, // lossy

{ decimal, int16, dec_to_i16, false },
{ decimal, int32, dec_to_i32, false },
{ decimal, int64, dec_to_i64, false },
{ decimal, float32, to_f32, false },
{ decimal, float64, to_f64, false },

{ date, timestamp, try_cast, false },
{ time, interval, try_cast, false },
{ timestamp, date, timestamp_to_date, true },
{ timestamp, time, timestamp_to_time, true },
{ interval, time, interval_to_time, true }
}
};
}
Expand Down Expand Up @@ -632,31 +639,38 @@ fn scalar_cast(
) => str_to_list(source.try_into()?, target_elem_type).map(Scalar::to_scalar_value),
(source_type, target_type) => {
macro_rules! gen_cast_impl {
($( { $input:ident, $cast:ident, $func:expr } ),*) => {
($( { $input:ident, $cast:ident, $func:expr, $infallible:ident } ),*) => {
match (source_type, target_type) {
$(
($input! { type_match_pattern }, $cast! { type_match_pattern }) => gen_cast_impl!(arm: $input, $cast, $func),
($input! { type_match_pattern }, $cast! { type_match_pattern }) => gen_cast_impl!(arm: $input, $cast, $func, $infallible),
)*
_ => {
return Err(ExprError::UnsupportedCast(source_type.clone(), target_type.clone()));
}
}
};
(arm: $input:ident, varchar, $func:expr) => {
(arm: $input:ident, varchar, $func:expr, false) => {
{
let source: <$input! { type_array } as Array>::RefItem<'_> = source.try_into()?;
let mut writer = String::new();
let target: Result<()> = $func(source, &mut writer);
target.map(|_| Scalar::to_scalar_value(writer.into_boxed_str()))
}
};
(arm: $input:ident, $cast:ident, $func:expr) => {
(arm: $input:ident, $cast:ident, $func:expr, false) => {
{
let source: <$input! { type_array } as Array>::RefItem<'_> = source.try_into()?;
let target: Result<<$cast! { type_array } as Array>::OwnedItem> = $func(source);
target.map(Scalar::to_scalar_value)
}
};
(arm: $input:ident, $cast:ident, $func:expr, true) => {
{
let source: <$input! { type_array } as Array>::RefItem<'_> = source.try_into()?;
let target: Result<<$cast! { type_array } as Array>::OwnedItem> = Ok($func(source));
target.map(Scalar::to_scalar_value)
}
};
}
for_all_cast_variants!(gen_cast_impl)
}
Expand Down Expand Up @@ -757,19 +771,19 @@ mod tests {
#[test]
fn temporal_cast() {
assert_eq!(
timestamp_to_date(str_to_timestamp("1999-01-08 04:02").unwrap()).unwrap(),
timestamp_to_date(str_to_timestamp("1999-01-08 04:02").unwrap()),
str_to_date("1999-01-08").unwrap(),
);
assert_eq!(
timestamp_to_time(str_to_timestamp("1999-01-08 04:02").unwrap()).unwrap(),
timestamp_to_time(str_to_timestamp("1999-01-08 04:02").unwrap()),
str_to_time("04:02").unwrap(),
);
assert_eq!(
interval_to_time(IntervalUnit::new(1, 2, 61003)).unwrap(),
interval_to_time(IntervalUnit::new(1, 2, 61003)),
str_to_time("00:01:01.003").unwrap(),
);
assert_eq!(
interval_to_time(IntervalUnit::new(0, 0, -61003)).unwrap(),
interval_to_time(IntervalUnit::new(0, 0, -61003)),
str_to_time("23:58:58.997").unwrap(),
);
}
Expand Down
4 changes: 2 additions & 2 deletions src/expr/src/vector_op/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use risingwave_common::types::{

use crate::vector_op::arithmetic_op::*;
use crate::vector_op::bitwise_op::*;
use crate::vector_op::cast::general_cast;
use crate::vector_op::cast::try_cast;
use crate::vector_op::cmp::*;
use crate::vector_op::conjunction::*;
use crate::ExprError;
Expand Down Expand Up @@ -239,7 +239,7 @@ fn test_conjunction() {
#[test]
fn test_cast() {
assert_eq!(
general_cast::<_, NaiveDateTimeWrapper>(NaiveDateWrapper::from_ymd_uncheck(1994, 1, 1))
try_cast::<_, NaiveDateTimeWrapper>(NaiveDateWrapper::from_ymd_uncheck(1994, 1, 1))
.unwrap(),
NaiveDateTimeWrapper::new(
NaiveDateTime::parse_from_str("1994-1-1 0:0:0", "%Y-%m-%d %H:%M:%S").unwrap()
Expand Down

0 comments on commit 91d35f1

Please sign in to comment.