From 95b0eff041fdcb082c4b0144b5cf8f4787840e27 Mon Sep 17 00:00:00 2001 From: Yijun Zhao Date: Sat, 24 Feb 2024 17:42:08 +0800 Subject: [PATCH] support decimal for quantile_cont --- src/query/expression/src/types.rs | 5 +- src/query/expression/src/types/any.rs | 6 +- src/query/expression/src/types/array.rs | 12 +- src/query/expression/src/types/binary.rs | 6 +- src/query/expression/src/types/bitmap.rs | 6 +- src/query/expression/src/types/boolean.rs | 6 +- src/query/expression/src/types/date.rs | 6 +- src/query/expression/src/types/decimal.rs | 26 ++- src/query/expression/src/types/empty_array.rs | 6 +- src/query/expression/src/types/empty_map.rs | 6 +- src/query/expression/src/types/generic.rs | 6 +- src/query/expression/src/types/geometry.rs | 6 +- src/query/expression/src/types/map.rs | 13 +- src/query/expression/src/types/null.rs | 6 +- src/query/expression/src/types/nullable.rs | 14 +- src/query/expression/src/types/number.rs | 5 +- src/query/expression/src/types/string.rs | 6 +- src/query/expression/src/types/timestamp.rs | 6 +- src/query/expression/src/types/variant.rs | 6 +- .../src/aggregates/aggregate_quantile_cont.rs | 217 ++++++++++++++++++ .../src/aggregates/aggregate_unary.rs | 60 +++-- .../02_0000_function_aggregate_mix.test | 10 + 22 files changed, 379 insertions(+), 61 deletions(-) diff --git a/src/query/expression/src/types.rs b/src/query/expression/src/types.rs index 8a249f0cb05a7..8ac7f84f9a9e4 100755 --- a/src/query/expression/src/types.rs +++ b/src/query/expression/src/types.rs @@ -345,7 +345,10 @@ pub trait ValueType: Debug + Clone + PartialEq + Sized + 'static { fn try_downcast_owned_builder(builder: ColumnBuilder) -> Option; - fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option; + fn try_upcast_column_builder( + builder: Self::ColumnBuilder, + decimal_size: Option, + ) -> Option; fn upcast_scalar(scalar: Self::Scalar) -> Scalar; fn upcast_column(col: Self::Column) -> Column; diff --git a/src/query/expression/src/types/any.rs b/src/query/expression/src/types/any.rs index cbbff225b433a..98d5698db2c35 100755 --- a/src/query/expression/src/types/any.rs +++ b/src/query/expression/src/types/any.rs @@ -16,6 +16,7 @@ use std::cmp::Ordering; use std::ops::Range; use crate::property::Domain; +use crate::types::DecimalSize; use crate::types::ValueType; use crate::values::Column; use crate::values::Scalar; @@ -67,7 +68,10 @@ impl ValueType for AnyType { Some(builder) } - fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + fn try_upcast_column_builder( + builder: Self::ColumnBuilder, + _decimal_size: Option, + ) -> Option { Some(builder) } diff --git a/src/query/expression/src/types/array.rs b/src/query/expression/src/types/array.rs index aa1bfce0b8437..107c14c2f1ccb 100755 --- a/src/query/expression/src/types/array.rs +++ b/src/query/expression/src/types/array.rs @@ -22,6 +22,7 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; use super::AnyType; +use super::DecimalSize; use crate::property::Domain; use crate::types::ArgType; use crate::types::DataType; @@ -107,8 +108,11 @@ impl ValueType for ArrayType { } } - fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { - Some(ColumnBuilder::Array(Box::new(builder.upcast()))) + fn try_upcast_column_builder( + builder: Self::ColumnBuilder, + decimal_size: Option, + ) -> Option { + Some(ColumnBuilder::Array(Box::new(builder.upcast(decimal_size)))) } fn upcast_scalar(scalar: Self::Scalar) -> Scalar { @@ -395,9 +399,9 @@ impl ArrayColumnBuilder { ) } - pub fn upcast(self) -> ArrayColumnBuilder { + pub fn upcast(self, decimal_size: Option) -> ArrayColumnBuilder { ArrayColumnBuilder { - builder: T::try_upcast_column_builder(self.builder).unwrap(), + builder: T::try_upcast_column_builder(self.builder, decimal_size).unwrap(), offsets: self.offsets, } } diff --git a/src/query/expression/src/types/binary.rs b/src/query/expression/src/types/binary.rs index c78459d1b75f5..e0d4e9dae0bc2 100644 --- a/src/query/expression/src/types/binary.rs +++ b/src/query/expression/src/types/binary.rs @@ -25,6 +25,7 @@ use serde::Serialize; use crate::property::Domain; use crate::types::ArgType; use crate::types::DataType; +use crate::types::DecimalSize; use crate::types::GenericMap; use crate::types::ValueType; use crate::utils::arrow::buffer_into_mut; @@ -87,7 +88,10 @@ impl ValueType for BinaryType { } } - fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + fn try_upcast_column_builder( + builder: Self::ColumnBuilder, + _decimal_size: Option, + ) -> Option { Some(ColumnBuilder::Binary(builder)) } diff --git a/src/query/expression/src/types/bitmap.rs b/src/query/expression/src/types/bitmap.rs index 1bc3859f76a78..388b7bd894c5a 100644 --- a/src/query/expression/src/types/bitmap.rs +++ b/src/query/expression/src/types/bitmap.rs @@ -20,6 +20,7 @@ use super::binary::BinaryIterator; use crate::property::Domain; use crate::types::ArgType; use crate::types::DataType; +use crate::types::DecimalSize; use crate::types::GenericMap; use crate::types::ValueType; use crate::values::Column; @@ -73,7 +74,10 @@ impl ValueType for BitmapType { } } - fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + fn try_upcast_column_builder( + builder: Self::ColumnBuilder, + _decimal_size: Option, + ) -> Option { Some(ColumnBuilder::Bitmap(builder)) } diff --git a/src/query/expression/src/types/boolean.rs b/src/query/expression/src/types/boolean.rs index d02723b56bc83..35216d652f0a9 100644 --- a/src/query/expression/src/types/boolean.rs +++ b/src/query/expression/src/types/boolean.rs @@ -20,6 +20,7 @@ use databend_common_arrow::arrow::bitmap::MutableBitmap; use crate::property::Domain; use crate::types::ArgType; use crate::types::DataType; +use crate::types::DecimalSize; use crate::types::GenericMap; use crate::types::ValueType; use crate::utils::arrow::bitmap_into_mut; @@ -80,7 +81,10 @@ impl ValueType for BooleanType { } } - fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + fn try_upcast_column_builder( + builder: Self::ColumnBuilder, + _decimal_size: Option, + ) -> Option { Some(ColumnBuilder::Boolean(builder)) } diff --git a/src/query/expression/src/types/date.rs b/src/query/expression/src/types/date.rs index 92ff02ab0af1a..0ae6657d4d587 100644 --- a/src/query/expression/src/types/date.rs +++ b/src/query/expression/src/types/date.rs @@ -28,6 +28,7 @@ use crate::date_helper::DateConverter; use crate::property::Domain; use crate::types::ArgType; use crate::types::DataType; +use crate::types::DecimalSize; use crate::types::GenericMap; use crate::types::ValueType; use crate::utils::arrow::buffer_into_mut; @@ -108,7 +109,10 @@ impl ValueType for DateType { } } - fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + fn try_upcast_column_builder( + builder: Self::ColumnBuilder, + _decimal_size: Option, + ) -> Option { Some(ColumnBuilder::Date(builder)) } diff --git a/src/query/expression/src/types/decimal.rs b/src/query/expression/src/types/decimal.rs index 814d6300288ae..bff297bc08f2d 100644 --- a/src/query/expression/src/types/decimal.rs +++ b/src/query/expression/src/types/decimal.rs @@ -94,8 +94,14 @@ impl ValueType for DecimalType { Num::try_downcast_owned_builder(builder) } - fn try_upcast_column_builder(_builder: Self::ColumnBuilder) -> Option { - None + fn try_upcast_column_builder( + builder: Self::ColumnBuilder, + decimal_size: Option, + ) -> Option { + Some(ColumnBuilder::Decimal(Num::upcast_builder( + builder, + decimal_size.unwrap(), + ))) } fn upcast_scalar(scalar: Self::Scalar) -> Scalar { @@ -360,6 +366,7 @@ pub trait Decimal: fn upcast_scalar(scalar: Self, size: DecimalSize) -> Scalar; fn upcast_column(col: Buffer, size: DecimalSize) -> Column; fn upcast_domain(domain: SimpleDomain, size: DecimalSize) -> Domain; + fn upcast_builder(builder: Vec, size: DecimalSize) -> DecimalColumnBuilder; fn data_type() -> DataType; const MIN: Self; const MAX: Self; @@ -574,6 +581,10 @@ impl Decimal for i128 { Domain::Decimal(DecimalDomain::Decimal128(domain, size)) } + fn upcast_builder(builder: Vec, size: DecimalSize) -> DecimalColumnBuilder { + DecimalColumnBuilder::Decimal128(builder, size) + } + fn data_type() -> DataType { DataType::Decimal(DecimalDataType::Decimal128(DecimalSize { precision: MAX_DECIMAL128_PRECISION, @@ -733,6 +744,10 @@ impl Decimal for i256 { Domain::Decimal(DecimalDomain::Decimal256(domain, size)) } + fn upcast_builder(builder: Vec, size: DecimalSize) -> DecimalColumnBuilder { + DecimalColumnBuilder::Decimal256(builder, size) + } + fn data_type() -> DataType { DataType::Decimal(DecimalDataType::Decimal256(DecimalSize { precision: MAX_DECIMAL256_PRECISION, @@ -1080,6 +1095,13 @@ impl DecimalColumnBuilder { } }) } + + pub fn decimal_size(&self) -> DecimalSize { + match self { + DecimalColumnBuilder::Decimal128(_, size) + | DecimalColumnBuilder::Decimal256(_, size) => *size, + } + } } impl PartialOrd for DecimalScalar { diff --git a/src/query/expression/src/types/empty_array.rs b/src/query/expression/src/types/empty_array.rs index 8caf8030071dd..7db1c2bfceea0 100644 --- a/src/query/expression/src/types/empty_array.rs +++ b/src/query/expression/src/types/empty_array.rs @@ -17,6 +17,7 @@ use std::ops::Range; use crate::property::Domain; use crate::types::ArgType; use crate::types::DataType; +use crate::types::DecimalSize; use crate::types::GenericMap; use crate::types::ValueType; use crate::values::Column; @@ -81,7 +82,10 @@ impl ValueType for EmptyArrayType { } } - fn try_upcast_column_builder(len: Self::ColumnBuilder) -> Option { + fn try_upcast_column_builder( + len: Self::ColumnBuilder, + _decimal_size: Option, + ) -> Option { Some(ColumnBuilder::EmptyArray { len }) } diff --git a/src/query/expression/src/types/empty_map.rs b/src/query/expression/src/types/empty_map.rs index 0b79e7be259b5..3ade20fff78f0 100644 --- a/src/query/expression/src/types/empty_map.rs +++ b/src/query/expression/src/types/empty_map.rs @@ -17,6 +17,7 @@ use std::ops::Range; use crate::property::Domain; use crate::types::ArgType; use crate::types::DataType; +use crate::types::DecimalSize; use crate::types::GenericMap; use crate::types::ValueType; use crate::values::Column; @@ -81,7 +82,10 @@ impl ValueType for EmptyMapType { } } - fn try_upcast_column_builder(len: Self::ColumnBuilder) -> Option { + fn try_upcast_column_builder( + len: Self::ColumnBuilder, + _decimal_size: Option, + ) -> Option { Some(ColumnBuilder::EmptyMap { len }) } diff --git a/src/query/expression/src/types/generic.rs b/src/query/expression/src/types/generic.rs index c944d2298bb3b..88a9fa1033e6a 100755 --- a/src/query/expression/src/types/generic.rs +++ b/src/query/expression/src/types/generic.rs @@ -17,6 +17,7 @@ use std::ops::Range; use crate::property::Domain; use crate::types::ArgType; use crate::types::DataType; +use crate::types::DecimalSize; use crate::types::GenericMap; use crate::types::ValueType; use crate::values::Column; @@ -69,7 +70,10 @@ impl ValueType for GenericType { Some(builder) } - fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + fn try_upcast_column_builder( + builder: Self::ColumnBuilder, + _decimal_size: Option, + ) -> Option { Some(builder) } diff --git a/src/query/expression/src/types/geometry.rs b/src/query/expression/src/types/geometry.rs index 4b65f06751db1..d9013c9f5fd89 100644 --- a/src/query/expression/src/types/geometry.rs +++ b/src/query/expression/src/types/geometry.rs @@ -25,6 +25,7 @@ use super::binary::BinaryIterator; use crate::property::Domain; use crate::types::ArgType; use crate::types::DataType; +use crate::types::DecimalSize; use crate::types::GenericMap; use crate::types::ValueType; use crate::values::Column; @@ -86,7 +87,10 @@ impl ValueType for GeometryType { } } - fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + fn try_upcast_column_builder( + builder: Self::ColumnBuilder, + _decimal_size: Option, + ) -> Option { Some(ColumnBuilder::Geometry(builder)) } diff --git a/src/query/expression/src/types/map.rs b/src/query/expression/src/types/map.rs index dd4308f4a75e0..c8ea28e7c1028 100755 --- a/src/query/expression/src/types/map.rs +++ b/src/query/expression/src/types/map.rs @@ -18,6 +18,7 @@ use std::ops::Range; use databend_common_arrow::arrow::trusted_len::TrustedLen; use super::ArrayType; +use super::DecimalSize; use crate::property::Domain; use crate::types::array::ArrayColumn; use crate::types::ArgType; @@ -92,7 +93,10 @@ impl ValueType for KvPair { None } - fn try_upcast_column_builder(_builder: Self::ColumnBuilder) -> Option { + fn try_upcast_column_builder( + _builder: Self::ColumnBuilder, + _decimal_size: Option, + ) -> Option { None } @@ -358,8 +362,11 @@ impl ValueType for MapType { as ValueType>::try_downcast_owned_builder(builder) } - fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { - as ValueType>::try_upcast_column_builder(builder) + fn try_upcast_column_builder( + builder: Self::ColumnBuilder, + decimal_size: Option, + ) -> Option { + as ValueType>::try_upcast_column_builder(builder, decimal_size) } fn upcast_scalar(scalar: Self::Scalar) -> Scalar { diff --git a/src/query/expression/src/types/null.rs b/src/query/expression/src/types/null.rs index 9970a74025522..312f3e95a003f 100644 --- a/src/query/expression/src/types/null.rs +++ b/src/query/expression/src/types/null.rs @@ -18,6 +18,7 @@ use super::nullable::NullableDomain; use crate::property::Domain; use crate::types::ArgType; use crate::types::DataType; +use crate::types::DecimalSize; use crate::types::GenericMap; use crate::types::ValueType; use crate::values::Column; @@ -85,7 +86,10 @@ impl ValueType for NullType { } } - fn try_upcast_column_builder(len: Self::ColumnBuilder) -> Option { + fn try_upcast_column_builder( + len: Self::ColumnBuilder, + _decimal_size: Option, + ) -> Option { Some(ColumnBuilder::Null { len }) } diff --git a/src/query/expression/src/types/nullable.rs b/src/query/expression/src/types/nullable.rs index c6eca90939cda..8ade8206d7c3c 100755 --- a/src/query/expression/src/types/nullable.rs +++ b/src/query/expression/src/types/nullable.rs @@ -20,6 +20,7 @@ use databend_common_arrow::arrow::bitmap::MutableBitmap; use databend_common_arrow::arrow::trusted_len::TrustedLen; use super::AnyType; +use super::DecimalSize; use crate::property::Domain; use crate::types::ArgType; use crate::types::DataType; @@ -121,8 +122,13 @@ impl ValueType for NullableType { } } - fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { - Some(ColumnBuilder::Nullable(Box::new(builder.upcast()))) + fn try_upcast_column_builder( + builder: Self::ColumnBuilder, + decimal_size: Option, + ) -> Option { + Some(ColumnBuilder::Nullable(Box::new( + builder.upcast(decimal_size), + ))) } fn upcast_scalar(scalar: Self::Scalar) -> Scalar { @@ -377,9 +383,9 @@ impl NullableColumnBuilder { } } - pub fn upcast(self) -> NullableColumnBuilder { + pub fn upcast(self, decimal_size: Option) -> NullableColumnBuilder { NullableColumnBuilder { - builder: T::try_upcast_column_builder(self.builder).unwrap(), + builder: T::try_upcast_column_builder(self.builder, decimal_size).unwrap(), validity: self.validity, } } diff --git a/src/query/expression/src/types/number.rs b/src/query/expression/src/types/number.rs index dbd794e15849a..c06c191ec9aaa 100644 --- a/src/query/expression/src/types/number.rs +++ b/src/query/expression/src/types/number.rs @@ -142,7 +142,10 @@ impl ValueType for NumberType { } } - fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + fn try_upcast_column_builder( + builder: Self::ColumnBuilder, + _decimal_size: Option, + ) -> Option { Num::try_upcast_column_builder(builder) } diff --git a/src/query/expression/src/types/string.rs b/src/query/expression/src/types/string.rs index 3a7ece60a5e57..49e1595eb46b8 100644 --- a/src/query/expression/src/types/string.rs +++ b/src/query/expression/src/types/string.rs @@ -28,6 +28,7 @@ use super::binary::BinaryIterator; use crate::property::Domain; use crate::types::ArgType; use crate::types::DataType; +use crate::types::DecimalSize; use crate::types::GenericMap; use crate::types::ValueType; use crate::utils::arrow::buffer_into_mut; @@ -86,7 +87,10 @@ impl ValueType for StringType { } } - fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + fn try_upcast_column_builder( + builder: Self::ColumnBuilder, + _decimal_size: Option, + ) -> Option { Some(ColumnBuilder::String(builder)) } diff --git a/src/query/expression/src/types/timestamp.rs b/src/query/expression/src/types/timestamp.rs index 808cb9f7c2de0..d40f299a6d118 100644 --- a/src/query/expression/src/types/timestamp.rs +++ b/src/query/expression/src/types/timestamp.rs @@ -27,6 +27,7 @@ use super::number::SimpleDomain; use crate::property::Domain; use crate::types::ArgType; use crate::types::DataType; +use crate::types::DecimalSize; use crate::types::GenericMap; use crate::types::ValueType; use crate::utils::arrow::buffer_into_mut; @@ -115,7 +116,10 @@ impl ValueType for TimestampType { } } - fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + fn try_upcast_column_builder( + builder: Self::ColumnBuilder, + _decimal_size: Option, + ) -> Option { Some(ColumnBuilder::Timestamp(builder)) } diff --git a/src/query/expression/src/types/variant.rs b/src/query/expression/src/types/variant.rs index a074336dbc58d..e50b121431afb 100644 --- a/src/query/expression/src/types/variant.rs +++ b/src/query/expression/src/types/variant.rs @@ -31,6 +31,7 @@ use crate::types::map::KvPair; use crate::types::AnyType; use crate::types::ArgType; use crate::types::DataType; +use crate::types::DecimalSize; use crate::types::GenericMap; use crate::types::ValueType; use crate::values::Column; @@ -95,7 +96,10 @@ impl ValueType for VariantType { } } - fn try_upcast_column_builder(builder: Self::ColumnBuilder) -> Option { + fn try_upcast_column_builder( + builder: Self::ColumnBuilder, + _decimal_size: Option, + ) -> Option { Some(ColumnBuilder::Variant(builder)) } diff --git a/src/query/functions/src/aggregates/aggregate_quantile_cont.rs b/src/query/functions/src/aggregates/aggregate_quantile_cont.rs index 6ea6461e66fe2..7819e3b1d5819 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_cont.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_cont.rs @@ -20,6 +20,9 @@ use borsh::BorshSerialize; use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::type_check::check_number; +use databend_common_expression::types::array::ArrayColumnBuilder; +use databend_common_expression::types::decimal::Decimal; +use databend_common_expression::types::decimal::DecimalType; use databend_common_expression::types::number::*; use databend_common_expression::types::*; use databend_common_expression::with_number_mapped_type; @@ -28,6 +31,7 @@ use databend_common_expression::Expr; use databend_common_expression::FunctionContext; use databend_common_expression::Scalar; use databend_common_expression::ScalarRef; +use ethnum::i256; use num_traits::AsPrimitive; use ordered_float::OrderedFloat; @@ -143,6 +147,152 @@ where } } +#[derive(BorshDeserialize, BorshSerialize)] +pub struct DecimalQuantileContState +where + T: ValueType, + T::Scalar: Decimal + BorshSerialize + BorshDeserialize, +{ + pub value: Vec, +} + +impl Default for DecimalQuantileContState +where + T: ValueType, + T::Scalar: BorshDeserialize + BorshSerialize + Decimal, +{ + fn default() -> Self { + Self { value: vec![] } + } +} + +impl DecimalQuantileContState +where + T: ValueType, + T::Scalar: Decimal + BorshSerialize + BorshDeserialize, +{ + fn compute_result(&mut self, whole: usize, frac: f64, value_len: usize) -> Result { + self.value.as_mut_slice().select_nth_unstable(whole); + let value = *self.value.get(whole).unwrap(); + let value1 = if whole + 1 >= value_len { + value + } else { + self.value.as_mut_slice().select_nth_unstable(whole + 1); + *self.value.get(whole + 1).unwrap() + }; + + let result = value1 + .checked_sub(value) + .and_then(|sub_result| sub_result.checked_mul(Decimal::from_float(frac))) + .and_then(|mul_result| value.checked_add(mul_result)); + + match result { + Some(r) => Ok(r), + None => Err(ErrorCode::Overflow("Decimal overflow when interpolate")), + } + } +} + +impl UnaryState> for DecimalQuantileContState +where + T: ValueType, + T::Scalar: Decimal + BorshSerialize + BorshDeserialize, +{ + fn add(&mut self, other: T::ScalarRef<'_>) -> Result<()> { + self.value.push(T::to_owned_scalar(other)); + Ok(()) + } + + fn merge(&mut self, rhs: &Self) -> Result<()> { + self.value.extend( + rhs.value + .iter() + .map(|v| T::to_owned_scalar(T::to_scalar_ref(v))), + ); + Ok(()) + } + + fn merge_result( + &mut self, + builder: &mut ArrayColumnBuilder, + function_data: Option<&dyn FunctionData>, + ) -> Result<()> { + let value_len = self.value.len(); + let quantile_cont_data = unsafe { + function_data + .unwrap() + .as_any() + .downcast_ref_unchecked::() + }; + + if quantile_cont_data.levels.len() > 1 { + let indices = quantile_cont_data + .levels + .iter() + .map(|level| libm::modf((value_len - 1) as f64 * (*level))) + .collect::>(); + + for (frac, whole) in indices { + let whole = whole as usize; + if whole >= value_len { + builder.push_default(); + } else { + let n = self.compute_result(whole, frac, value_len)?; + builder.put_item(T::to_scalar_ref(&n)); + } + } + builder.commit_row(); + } + + Ok(()) + } +} + +impl UnaryState for DecimalQuantileContState +where + T: ValueType, + T::Scalar: Decimal + BorshSerialize + BorshDeserialize, +{ + fn add(&mut self, other: T::ScalarRef<'_>) -> Result<()> { + self.value.push(T::to_owned_scalar(other)); + Ok(()) + } + + fn merge(&mut self, rhs: &Self) -> Result<()> { + self.value.extend( + rhs.value + .iter() + .map(|v| T::to_owned_scalar(T::to_scalar_ref(v))), + ); + Ok(()) + } + + fn merge_result( + &mut self, + builder: &mut T::ColumnBuilder, + function_data: Option<&dyn FunctionData>, + ) -> Result<()> { + let value_len = self.value.len(); + let quantile_cont_data = unsafe { + function_data + .unwrap() + .as_any() + .downcast_ref_unchecked::() + }; + + let (frac, whole) = libm::modf((value_len - 1) as f64 * quantile_cont_data.levels[0]); + let whole = whole as usize; + if whole >= value_len { + T::push_default(builder); + } else { + let n = self.compute_result(whole, frac, value_len)?; + T::push_item(builder, T::to_scalar_ref(&n)); + } + + Ok(()) + } +} + pub(crate) fn get_levels(params: &Vec) -> Result> { let levels = if params.len() == 1 { let level: F64 = check_number( @@ -242,6 +392,73 @@ pub fn try_create_aggregate_quantile_cont_function( } } + DataType::Decimal(DecimalDataType::Decimal128(s)) => { + let decimal_size = DecimalSize { + precision: s.precision, + scale: s.scale, + }; + let data_type = DataType::Decimal(DecimalDataType::from_size(decimal_size)?); + if params.len() > 1 { + let func = AggregateUnaryFunction::< + DecimalQuantileContState>, + DecimalType, + ArrayType>, + >::try_create( + display_name, + DataType::Array(Box::new(data_type)), + params, + arguments[0].clone(), + ) + .with_function_data(Box::new(QuantileData { levels })) + .with_need_drop(true); + Ok(Arc::new(func)) + } else { + let func = AggregateUnaryFunction::< + DecimalQuantileContState>, + DecimalType, + DecimalType, + >::try_create( + display_name, data_type, params, arguments[0].clone() + ) + .with_function_data(Box::new(QuantileData { levels })) + .with_need_drop(true); + Ok(Arc::new(func)) + } + } + DataType::Decimal(DecimalDataType::Decimal256(s)) => { + let decimal_size = DecimalSize { + precision: s.precision, + scale: s.scale, + }; + let data_type = DataType::Decimal(DecimalDataType::from_size(decimal_size)?); + if params.len() > 1 { + let func = AggregateUnaryFunction::< + DecimalQuantileContState>, + DecimalType, + ArrayType>, + >::try_create( + display_name, + DataType::Array(Box::new(data_type)), + params, + arguments[0].clone(), + ) + .with_function_data(Box::new(QuantileData { levels })) + .with_need_drop(true); + Ok(Arc::new(func)) + } else { + let func = AggregateUnaryFunction::< + DecimalQuantileContState>, + DecimalType, + DecimalType, + >::try_create( + display_name, data_type, params, arguments[0].clone() + ) + .with_function_data(Box::new(QuantileData { levels })) + .with_need_drop(true); + Ok(Arc::new(func)) + } + } + _ => Err(ErrorCode::BadDataValueType(format!( "{} does not support type '{:?}'", display_name, arguments[0] diff --git a/src/query/functions/src/aggregates/aggregate_unary.rs b/src/query/functions/src/aggregates/aggregate_unary.rs index c68c95c3bf4f9..cb6b5dcee657c 100644 --- a/src/query/functions/src/aggregates/aggregate_unary.rs +++ b/src/query/functions/src/aggregates/aggregate_unary.rs @@ -14,7 +14,6 @@ use std::alloc::Layout; use std::any::Any; -use std::any::TypeId; use std::fmt::Display; use std::fmt::Formatter; use std::marker::PhantomData; @@ -23,10 +22,8 @@ use std::sync::Arc; use databend_common_arrow::arrow::bitmap::Bitmap; use databend_common_base::base::take_mut; use databend_common_exception::Result; -use databend_common_expression::types::decimal::Decimal128Type; -use databend_common_expression::types::decimal::Decimal256Type; -use databend_common_expression::types::decimal::DecimalColumnBuilder; use databend_common_expression::types::DataType; +use databend_common_expression::types::DecimalSize; use databend_common_expression::types::ValueType; use databend_common_expression::AggregateFunction; use databend_common_expression::AggregateFunctionRef; @@ -133,39 +130,36 @@ where } fn do_merge_result(&self, state: &mut S, builder: &mut ColumnBuilder) -> Result<()> { - match builder { - // current decimal implementation hard do upcast_builder, we do downcast manually. - ColumnBuilder::Decimal(b) => match b { - DecimalColumnBuilder::Decimal128(_, _) => { - debug_assert!(TypeId::of::() == TypeId::of::()); - let builder = R::try_downcast_builder(builder).unwrap(); - state.merge_result(builder, self.function_data.as_deref()) - } - DecimalColumnBuilder::Decimal256(_, _) => { - debug_assert!(TypeId::of::() == TypeId::of::()); - let builder = R::try_downcast_builder(builder).unwrap(); - state.merge_result(builder, self.function_data.as_deref()) - } - }, - _ => { - // some `ValueType` like `NullableType` need ownership to downcast builder, - // so here we using an unsafe way to take the ownership of builder. - // See [`take_mut`] for details. - if let Some(builder) = R::try_downcast_builder(builder) { - state.merge_result(builder, self.function_data.as_deref()) - } else { - take_mut(builder, |builder| { - let mut builder = R::try_downcast_owned_builder(builder).unwrap(); - let res = state.merge_result(&mut builder, self.function_data.as_deref()); - - (res, R::try_upcast_column_builder(builder).unwrap()) - }) - } - } + let decimal_size = check_decimal(builder); + // some `ValueType` like `NullableType` need ownership to downcast builder, + // so here we using an unsafe way to take the ownership of builder. + // See [`take_mut`] for details. + if let Some(builder) = R::try_downcast_builder(builder) { + state.merge_result(builder, self.function_data.as_deref()) + } else { + take_mut(builder, |builder| { + let mut builder = R::try_downcast_owned_builder(builder).unwrap(); + let res = state.merge_result(&mut builder, self.function_data.as_deref()); + + ( + res, + R::try_upcast_column_builder(builder, decimal_size).unwrap(), + ) + }) } } } +fn check_decimal(builder: &ColumnBuilder) -> Option { + match builder { + ColumnBuilder::Decimal(b) => Some(b.decimal_size()), + ColumnBuilder::Array(box b) => check_decimal(&b.builder), + ColumnBuilder::Nullable(box b) => check_decimal(&b.builder), + ColumnBuilder::Map(box b) => check_decimal(&b.builder), + _ => None, + } +} + impl AggregateFunction for AggregateUnaryFunction where S: UnaryState + 'static, diff --git a/tests/sqllogictests/suites/query/02_function/02_0000_function_aggregate_mix.test b/tests/sqllogictests/suites/query/02_function/02_0000_function_aggregate_mix.test index 37222be32635b..3899932413745 100644 --- a/tests/sqllogictests/suites/query/02_function/02_0000_function_aggregate_mix.test +++ b/tests/sqllogictests/suites/query/02_function/02_0000_function_aggregate_mix.test @@ -232,6 +232,11 @@ SELECT quantile_cont(0, 0.5, 0.6, 1)(number) from numbers_mt(10000); ---- [0.0,4999.5,5999.4,9999.0] +query T +SELECT quantile_cont(0, 0.5, 0.6, 1)(number::decimal(10,2)) from numbers_mt(10000); +---- +[0.00,4999.00,5999.00,9999.00] + statement error 1010 SELECT quantile_cont(5)(number) from numbers_mt(10000) @@ -240,6 +245,11 @@ SELECT quantile_disc(0, 0.5, 0.6, 1)(number) from numbers_mt(10000); ---- [0,4999,5999,9999] +query T +SELECT quantile_disc(0, 0.5, 0.6, 1)(number::decimal(10,2)) from numbers_mt(10000); +---- +[0.00,4999.00,5999.00,9999.00] + query F SELECT quantile_tdigest(0.6)(number) from numbers_mt(10000) ----