From 5bf456803b89244390c127753c967700cf5606f8 Mon Sep 17 00:00:00 2001 From: Yijun Zhao Date: Mon, 20 Nov 2023 20:09:04 +0800 Subject: [PATCH 1/5] refactor kurtosis --- .../functions/src/aggregates/aggregate_avg.rs | 2 +- .../src/aggregates/aggregate_kurtosis.rs | 208 ++++-------------- .../src/aggregates/aggregate_skewness.rs | 29 +-- .../functions/src/aggregates/aggregate_sum.rs | 22 +- .../tests/it/aggregates/testdata/agg.txt | 2 +- .../it/aggregates/testdata/agg_group_by.txt | 4 +- .../tests/it/scalars/testdata/array.txt | 20 +- .../02_0000_function_aggregate_mix.test | 6 +- 8 files changed, 65 insertions(+), 228 deletions(-) diff --git a/src/query/functions/src/aggregates/aggregate_avg.rs b/src/query/functions/src/aggregates/aggregate_avg.rs index 04c7bcfe41af1..f61cf3bec00a5 100644 --- a/src/query/functions/src/aggregates/aggregate_avg.rs +++ b/src/query/functions/src/aggregates/aggregate_avg.rs @@ -86,7 +86,7 @@ where fn merge_result( &mut self, - builder: &mut ::ColumnBuilder, + builder: &mut Vec, _function_data: Option<&dyn FunctionData>, ) -> Result<()> { let value = self.value.as_() / (self.count as f64); diff --git a/src/query/functions/src/aggregates/aggregate_kurtosis.rs b/src/query/functions/src/aggregates/aggregate_kurtosis.rs index 131f34317664b..7e72b2ff78ab3 100644 --- a/src/query/functions/src/aggregates/aggregate_kurtosis.rs +++ b/src/query/functions/src/aggregates/aggregate_kurtosis.rs @@ -12,20 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::alloc::Layout; -use std::fmt::Display; -use std::fmt::Formatter; -use std::marker::PhantomData; -use std::sync::Arc; - -use common_arrow::arrow::bitmap::Bitmap; use common_exception::ErrorCode; use common_exception::Result; use common_expression::types::number::*; use common_expression::types::*; use common_expression::with_number_mapped_type; -use common_expression::Column; -use common_expression::ColumnBuilder; use common_expression::Scalar; use num_traits::AsPrimitive; use serde::Deserialize; @@ -33,11 +24,12 @@ use serde::Serialize; use super::deserialize_state; use super::serialize_state; +use super::AggregateUnaryFunction; +use super::FunctionData; +use super::UnaryState; use crate::aggregates::aggregate_function_factory::AggregateFunctionDescription; use crate::aggregates::assert_unary_arguments; -use crate::aggregates::AggregateFunction; use crate::aggregates::AggregateFunctionRef; -use crate::aggregates::StateAddr; #[derive(Default, Serialize, Deserialize)] struct KurtosisState { @@ -48,44 +40,46 @@ struct KurtosisState { pub sum_four: f64, } -impl KurtosisState { - fn new() -> Self { - Self::default() - } - - #[inline(always)] - fn add(&mut self, other: f64) { +impl UnaryState for KurtosisState +where + T: ValueType + Sync + Send, + T::Scalar: AsPrimitive, +{ + fn add(&mut self, other: T::ScalarRef<'_>) -> Result<()> { + let other = T::to_owned_scalar(other).as_(); self.n += 1; self.sum += other; self.sum_sqr += other.powi(2); self.sum_cub += other.powi(3); self.sum_four += other.powi(4); + Ok(()) } - fn merge(&mut self, rhs: &Self) { + fn merge(&mut self, rhs: &Self) -> Result<()> { if rhs.n == 0 { - return; + return Ok(()); } self.n += rhs.n; self.sum += rhs.sum; self.sum_sqr += rhs.sum_sqr; self.sum_cub += rhs.sum_cub; self.sum_four += rhs.sum_four; + Ok(()) } - fn merge_result(&mut self, builder: &mut ColumnBuilder) -> Result<()> { - let builder = match builder { - ColumnBuilder::Nullable(box b) => b, - _ => unreachable!(), - }; + fn merge_result( + &mut self, + builder: &mut Vec, + _function_data: Option<&dyn FunctionData>, + ) -> Result<()> { if self.n <= 3 { - builder.push_null(); + builder.push(F64::from(0_f64)); return Ok(()); } let n = self.n as f64; let temp = 1.0 / n; if self.sum_sqr - self.sum * self.sum * temp == 0.0 { - builder.push_null(); + builder.push(F64::from(0_f64)); return Ok(()); } let m4 = temp @@ -93,152 +87,26 @@ impl KurtosisState { + 6.0 * self.sum_sqr * self.sum * self.sum * temp * temp - 3.0 * self.sum.powi(4) * temp.powi(3)); let m2 = temp * (self.sum_sqr - self.sum * self.sum * temp); + if m2 <= 0.0 || (n - 2.0) * (n - 3.0) == 0.0 { + builder.push(F64::from(0_f64)); + return Ok(()); + } let value = (n - 1.0) * ((n + 1.0) * m4 / (m2 * m2) - 3.0 * (n - 1.0)) / ((n - 2.0) * (n - 3.0)); if value.is_infinite() || value.is_nan() { - builder.push_null(); + return Err(ErrorCode::SemanticError("Kurtosis is out of range!")); } else { - builder.push(Float64Type::upcast_scalar(value.into()).as_ref()); - } - Ok(()) - } -} - -#[derive(Clone)] -pub struct AggregateKurtosisFunction { - display_name: String, - return_type: DataType, - _arguments: Vec, - _t: PhantomData, -} - -impl Display for AggregateKurtosisFunction -where T: Number + AsPrimitive -{ - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.display_name) - } -} - -impl AggregateFunction for AggregateKurtosisFunction -where T: Number + AsPrimitive -{ - fn name(&self) -> &str { - "AggregateKurtosisFunction" - } - - fn return_type(&self) -> Result { - Ok(self.return_type.clone()) - } - - fn init_state(&self, place: StateAddr) { - place.write(KurtosisState::new) - } - - fn state_layout(&self) -> Layout { - Layout::new::() - } - - fn accumulate( - &self, - place: StateAddr, - columns: &[Column], - validity: Option<&Bitmap>, - _input_rows: usize, - ) -> Result<()> { - let column = NumberType::::try_downcast_column(&columns[0]).unwrap(); - let state = place.get::(); - match validity { - Some(bitmap) => { - for (value, is_valid) in column.iter().zip(bitmap.iter()) { - if is_valid { - state.add(value.as_()); - } - } - } - None => { - for value in column.iter() { - state.add(value.as_()); - } - } + builder.push(F64::from(value)); } - - Ok(()) - } - - fn accumulate_row(&self, place: StateAddr, columns: &[Column], row: usize) -> Result<()> { - let column = NumberType::::try_downcast_column(&columns[0]).unwrap(); - - let state = place.get::(); - let v: f64 = column[row].as_(); - state.add(v); - Ok(()) - } - - fn accumulate_keys( - &self, - places: &[StateAddr], - offset: usize, - columns: &[Column], - _input_rows: usize, - ) -> Result<()> { - let column = NumberType::::try_downcast_column(&columns[0]).unwrap(); - - column.iter().zip(places.iter()).for_each(|(value, place)| { - let place = place.next(offset); - let state = place.get::(); - let v: f64 = value.as_(); - state.add(v); - }); Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - serialize_state(writer, state) + fn serialize(&self, writer: &mut Vec) -> Result<()> { + serialize_state(writer, self) } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let rhs: KurtosisState = deserialize_state(reader)?; - state.merge(&rhs); - Ok(()) - } - - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let state = place.get::(); - let other = rhs.get::(); - state.merge(other); - Ok(()) - } - - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { - let state = place.get::(); - state.merge_result(builder) - } - - fn need_manual_drop_state(&self) -> bool { - false - } -} - -impl AggregateKurtosisFunction -where T: Number + AsPrimitive -{ - fn try_create( - display_name: &str, - return_type: DataType, - _params: Vec, - arguments: Vec, - ) -> Result> { - let func = AggregateKurtosisFunction:: { - display_name: display_name.to_string(), - return_type, - _arguments: arguments, - _t: PhantomData, - }; - - Ok(Arc::new(func)) + fn deserialize(reader: &mut &[u8]) -> Result { + deserialize_state::(reader) } } @@ -251,14 +119,12 @@ pub fn try_create_aggregate_kurtosis_function( with_number_mapped_type!(|NUM_TYPE| match &arguments[0] { DataType::Number(NumberDataType::NUM_TYPE) => { - let return_type = - DataType::Nullable(Box::new(DataType::Number(NumberDataType::Float64))); - AggregateKurtosisFunction::::try_create( - display_name, - return_type, - params, - arguments, - ) + let return_type = DataType::Number(NumberDataType::Float64); + AggregateUnaryFunction::< + KurtosisState, + NumberType, + Float64Type, + >::try_create_unary(display_name, return_type, params, arguments[0].clone()) } _ => Err(ErrorCode::BadDataValueType(format!( diff --git a/src/query/functions/src/aggregates/aggregate_skewness.rs b/src/query/functions/src/aggregates/aggregate_skewness.rs index 71dac47ec0b3e..c92fc6fbf95dd 100644 --- a/src/query/functions/src/aggregates/aggregate_skewness.rs +++ b/src/query/functions/src/aggregates/aggregate_skewness.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::marker::PhantomData; - use common_exception::ErrorCode; use common_exception::Result; use common_expression::types::number::*; @@ -33,32 +31,15 @@ use crate::aggregates::aggregate_function_factory::AggregateFunctionDescription; use crate::aggregates::aggregate_unary::AggregateUnaryFunction; use crate::aggregates::aggregate_unary::UnaryState; -#[derive(Serialize, Deserialize)] -pub struct SkewnessStateV2 { +#[derive(Default, Serialize, Deserialize)] +pub struct SkewnessStateV2 { pub n: u64, pub sum: f64, pub sum_sqr: f64, pub sum_cub: f64, - _ph: PhantomData, -} - -impl Default for SkewnessStateV2 -where - T: ValueType + Sync + Send, - T::Scalar: AsPrimitive, -{ - fn default() -> Self { - Self { - n: 0, - sum: 0.0, - sum_sqr: 0.0, - sum_cub: 0.0, - _ph: PhantomData, - } - } } -impl UnaryState for SkewnessStateV2 +impl UnaryState for SkewnessStateV2 where T: ValueType + Sync + Send, T::Scalar: AsPrimitive, @@ -85,7 +66,7 @@ where fn merge_result( &mut self, - builder: &mut ::ColumnBuilder, + builder: &mut Vec, _function_data: Option<&dyn FunctionData>, ) -> Result<()> { if self.n <= 2 { @@ -135,7 +116,7 @@ pub fn try_create_aggregate_skewness_function( DataType::Number(NumberDataType::NUM) => { let return_type = DataType::Number(NumberDataType::Float64); AggregateUnaryFunction::< - SkewnessStateV2>, + SkewnessStateV2, NumberType, Float64Type, >::try_create_unary(display_name, return_type, params, arguments[0].clone()) diff --git a/src/query/functions/src/aggregates/aggregate_sum.rs b/src/query/functions/src/aggregates/aggregate_sum.rs index 9261b7d776fdb..d814feff025a3 100644 --- a/src/query/functions/src/aggregates/aggregate_sum.rs +++ b/src/query/functions/src/aggregates/aggregate_sum.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::marker::PhantomData; - use common_arrow::arrow::bitmap::Bitmap; use common_exception::ErrorCode; use common_exception::Result; @@ -68,30 +66,25 @@ pub trait SumState: Serialize + DeserializeOwned + Send + Sync + Default + 'stat } #[derive(Deserialize, Serialize)] -pub struct NumberSumState +pub struct NumberSumState where R: ValueType { pub value: R::Scalar, - #[serde(skip)] - _t: PhantomData, } -impl Default for NumberSumState +impl Default for NumberSumState where - T: ValueType + Sync + Send, R: ValueType, - T::Scalar: Number + AsPrimitive, R::Scalar: Number + AsPrimitive + Serialize + DeserializeOwned + std::ops::AddAssign, { fn default() -> Self { - NumberSumState:: { + NumberSumState:: { value: R::Scalar::default(), - _t: PhantomData, } } } -impl UnaryState for NumberSumState +impl UnaryState for NumberSumState where T: ValueType + Sync + Send, R: ValueType, @@ -124,10 +117,7 @@ where fn deserialize(reader: &mut &[u8]) -> Result { let value = deserialize_state(reader)?; - Ok(Self { - value, - _t: PhantomData, - }) + Ok(Self { value }) } } @@ -212,7 +202,7 @@ pub fn try_create_aggregate_sum_function( type TSum = ::Sum; let return_type = NumberType::::data_type(); AggregateUnaryFunction::< - NumberSumState, NumberType>, + NumberSumState>, NumberType, NumberType, >::try_create_unary(display_name, return_type, params, arguments[0].clone()) diff --git a/src/query/functions/tests/it/aggregates/testdata/agg.txt b/src/query/functions/tests/it/aggregates/testdata/agg.txt index 86eed667c24bf..e7dd5320493f2 100644 --- a/src/query/functions/tests/it/aggregates/testdata/agg.txt +++ b/src/query/functions/tests/it/aggregates/testdata/agg.txt @@ -720,7 +720,7 @@ evaluation (internal): | Column | Data | +--------+-------------------------------------------------------------------------+ | x_null | NullableColumn { column: UInt64([1, 2, 3, 4]), validity: [0b____0011] } | -| Output | NullableColumn { column: Float64([0]), validity: [0b_______0] } | +| Output | NullableColumn { column: Float64([0]), validity: [0b_______1] } | +--------+-------------------------------------------------------------------------+ diff --git a/src/query/functions/tests/it/aggregates/testdata/agg_group_by.txt b/src/query/functions/tests/it/aggregates/testdata/agg_group_by.txt index 90d0dc339fe01..6d535e908566d 100644 --- a/src/query/functions/tests/it/aggregates/testdata/agg_group_by.txt +++ b/src/query/functions/tests/it/aggregates/testdata/agg_group_by.txt @@ -710,7 +710,7 @@ evaluation (internal): | Column | Data | +--------+--------------------------------------------------------------------+ | a | Int64([4, 3, 2, 1]) | -| Output | NullableColumn { column: Float64([0, 0]), validity: [0b______00] } | +| Output | NullableColumn { column: Float64([0, 0]), validity: [0b______11] } | +--------+--------------------------------------------------------------------+ @@ -720,7 +720,7 @@ evaluation (internal): | Column | Data | +--------+-------------------------------------------------------------------------+ | x_null | NullableColumn { column: UInt64([1, 2, 3, 4]), validity: [0b____0011] } | -| Output | NullableColumn { column: Float64([0, 0]), validity: [0b______00] } | +| Output | NullableColumn { column: Float64([0, 0]), validity: [0b______11] } | +--------+-------------------------------------------------------------------------+ diff --git a/src/query/functions/tests/it/scalars/testdata/array.txt b/src/query/functions/tests/it/scalars/testdata/array.txt index b255b8a5b1f59..e77e2f46ed889 100644 --- a/src/query/functions/tests/it/scalars/testdata/array.txt +++ b/src/query/functions/tests/it/scalars/testdata/array.txt @@ -1983,19 +1983,19 @@ output : NULL ast : array_kurtosis([1, 2, 3]) raw expr : array_kurtosis(array(1, 2, 3)) checked expr : array_kurtosis(array(1_u8, 2_u8, 3_u8)) -optimized expr : NULL +optimized expr : 0_f64 output type : Float64 NULL -output domain : {NULL} -output : NULL +output domain : {0..=0} +output : 0 ast : array_kurtosis([NULL, 3, 2, 1]) raw expr : array_kurtosis(array(NULL, 3, 2, 1)) checked expr : array_kurtosis(array(CAST(NULL AS UInt8 NULL), CAST(3_u8 AS UInt8 NULL), CAST(2_u8 AS UInt8 NULL), CAST(1_u8 AS UInt8 NULL))) -optimized expr : NULL +optimized expr : 0_f64 output type : Float64 NULL -output domain : {NULL} -output : NULL +output domain : {0..=0} +output : 0 ast : array_kurtosis([a, b, c, d]) @@ -2034,9 +2034,9 @@ evaluation: | Type | UInt64 NULL | UInt64 NULL | UInt64 NULL | UInt64 NULL | Float64 NULL | | Domain | {0..=4} ∪ {NULL} | {0..=6} ∪ {NULL} | {3..=9} | {0..=6} ∪ {NULL} | Unknown | | Row 0 | 1 | 2 | 3 | 4 | -1.2 | -| Row 1 | 2 | NULL | 7 | 6 | NULL | -| Row 2 | NULL | 5 | 8 | 5 | NULL | -| Row 3 | 4 | 6 | 9 | NULL | NULL | +| Row 1 | 2 | NULL | 7 | 6 | 0 | +| Row 2 | NULL | 5 | 8 | 5 | 0 | +| Row 3 | 4 | 6 | 9 | NULL | 0 | +--------+------------------+------------------+-------------+------------------+--------------+ evaluation (internal): +--------+-----------------------------------------------------------------------------+ @@ -2046,7 +2046,7 @@ evaluation (internal): | b | NullableColumn { column: UInt64([2, 0, 5, 6]), validity: [0b____1101] } | | c | NullableColumn { column: UInt64([3, 7, 8, 9]), validity: [0b____1111] } | | d | NullableColumn { column: UInt64([4, 6, 5, 0]), validity: [0b____0111] } | -| Output | NullableColumn { column: Float64([-1.2, 0, 0, 0]), validity: [0b____0001] } | +| Output | NullableColumn { column: Float64([-1.2, 0, 0, 0]), validity: [0b____1111] } | +--------+-----------------------------------------------------------------------------+ diff --git a/tests/sqllogictests/suites/query/02_function/02_0000_function_aggregate_mix.test b/tests/sqllogictests/suites/query/02_function/02_0000_function_aggregate_mix.test index 7b41b2ecc1e8d..16f6390719de2 100644 --- a/tests/sqllogictests/suites/query/02_function/02_0000_function_aggregate_mix.test +++ b/tests/sqllogictests/suites/query/02_function/02_0000_function_aggregate_mix.test @@ -278,7 +278,7 @@ SELECT array_sort(list(id), 'asc'), array_sort(list(arr), 'asc') FROM t2; query I select kurtosis(10) from numbers(5) ---- -NULL +0.0 statement ok create table aggr(k int, v int, v2 int null); @@ -305,8 +305,8 @@ select kurtosis(k), kurtosis(v), kurtosis(v2) from aggr; query I select kurtosis(v2) from aggr group by v order by v; ---- -NULL -NULL +0.0 +0.0 NULL -3.9775993237531697 From 346a322837779b9b4c6ac51730fb266e0fd66136 Mon Sep 17 00:00:00 2001 From: Yijun Zhao Date: Thu, 23 Nov 2023 17:28:46 +0800 Subject: [PATCH 2/5] add quantile cont --- .../src/aggregates/aggregate_quantile_cont.rs | 359 +++++++----------- 1 file changed, 136 insertions(+), 223 deletions(-) diff --git a/src/query/functions/src/aggregates/aggregate_quantile_cont.rs b/src/query/functions/src/aggregates/aggregate_quantile_cont.rs index cb62d0cad7a3b..e8fb4c4821e2f 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_cont.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_cont.rs @@ -12,13 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::alloc::Layout; -use std::fmt::Display; -use std::fmt::Formatter; -use std::marker::PhantomData; +use std::any::Any; use std::sync::Arc; -use common_arrow::arrow::bitmap::Bitmap; use common_exception::ErrorCode; use common_exception::Result; use common_expression::type_check::check_number; @@ -26,7 +22,6 @@ use common_expression::types::number::*; use common_expression::types::*; use common_expression::with_number_mapped_type; use common_expression::Column; -use common_expression::ColumnBuilder; use common_expression::Expr; use common_expression::FunctionContext; use common_expression::Scalar; @@ -38,30 +33,57 @@ use serde::Serialize; use super::deserialize_state; use super::serialize_state; +use super::AggregateUnaryFunction; +use super::FunctionData; +use super::UnaryState; use crate::aggregates::aggregate_function_factory::AggregateFunctionDescription; use crate::aggregates::assert_params; use crate::aggregates::assert_unary_arguments; -use crate::aggregates::AggregateFunction; use crate::aggregates::AggregateFunctionRef; -use crate::aggregates::StateAddr; use crate::BUILTIN_FUNCTIONS; const MEDIAN: u8 = 0; const QUANTILE_CONT: u8 = 1; +struct QuantileContData { + pub levels: Vec, +} + +impl FunctionData for QuantileContData { + fn as_any(&self) -> &dyn Any { + self + } +} #[derive(Default, Serialize, Deserialize)] struct QuantileContState { pub value: Vec>, } impl QuantileContState { - fn new() -> Self { - Self::default() + fn compute_result(&mut self, whole: usize, frac: f64, value_len: usize) -> f64 { + self.value.as_mut_slice().select_nth_unstable(whole); + let value = self.value.get(whole).unwrap().0; + let value1 = if whole + 1 >= value_len { + value + } else { + self.value.as_mut_slice().select_nth_unstable(whole + 1); + self.value.get(whole + 1).unwrap().0 + }; + + value + (value1 - value) * frac } +} - #[inline(always)] - fn add(&mut self, other: f64) { +impl UnaryState for QuantileContState +where + T: ValueType, + T::Scalar: Number + AsPrimitive, + R: ValueType, +{ + fn add(&mut self, other: T::ScalarRef<'_>) -> Result<()> { + let other = T::to_owned_scalar(other).as_(); self.value.push(other.into()); + Ok(()) } fn merge(&mut self, rhs: &Self) -> Result<()> { @@ -69,247 +91,121 @@ impl QuantileContState { Ok(()) } - fn merge_result(&mut self, builder: &mut ColumnBuilder, levels: Vec) -> Result<()> { + fn merge_result( + &mut self, + builder: &mut R::ColumnBuilder, + function_data: Option<&dyn FunctionData>, + ) -> Result<()> { let value_len = self.value.len(); - if levels.len() > 1 { - let builder = match builder { - ColumnBuilder::Array(box b) => b, - _ => unreachable!(), - }; - let indices = levels + let quantile_cont_data = unsafe { + function_data + .unwrap() + .as_any() + .downcast_ref_unchecked::() + }; + if quantile_cont_data.levels.len() > 1 { + let indices = quantile_cont_data + .levels .iter() .map(|level| libm::modf((value_len - 1) as f64 * (*level))) .collect::>(); + // as we already know the return type `R` here is `ArrayType` + // we should has a inner builder to build `Float64Type::Column` and + // provide to `R::ColumnBuilder` when `push_item` + let mut inner_column_builder = NumberColumnBuilder::Float64(Vec::new()); for (frac, whole) in indices { let whole = whole as usize; if whole >= value_len { - builder.push_default(); + R::push_default(builder); } else { let n = self.compute_result(whole, frac, value_len); - builder.put_item(ScalarRef::Number(NumberScalar::Float64(n.into()))); + inner_column_builder.push(NumberScalar::Float64(n.into())); } } - builder.commit_row(); + let float64_column = inner_column_builder.build(); + R::push_item( + builder, + R::try_downcast_scalar(&ScalarRef::Array(Column::Number(float64_column))).unwrap(), + ) } else { - let builder = NumberType::::try_downcast_builder(builder).unwrap(); - let (frac, whole) = libm::modf((value_len - 1) as f64 * levels[0]); + let (frac, whole) = libm::modf((value_len - 1) as f64 * quantile_cont_data.levels[0]); let whole = whole as usize; if whole >= value_len { - builder.push(0_f64.into()); + R::push_default(builder); } else { let n = self.compute_result(whole, frac, value_len); - builder.push(n.into()); + R::push_item( + builder, + R::try_downcast_scalar(&ScalarRef::Number(NumberScalar::Float64(n.into()))) + .unwrap(), + ); } } Ok(()) } - fn compute_result(&mut self, whole: usize, frac: f64, value_len: usize) -> f64 { - self.value.as_mut_slice().select_nth_unstable(whole); - let value = self.value.get(whole).unwrap().0; - let value1 = if whole + 1 >= value_len { - value - } else { - self.value.as_mut_slice().select_nth_unstable(whole + 1); - self.value.get(whole + 1).unwrap().0 - }; - - value + (value1 - value) * frac + fn serialize(&self, writer: &mut Vec) -> Result<()> { + serialize_state(writer, self) } -} - -#[derive(Clone)] -pub struct AggregateQuantileContFunction { - display_name: String, - return_type: DataType, - levels: Vec, - _arguments: Vec, - _t: PhantomData, -} -impl Display for AggregateQuantileContFunction -where T: Number + AsPrimitive -{ - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.display_name) + fn deserialize(reader: &mut &[u8]) -> Result + where Self: Sized { + deserialize_state(reader) } } -impl AggregateFunction for AggregateQuantileContFunction -where T: Number + AsPrimitive -{ - fn name(&self) -> &str { - "AggregateQuantileContFunction" - } - - fn return_type(&self) -> Result { - Ok(self.return_type.clone()) - } - - fn init_state(&self, place: StateAddr) { - place.write(QuantileContState::new) - } - - fn state_layout(&self) -> Layout { - Layout::new::() - } - - fn accumulate( - &self, - place: StateAddr, - columns: &[Column], - validity: Option<&Bitmap>, - _input_rows: usize, - ) -> Result<()> { - let column = NumberType::::try_downcast_column(&columns[0]).unwrap(); - let state = place.get::(); - match validity { - Some(bitmap) => { - for (value, is_valid) in column.iter().zip(bitmap.iter()) { - if is_valid { - state.add(value.as_()); - } - } - } - None => { - for value in column.iter() { - state.add(value.as_()); - } - } +fn get_levels(params: &Vec) -> Result> { + let levels = if params.len() == 1 { + let level: F64 = check_number( + None, + &FunctionContext::default(), + &Expr::::Constant { + span: None, + scalar: params[0].clone(), + data_type: params[0].as_ref().infer_data_type(), + }, + &BUILTIN_FUNCTIONS, + )?; + let level = level.0; + if !(0.0..=1.0).contains(&level) { + return Err(ErrorCode::BadDataValueType(format!( + "level range between [0, 1], got: {:?}", + level + ))); } - - Ok(()) - } - - fn accumulate_row(&self, place: StateAddr, columns: &[Column], row: usize) -> Result<()> { - let column = NumberType::::try_downcast_column(&columns[0]).unwrap(); - - let state = place.get::(); - let v: f64 = column[row].as_(); - state.add(v); - Ok(()) - } - - fn accumulate_keys( - &self, - places: &[StateAddr], - offset: usize, - columns: &[Column], - _input_rows: usize, - ) -> Result<()> { - let column = NumberType::::try_downcast_column(&columns[0]).unwrap(); - - column.iter().zip(places.iter()).for_each(|(value, place)| { - let place = place.next(offset); - let state = place.get::(); - let v: f64 = value.as_(); - state.add(v); - }); - Ok(()) - } - - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - serialize_state(writer, state) - } - - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let rhs: QuantileContState = deserialize_state(reader)?; - state.merge(&rhs) - } - - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let state = place.get::(); - let other = rhs.get::(); - state.merge(other) - } - - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { - let state = place.get::(); - state.merge_result(builder, self.levels.clone()) - } - - fn need_manual_drop_state(&self) -> bool { - true - } - - unsafe fn drop_state(&self, place: StateAddr) { - let state = place.get::(); - std::ptr::drop_in_place(state); - } -} - -impl AggregateQuantileContFunction -where T: Number + AsPrimitive -{ - fn try_create( - display_name: &str, - return_type: DataType, - params: Vec, - arguments: Vec, - ) -> Result> { - let levels = if params.len() == 1 { + vec![level] + } else if params.is_empty() { + vec![0.5f64] + } else { + let mut levels = Vec::with_capacity(params.len()); + for param in params { let level: F64 = check_number( None, &FunctionContext::default(), - &Expr::::Constant { + &Expr::::Cast { span: None, - scalar: params[0].clone(), - data_type: params[0].as_ref().infer_data_type(), + is_try: false, + expr: Box::new(Expr::Constant { + span: None, + scalar: param.clone(), + data_type: param.as_ref().infer_data_type(), + }), + dest_type: DataType::Number(NumberDataType::Float64), }, &BUILTIN_FUNCTIONS, )?; let level = level.0; if !(0.0..=1.0).contains(&level) { return Err(ErrorCode::BadDataValueType(format!( - "level range between [0, 1], got: {:?}", + "level range between [0, 1], got: {:?} in levels", level ))); } - vec![level] - } else if params.is_empty() { - vec![0.5f64] - } else { - let mut levels = Vec::with_capacity(params.len()); - for param in params { - let level: F64 = check_number( - None, - &FunctionContext::default(), - &Expr::::Cast { - span: None, - is_try: false, - expr: Box::new(Expr::Constant { - span: None, - scalar: param.clone(), - data_type: param.as_ref().infer_data_type(), - }), - dest_type: DataType::Number(NumberDataType::Float64), - }, - &BUILTIN_FUNCTIONS, - )?; - let level = level.0; - if !(0.0..=1.0).contains(&level) { - return Err(ErrorCode::BadDataValueType(format!( - "level range between [0, 1], got: {:?} in levels", - level - ))); - } - levels.push(level); - } - levels - }; - - let func = AggregateQuantileContFunction:: { - display_name: display_name.to_string(), - return_type, - levels, - _arguments: arguments, - _t: PhantomData, - }; - - Ok(Arc::new(func)) - } + levels.push(level); + } + levels + }; + Ok(levels) } pub fn try_create_aggregate_quantile_cont_function( @@ -323,19 +219,36 @@ pub fn try_create_aggregate_quantile_cont_function( assert_unary_arguments(display_name, arguments.len())?; + let levels = get_levels(¶ms)?; + with_number_mapped_type!(|NUM_TYPE| match &arguments[0] { DataType::Number(NumberDataType::NUM_TYPE) => { - let return_type = if params.len() > 1 { - DataType::Array(Box::new(DataType::Number(NumberDataType::Float64))) + if params.len() > 1 { + let return_type = + DataType::Array(Box::new(DataType::Number(NumberDataType::Float64))); + let func = AggregateUnaryFunction::< + QuantileContState, + NumberType, + ArrayType, + >::try_create( + display_name, return_type, params, arguments[0].clone() + ) + .with_function_data(Box::new(QuantileContData { levels })); + + Ok(Arc::new(func)) } else { - DataType::Number(NumberDataType::Float64) - }; - AggregateQuantileContFunction::::try_create( - display_name, - return_type, - params, - arguments, - ) + let return_type = DataType::Number(NumberDataType::Float64); + let func = AggregateUnaryFunction::< + QuantileContState, + NumberType, + Float64Type, + >::try_create( + display_name, return_type, params, arguments[0].clone() + ) + .with_function_data(Box::new(QuantileContData { levels })); + + Ok(Arc::new(func)) + } } _ => Err(ErrorCode::BadDataValueType(format!( From de15246a19e9596a712bd44a0d124e9272d571b6 Mon Sep 17 00:00:00 2001 From: Yijun Zhao Date: Fri, 24 Nov 2023 11:15:36 +0800 Subject: [PATCH 3/5] add quantile disc --- .../src/aggregates/aggregate_quantile_cont.rs | 16 +- .../src/aggregates/aggregate_quantile_disc.rs | 416 +++++++----------- 2 files changed, 164 insertions(+), 268 deletions(-) diff --git a/src/query/functions/src/aggregates/aggregate_quantile_cont.rs b/src/query/functions/src/aggregates/aggregate_quantile_cont.rs index e8fb4c4821e2f..8108f458d62bd 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_cont.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_cont.rs @@ -45,11 +45,11 @@ use crate::BUILTIN_FUNCTIONS; const MEDIAN: u8 = 0; const QUANTILE_CONT: u8 = 1; -struct QuantileContData { - pub levels: Vec, +pub(crate) struct QuantileData { + pub(crate) levels: Vec, } -impl FunctionData for QuantileContData { +impl FunctionData for QuantileData { fn as_any(&self) -> &dyn Any { self } @@ -101,7 +101,7 @@ where function_data .unwrap() .as_any() - .downcast_ref_unchecked::() + .downcast_ref_unchecked::() }; if quantile_cont_data.levels.len() > 1 { let indices = quantile_cont_data @@ -154,7 +154,7 @@ where } } -fn get_levels(params: &Vec) -> Result> { +pub(crate) fn get_levels(params: &Vec) -> Result> { let levels = if params.len() == 1 { let level: F64 = check_number( None, @@ -233,7 +233,8 @@ pub fn try_create_aggregate_quantile_cont_function( >::try_create( display_name, return_type, params, arguments[0].clone() ) - .with_function_data(Box::new(QuantileContData { levels })); + .with_function_data(Box::new(QuantileData { levels })) + .with_need_drop(true); Ok(Arc::new(func)) } else { @@ -245,7 +246,8 @@ pub fn try_create_aggregate_quantile_cont_function( >::try_create( display_name, return_type, params, arguments[0].clone() ) - .with_function_data(Box::new(QuantileContData { levels })); + .with_function_data(Box::new(QuantileData { levels })) + .with_need_drop(true); Ok(Arc::new(func)) } diff --git a/src/query/functions/src/aggregates/aggregate_quantile_disc.rs b/src/query/functions/src/aggregates/aggregate_quantile_disc.rs index 6ebf7cfc02650..2a0bf4938b7b6 100644 --- a/src/query/functions/src/aggregates/aggregate_quantile_disc.rs +++ b/src/query/functions/src/aggregates/aggregate_quantile_disc.rs @@ -12,24 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::alloc::Layout; -use std::fmt::Display; -use std::fmt::Formatter; -use std::marker::PhantomData; use std::sync::Arc; -use common_arrow::arrow::bitmap::Bitmap; use common_exception::ErrorCode; use common_exception::Result; -use common_expression::type_check::check_number; +use common_expression::types::array::ArrayColumnBuilder; use common_expression::types::decimal::*; use common_expression::types::number::*; use common_expression::types::*; use common_expression::with_number_mapped_type; -use common_expression::Column; -use common_expression::ColumnBuilder; -use common_expression::Expr; -use common_expression::FunctionContext; use common_expression::Scalar; use ethnum::i256; use serde::de::DeserializeOwned; @@ -37,24 +28,17 @@ use serde::Deserialize; use serde::Serialize; use super::deserialize_state; +use super::get_levels; use super::serialize_state; +use super::AggregateUnaryFunction; +use super::FunctionData; +use super::QuantileData; +use super::UnaryState; use crate::aggregates::aggregate_function_factory::AggregateFunctionDescription; use crate::aggregates::assert_unary_arguments; -use crate::aggregates::AggregateFunction; use crate::aggregates::AggregateFunctionRef; -use crate::aggregates::StateAddr; use crate::with_simple_no_number_mapped_type; -use crate::BUILTIN_FUNCTIONS; -pub trait QuantileStateFunc: - Serialize + DeserializeOwned + Send + Sync + 'static -{ - fn new() -> Self; - fn add(&mut self, other: T::ScalarRef<'_>); - fn add_batch(&mut self, column: &T::Column, validity: Option<&Bitmap>) -> Result<()>; - fn merge(&mut self, rhs: &Self) -> Result<()>; - fn merge_result(&mut self, builder: &mut ColumnBuilder, levels: Vec) -> Result<()>; -} #[derive(Serialize, Deserialize)] struct QuantileState where @@ -64,6 +48,7 @@ where #[serde(bound(deserialize = "T::Scalar: DeserializeOwned"))] pub value: Vec, } + impl Default for QuantileState where T: ValueType, @@ -73,39 +58,17 @@ where Self { value: vec![] } } } -impl QuantileStateFunc for QuantileState + +impl UnaryState> for QuantileState where - T: ValueType, - T::Scalar: Serialize + DeserializeOwned + Send + Sync + Ord, + T: ValueType + Sync + Send, + T::Scalar: Serialize + DeserializeOwned + Sync + Send + Ord, { - fn new() -> Self { - Self::default() - } - fn add(&mut self, other: T::ScalarRef<'_>) { + fn add(&mut self, other: T::ScalarRef<'_>) -> Result<()> { self.value.push(T::to_owned_scalar(other)); - } - fn add_batch(&mut self, column: &T::Column, validity: Option<&Bitmap>) -> Result<()> { - let column_len = T::column_len(column); - if column_len == 0 { - return Ok(()); - } - let column_iter = T::iter_column(column); - if let Some(validity) = validity { - if validity.unset_bits() == column_len { - return Ok(()); - } - for (data, valid) in column_iter.zip(validity.iter()) { - if !valid { - continue; - } - self.add(data.clone()); - } - } else { - self.value - .extend(column_iter.map(|data| T::to_owned_scalar(data))); - } Ok(()) } + fn merge(&mut self, rhs: &Self) -> Result<()> { self.value.extend( rhs.value @@ -114,14 +77,22 @@ where ); Ok(()) } - fn merge_result(&mut self, builder: &mut ColumnBuilder, levels: Vec) -> Result<()> { + + fn merge_result( + &mut self, + builder: &mut ArrayColumnBuilder, + function_data: Option<&dyn FunctionData>, + ) -> Result<()> { let value_len = self.value.len(); - if levels.len() > 1 { - let builder = match builder { - ColumnBuilder::Array(box b) => b, - _ => unreachable!(), - }; - let indices = levels + let quantile_disc_data = unsafe { + function_data + .unwrap() + .as_any() + .downcast_ref_unchecked::() + }; + if quantile_disc_data.levels.len() > 1 { + let indices = quantile_disc_data + .levels .iter() .map(|level| ((value_len - 1) as f64 * (*level)).floor() as usize) .collect::>(); @@ -129,200 +100,80 @@ where if idx < value_len { self.value.as_mut_slice().select_nth_unstable(idx); let value = self.value.get(idx).unwrap(); - builder.put_item(T::upcast_scalar(value.clone()).as_ref()); + builder.put_item(T::to_scalar_ref(value)); } else { builder.push_default(); } } builder.commit_row(); - } else { - let builder = T::try_downcast_builder(builder).unwrap(); - let idx = ((value_len - 1) as f64 * levels[0]).floor() as usize; - if idx >= value_len { - T::push_default(builder); - } else { - self.value.as_mut_slice().select_nth_unstable(idx); - let value = self.value.get(idx).unwrap(); - T::push_item(builder, T::to_scalar_ref(value)); - } } Ok(()) } -} -#[derive(Clone)] -pub struct AggregateQuantileDiscFunction { - display_name: String, - return_type: DataType, - levels: Vec, - _arguments: Vec, - _t: PhantomData, - _state: PhantomData, -} -impl Display for AggregateQuantileDiscFunction -where - State: QuantileStateFunc, - T: Send + Sync + ValueType, -{ - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.display_name) + + fn serialize(&self, writer: &mut Vec) -> Result<()> { + serialize_state(writer, self) + } + + fn deserialize(reader: &mut &[u8]) -> Result + where Self: Sized { + deserialize_state(reader) } } -impl AggregateFunction for AggregateQuantileDiscFunction + +impl UnaryState for QuantileState where - T: ValueType + Send + Sync, - State: QuantileStateFunc, + T: ArgType + Sync + Send, + T::Scalar: Serialize + DeserializeOwned + Sync + Send + Ord, { - fn name(&self) -> &str { - "AggregateQuantileDiscFunction" - } - fn return_type(&self) -> Result { - Ok(self.return_type.clone()) - } - fn init_state(&self, place: StateAddr) { - place.write(|| State::new()) - } - fn state_layout(&self) -> Layout { - Layout::new::() - } - fn accumulate( - &self, - place: StateAddr, - columns: &[Column], - validity: Option<&Bitmap>, - _input_rows: usize, - ) -> Result<()> { - let column = T::try_downcast_column(&columns[0]).unwrap(); - let state = place.get::(); - state.add_batch(&column, validity) - } - fn accumulate_row(&self, place: StateAddr, columns: &[Column], row: usize) -> Result<()> { - let column = T::try_downcast_column(&columns[0]).unwrap(); - let v = T::index_column(&column, row); - if let Some(v) = v { - let state = place.get::(); - state.add(v) - } + fn add(&mut self, other: T::ScalarRef<'_>) -> Result<()> { + self.value.push(T::to_owned_scalar(other)); Ok(()) } - fn accumulate_keys( - &self, - places: &[StateAddr], - offset: usize, - columns: &[Column], - _input_rows: usize, - ) -> Result<()> { - let column = T::try_downcast_column(&columns[0]).unwrap(); - let column_iter = T::iter_column(&column); - column_iter.zip(places.iter()).for_each(|(v, place)| { - let addr = place.next(offset); - let state = addr.get::(); - state.add(v.clone()) - }); + + fn merge(&mut self, rhs: &Self) -> Result<()> { + self.value.extend( + rhs.value + .iter() + .map(|v| T::to_owned_scalar(T::to_scalar_ref(v))), + ); Ok(()) } - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - serialize_state(writer, state) - } - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let rhs: State = deserialize_state(reader)?; - state.merge(&rhs) - } + fn merge_result( + &mut self, + builder: &mut T::ColumnBuilder, + function_data: Option<&dyn FunctionData>, + ) -> Result<()> { + let value_len = self.value.len(); + let quantile_disc_data = unsafe { + function_data + .unwrap() + .as_any() + .downcast_ref_unchecked::() + }; - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let state = place.get::(); - let other = rhs.get::(); - state.merge(other) - } + let idx = ((value_len - 1) as f64 * quantile_disc_data.levels[0]).floor() as usize; + if idx >= value_len { + T::push_default(builder); + } else { + self.value.as_mut_slice().select_nth_unstable(idx); + let value = self.value.get(idx).unwrap(); + T::push_item(builder, T::to_scalar_ref(value)); + } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { - let state = place.get::(); - state.merge_result(builder, self.levels.clone()) + Ok(()) } - fn need_manual_drop_state(&self) -> bool { - true + fn serialize(&self, writer: &mut Vec) -> Result<()> { + serialize_state(writer, self) } - unsafe fn drop_state(&self, place: StateAddr) { - let state = place.get::(); - std::ptr::drop_in_place(state); - } -} -impl AggregateQuantileDiscFunction -where - State: QuantileStateFunc, - T: Send + Sync + ValueType, -{ - fn try_create( - display_name: &str, - return_type: DataType, - params: Vec, - arguments: Vec, - ) -> Result> { - let levels = if params.len() == 1 { - let level: F64 = check_number( - None, - &FunctionContext::default(), - &Expr::::Constant { - span: None, - scalar: params[0].clone(), - data_type: params[0].as_ref().infer_data_type(), - }, - &BUILTIN_FUNCTIONS, - )?; - let level = level.0; - if !(0.0..=1.0).contains(&level) { - return Err(ErrorCode::BadDataValueType(format!( - "level range between [0, 1], got: {:?}", - level - ))); - } - vec![level] - } else if params.is_empty() { - vec![0.5f64] - } else { - let mut levels = Vec::with_capacity(params.len()); - for param in params { - let level: F64 = check_number( - None, - &FunctionContext::default(), - &Expr::::Cast { - span: None, - is_try: false, - expr: Box::new(Expr::Constant { - span: None, - scalar: param.clone(), - data_type: param.as_ref().infer_data_type(), - }), - dest_type: DataType::Number(NumberDataType::Float64), - }, - &BUILTIN_FUNCTIONS, - )?; - let level = level.0; - if !(0.0..=1.0).contains(&level) { - return Err(ErrorCode::BadDataValueType(format!( - "level range between [0, 1], got: {:?} in levels", - level - ))); - } - levels.push(level); - } - levels - }; - let func = AggregateQuantileDiscFunction:: { - display_name: display_name.to_string(), - return_type, - levels, - _arguments: arguments, - _t: PhantomData, - _state: PhantomData, - }; - Ok(Arc::new(func)) + fn deserialize(reader: &mut &[u8]) -> Result + where Self: Sized { + deserialize_state(reader) } } + pub fn try_create_aggregate_quantile_disc_function( display_name: &str, params: Vec, @@ -330,22 +181,37 @@ pub fn try_create_aggregate_quantile_disc_function( ) -> Result { assert_unary_arguments(display_name, arguments.len())?; let data_type = arguments[0].clone(); + let levels = get_levels(¶ms)?; with_simple_no_number_mapped_type!(|T| match data_type { DataType::Number(num_type) => { - with_number_mapped_type!(|NUM| match num_type { - NumberDataType::NUM => { - type State = QuantileState>; - let return_type = if params.len() > 1 { - DataType::Array(Box::new(data_type)) + with_number_mapped_type!(|NUM_TYPE| match num_type { + NumberDataType::NUM_TYPE => { + if params.len() > 1 { + let func = AggregateUnaryFunction::< + QuantileState>, + NumberType, + ArrayType>, + >::try_create( + display_name, + DataType::Array(Box::new(data_type)), + params, + arguments[0].clone(), + ) + .with_function_data(Box::new(QuantileData { levels })) + .with_need_drop(true); + Ok(Arc::new(func)) } else { - data_type - }; - AggregateQuantileDiscFunction::, State>::try_create( - display_name, - return_type, - params, - arguments, - ) + let func = AggregateUnaryFunction::< + QuantileState>, + NumberType, + NumberType, + >::try_create( + display_name, data_type, params, arguments[0].clone() + ) + .with_function_data(Box::new(QuantileData { levels })) + .with_need_drop(true); + Ok(Arc::new(func)) + } } }) } @@ -355,18 +221,32 @@ pub fn try_create_aggregate_quantile_disc_function( scale: s.scale, }; let data_type = DataType::Decimal(DecimalDataType::from_size(decimal_size)?); - let return_type = if params.len() > 1 { - DataType::Array(Box::new(data_type)) + if params.len() > 1 { + let func = AggregateUnaryFunction::< + QuantileState>, + DecimalType, + ArrayType>, + >::try_create( + display_name, + DataType::Array(Box::new(data_type)), + params, + arguments[0].clone(), + ) + .with_function_data(Box::new(QuantileData { levels })) + .with_need_drop(true); + Ok(Arc::new(func)) } else { - data_type - }; - type State = QuantileState>; - AggregateQuantileDiscFunction::, State>::try_create( - display_name, - return_type, - params, - arguments, - ) + let func = AggregateUnaryFunction::< + QuantileState>, + DecimalType, + DecimalType, + >::try_create( + display_name, data_type, params, arguments[0].clone() + ) + .with_function_data(Box::new(QuantileData { levels })) + .with_need_drop(true); + Ok(Arc::new(func)) + } } DataType::Decimal(DecimalDataType::Decimal256(s)) => { let decimal_size = DecimalSize { @@ -374,18 +254,32 @@ pub fn try_create_aggregate_quantile_disc_function( scale: s.scale, }; let data_type = DataType::Decimal(DecimalDataType::from_size(decimal_size)?); - let return_type = if params.len() > 1 { - DataType::Array(Box::new(data_type)) + if params.len() > 1 { + let func = AggregateUnaryFunction::< + QuantileState>, + DecimalType, + ArrayType>, + >::try_create( + display_name, + DataType::Array(Box::new(data_type)), + params, + arguments[0].clone(), + ) + .with_function_data(Box::new(QuantileData { levels })) + .with_need_drop(true); + Ok(Arc::new(func)) } else { - data_type - }; - type State = QuantileState>; - AggregateQuantileDiscFunction::, State>::try_create( - display_name, - return_type, - params, - arguments, - ) + let func = AggregateUnaryFunction::< + QuantileState>, + DecimalType, + DecimalType, + >::try_create( + display_name, data_type, params, arguments[0].clone() + ) + .with_function_data(Box::new(QuantileData { levels })) + .with_need_drop(true); + Ok(Arc::new(func)) + } } _ => Err(ErrorCode::BadDataValueType(format!( "{} does not support type '{:?}'", From 5074a760c4c9db5adc38688afe825dcad1fd4348 Mon Sep 17 00:00:00 2001 From: Yijun Zhao Date: Fri, 24 Nov 2023 15:06:41 +0800 Subject: [PATCH 4/5] add stddev --- .../src/aggregates/aggregate_stddev.rs | 181 ++++-------------- 1 file changed, 36 insertions(+), 145 deletions(-) diff --git a/src/query/functions/src/aggregates/aggregate_stddev.rs b/src/query/functions/src/aggregates/aggregate_stddev.rs index 98249fc686735..36aac989279db 100644 --- a/src/query/functions/src/aggregates/aggregate_stddev.rs +++ b/src/query/functions/src/aggregates/aggregate_stddev.rs @@ -12,23 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::alloc::Layout; -use std::fmt; -use std::marker::PhantomData; use std::sync::Arc; -use common_arrow::arrow::bitmap::Bitmap; use common_exception::ErrorCode; use common_exception::Result; use common_expression::types::number::Number; use common_expression::types::number::F64; use common_expression::types::DataType; +use common_expression::types::Float64Type; use common_expression::types::NumberDataType; use common_expression::types::NumberType; use common_expression::types::ValueType; use common_expression::with_number_mapped_type; -use common_expression::Column; -use common_expression::ColumnBuilder; use common_expression::Scalar; use num_traits::AsPrimitive; use serde::Deserialize; @@ -36,43 +31,48 @@ use serde::Serialize; use super::deserialize_state; use super::serialize_state; -use super::StateAddr; +use super::AggregateUnaryFunction; +use super::FunctionData; +use super::UnaryState; use crate::aggregates::aggregate_function_factory::AggregateFunctionDescription; use crate::aggregates::aggregator_common::assert_unary_arguments; use crate::aggregates::AggregateFunction; -use crate::aggregates::AggregateFunctionRef; const POP: u8 = 0; const SAMP: u8 = 1; -#[derive(Serialize, Deserialize)] -struct AggregateStddevState { +#[derive(Default, Serialize, Deserialize)] +struct AggregateStddevState { pub sum: f64, pub count: u64, pub variance: f64, } -impl AggregateStddevState { - #[inline(always)] - fn add(&mut self, value: f64) { +impl UnaryState for AggregateStddevState +where + T: ValueType, + T::Scalar: Number + AsPrimitive, +{ + fn add(&mut self, other: T::ScalarRef<'_>) -> Result<()> { + let value = T::to_owned_scalar(other).as_(); self.sum += value; self.count += 1; if self.count > 1 { let t = self.count as f64 * value - self.sum; self.variance += (t * t) / (self.count * (self.count - 1)) as f64; } + Ok(()) } - #[inline(always)] - fn merge(&mut self, other: &Self) { + fn merge(&mut self, other: &Self) -> Result<()> { if other.count == 0 { - return; + return Ok(()); } if self.count == 0 { self.count = other.count; self.sum = other.sum; self.variance = other.variance; - return; + return Ok(()); } let t = (other.count as f64 / self.count as f64) * self.sum - other.sum; @@ -82,156 +82,47 @@ impl AggregateStddevState { * t; self.count += other.count; self.sum += other.sum; - } -} - -#[derive(Clone)] -pub struct AggregateStddevFunction { - display_name: String, - _arguments: Vec, - t: PhantomData, -} - -impl AggregateFunction for AggregateStddevFunction -where T: Number + AsPrimitive -{ - fn name(&self) -> &str { - "AggregateStddevPopFunction" - } - - fn return_type(&self) -> Result { - Ok(DataType::Number(NumberDataType::Float64)) - } - - fn init_state(&self, place: StateAddr) { - place.write(|| AggregateStddevState { - sum: 0.0, - count: 0, - variance: 0.0, - }); - } - - fn state_layout(&self) -> Layout { - Layout::new::() - } - - fn accumulate( - &self, - place: StateAddr, - columns: &[Column], - validity: Option<&Bitmap>, - _input_rows: usize, - ) -> Result<()> { - let state = place.get::(); - let column = NumberType::::try_downcast_column(&columns[0]).unwrap(); - match validity { - Some(bitmap) => { - for (value, is_valid) in column.iter().zip(bitmap.iter()) { - if is_valid { - state.add(value.as_()); - } - } - } - None => { - for value in column.iter() { - state.add(value.as_()); - } - } - } - Ok(()) } - fn accumulate_keys( - &self, - places: &[StateAddr], - offset: usize, - columns: &[Column], - _input_rows: usize, + fn merge_result( + &mut self, + builder: &mut Vec, + _function_data: Option<&dyn FunctionData>, ) -> Result<()> { - let column = NumberType::::try_downcast_column(&columns[0]).unwrap(); - - column.iter().zip(places.iter()).for_each(|(value, place)| { - let place = place.next(offset); - let state = place.get::(); - let v: f64 = value.as_(); - state.add(v); - }); - Ok(()) - } - - fn accumulate_row(&self, place: StateAddr, columns: &[Column], row: usize) -> Result<()> { - let column = NumberType::::try_downcast_column(&columns[0]).unwrap(); - - let state = place.get::(); - let v: f64 = column[row].as_(); - state.add(v); - Ok(()) - } - - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { - let state = place.get::(); - serialize_state(writer, state) - } - - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { - let state = place.get::(); - let rhs: AggregateStddevState = deserialize_state(reader)?; - state.merge(&rhs); - Ok(()) - } - - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let state = place.get::(); - let other = rhs.get::(); - state.merge(other); - Ok(()) - } - - #[allow(unused_mut)] - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { - let state = place.get::(); - let builder = NumberType::::try_downcast_builder(builder).unwrap(); - let variance = state.variance / (state.count - TYPE as u64) as f64; + let variance = self.variance / (self.count - TYPE as u64) as f64; builder.push(variance.sqrt().into()); Ok(()) } -} -impl fmt::Display for AggregateStddevFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.display_name) + fn serialize(&self, writer: &mut Vec) -> Result<()> { + serialize_state(writer, self) } -} -impl AggregateStddevFunction -where T: Number + AsPrimitive -{ - pub fn try_create( - display_name: &str, - arguments: Vec, - ) -> Result { - Ok(Arc::new(Self { - display_name: display_name.to_string(), - _arguments: arguments, - t: PhantomData, - })) + fn deserialize(reader: &mut &[u8]) -> Result + where Self: Sized { + deserialize_state(reader) } } pub fn try_create_aggregate_stddev_pop_function( display_name: &str, - _params: Vec, + params: Vec, arguments: Vec, ) -> Result> { assert_unary_arguments(display_name, arguments.len())?; with_number_mapped_type!(|NUM_TYPE| match &arguments[0] { DataType::Number(NumberDataType::NUM_TYPE) => { - AggregateStddevFunction::::try_create(display_name, arguments) + let return_type = DataType::Number(NumberDataType::Float64); + AggregateUnaryFunction::< + AggregateStddevState, + NumberType, + Float64Type, + >::try_create_unary(display_name, return_type, params, arguments[0].clone()) } _ => Err(ErrorCode::BadDataValueType(format!( - "AggregateStddevPopFunction does not support type '{:?}'", - arguments[0] + "{} does not support type '{:?}'", + display_name, arguments[0] ))), }) } From b9a02c15ca31e34b994e89d2025130215165f16f Mon Sep 17 00:00:00 2001 From: Yijun Zhao Date: Fri, 24 Nov 2023 20:39:48 +0800 Subject: [PATCH 5/5] add approx count distinct --- .../aggregate_approx_count_distinct.rs | 218 ++++++++---------- 1 file changed, 96 insertions(+), 122 deletions(-) diff --git a/src/query/functions/src/aggregates/aggregate_approx_count_distinct.rs b/src/query/functions/src/aggregates/aggregate_approx_count_distinct.rs index dce39319557d0..6269602a4bbbd 100644 --- a/src/query/functions/src/aggregates/aggregate_approx_count_distinct.rs +++ b/src/query/functions/src/aggregates/aggregate_approx_count_distinct.rs @@ -12,13 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::alloc::Layout; -use std::fmt; use std::hash::Hash; -use std::marker::PhantomData; use std::sync::Arc; -use common_arrow::arrow::bitmap::Bitmap; use common_exception::Result; use common_expression::types::AnyType; use common_expression::types::DataType; @@ -27,10 +23,9 @@ use common_expression::types::NumberDataType; use common_expression::types::NumberType; use common_expression::types::StringType; use common_expression::types::TimestampType; +use common_expression::types::UInt64Type; use common_expression::types::ValueType; use common_expression::with_number_mapped_type; -use common_expression::Column; -use common_expression::ColumnBuilder; use common_expression::Scalar; use streaming_algorithms::HyperLogLog; @@ -38,150 +33,135 @@ use super::aggregate_function::AggregateFunction; use super::aggregate_function_factory::AggregateFunctionDescription; use super::deserialize_state; use super::serialize_state; -use super::AggregateFunctionRef; -use super::StateAddr; +use super::AggregateUnaryFunction; +use super::FunctionData; +use super::UnaryState; use crate::aggregates::aggregator_common::assert_unary_arguments; /// Use Hyperloglog to estimate distinct of values -pub struct AggregateApproxCountDistinctState { - hll: HyperLogLog, -} - -/// S: ScalarType -#[derive(Clone)] -pub struct AggregateApproxCountDistinctFunction { - display_name: String, - _t: PhantomData, -} - -impl AggregateApproxCountDistinctFunction -where for<'a> T::ScalarRef<'a>: Hash +struct AggregateApproxCountDistinctState +where T: ValueType { - pub fn try_create( - display_name: &str, - _arguments: Vec, - ) -> Result { - Ok(Arc::new(Self { - display_name: display_name.to_string(), - _t: PhantomData, - })) - } + hll: HyperLogLog, } -impl AggregateFunction for AggregateApproxCountDistinctFunction -where for<'a> T::ScalarRef<'a>: Hash +impl Default for AggregateApproxCountDistinctState +where + T: ValueType + Send + Sync, + T::Scalar: Hash, { - fn name(&self) -> &str { - "AggregateApproxCountDistinctFunction" - } - - fn return_type(&self) -> Result { - Ok(DataType::Number(NumberDataType::UInt64)) - } - - fn init_state(&self, place: StateAddr) { - place.write(|| AggregateApproxCountDistinctState { - hll: HyperLogLog::>::new(0.04), - }); - } - - fn state_layout(&self) -> Layout { - Layout::new::>>() - } - - fn accumulate( - &self, - place: StateAddr, - columns: &[Column], - validity: Option<&Bitmap>, - _input_rows: usize, - ) -> Result<()> { - let state = place.get::>>(); - let column = T::try_downcast_column(&columns[0]).unwrap(); - - if let Some(validity) = validity { - T::iter_column(&column) - .zip(validity.iter()) - .for_each(|(t, b)| { - if b { - state.hll.push(&t); - } - }); - } else { - T::iter_column(&column).for_each(|t| { - state.hll.push(&t); - }); + fn default() -> Self { + Self { + hll: HyperLogLog::::new(0.04), } - Ok(()) - } - - fn accumulate_row(&self, place: StateAddr, columns: &[Column], row: usize) -> Result<()> { - let state = place.get::>>(); - let column = T::try_downcast_column(&columns[0]).unwrap(); - state.hll.push(&T::index_column(&column, row).unwrap()); - Ok(()) } +} - fn serialize(&self, place: StateAddr, writer: &mut Vec) -> Result<()> { - let state = place.get::>>(); - serialize_state(writer, &state.hll) - } - - fn merge(&self, place: StateAddr, reader: &mut &[u8]) -> Result<()> { - let state = place.get::>>(); - let hll: HyperLogLog> = deserialize_state(reader)?; - state.hll.union(&hll); - +impl UnaryState for AggregateApproxCountDistinctState +where + T: ValueType + Send + Sync, + T::Scalar: Hash, +{ + fn add(&mut self, other: T::ScalarRef<'_>) -> Result<()> { + self.hll.push(&T::to_owned_scalar(other)); Ok(()) } - fn merge_states(&self, place: StateAddr, rhs: StateAddr) -> Result<()> { - let state = place.get::>>(); - let other = rhs.get::>>(); - state.hll.union(&other.hll); + fn merge(&mut self, rhs: &Self) -> Result<()> { + self.hll.union(&rhs.hll); Ok(()) } - fn merge_result(&self, place: StateAddr, builder: &mut ColumnBuilder) -> Result<()> { - let state = place.get::>>(); - let builder = NumberType::::try_downcast_builder(builder).unwrap(); - builder.push(state.hll.len() as u64); + fn merge_result( + &mut self, + builder: &mut Vec, + _function_data: Option<&dyn FunctionData>, + ) -> Result<()> { + builder.push(self.hll.len() as u64); Ok(()) } - fn need_manual_drop_state(&self) -> bool { - true + fn serialize(&self, writer: &mut Vec) -> Result<()> { + serialize_state(writer, &self.hll) } - unsafe fn drop_state(&self, place: StateAddr) { - let state = place.get::>>(); - std::ptr::drop_in_place(state); + fn deserialize(reader: &mut &[u8]) -> Result + where Self: Sized { + let hll = deserialize_state(reader)?; + Ok(Self { hll }) } } pub fn try_create_aggregate_approx_count_distinct_function( display_name: &str, - _params: Vec, + params: Vec, arguments: Vec, ) -> Result> { assert_unary_arguments(display_name, arguments.len())?; + let return_type = DataType::Number(NumberDataType::UInt64); + with_number_mapped_type!(|NUM_TYPE| match &arguments[0] { DataType::Number(NumberDataType::NUM_TYPE) => { - AggregateApproxCountDistinctFunction::>::try_create( - display_name, - arguments, + let func = AggregateUnaryFunction::< + AggregateApproxCountDistinctState>, + NumberType, + UInt64Type, + >::try_create( + display_name, return_type, params, arguments[0].clone() + ) + .with_need_drop(true); + + Ok(Arc::new(func)) + } + DataType::String => { + let func = AggregateUnaryFunction::< + AggregateApproxCountDistinctState, + StringType, + UInt64Type, + >::try_create( + display_name, return_type, params, arguments[0].clone() + ) + .with_need_drop(true); + + Ok(Arc::new(func)) + } + DataType::Date => { + let func = AggregateUnaryFunction::< + AggregateApproxCountDistinctState, + DateType, + UInt64Type, + >::try_create( + display_name, return_type, params, arguments[0].clone() ) + .with_need_drop(true); + + Ok(Arc::new(func)) + } + DataType::Timestamp => { + let func = AggregateUnaryFunction::< + AggregateApproxCountDistinctState, + TimestampType, + UInt64Type, + >::try_create( + display_name, return_type, params, arguments[0].clone() + ) + .with_need_drop(true); + + Ok(Arc::new(func)) + } + _ => { + let func = AggregateUnaryFunction::< + AggregateApproxCountDistinctState, + AnyType, + UInt64Type, + >::try_create( + display_name, return_type, params, arguments[0].clone() + ) + .with_need_drop(true); + + Ok(Arc::new(func)) } - DataType::String => - AggregateApproxCountDistinctFunction::::try_create(display_name, arguments,), - DataType::Date => - AggregateApproxCountDistinctFunction::::try_create(display_name, arguments,), - DataType::Timestamp => AggregateApproxCountDistinctFunction::::try_create( - display_name, - arguments, - ), - _ => AggregateApproxCountDistinctFunction::::try_create(display_name, arguments,), }) } @@ -196,9 +176,3 @@ pub fn aggregate_approx_count_distinct_function_desc() -> AggregateFunctionDescr features, ) } - -impl fmt::Display for AggregateApproxCountDistinctFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.display_name) - } -}