diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 5e609f125bc86..b3dbef7dfdf53 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -58,10 +58,9 @@ pub fn create_aggregate_expr( return_type, )), (AggregateFunction::Count, true) => Arc::new(expressions::DistinctCount::new( - input_phy_types, - input_phy_exprs, + input_phy_types[0].clone(), + input_phy_exprs[0].clone(), name, - return_type, )), (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( input_phy_exprs[0].clone(), diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 0f3c4c5b4d005..8fe6758ef5241 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -31,35 +31,30 @@ use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; #[derive(Debug, PartialEq, Eq, Hash, Clone)] -struct DistinctScalarValues(Vec); +struct DistinctScalarValues(ScalarValue); /// Expression for a COUNT(DISTINCT) aggregation. #[derive(Debug)] pub struct DistinctCount { /// Column name name: String, - /// The DataType for the final count - data_type: DataType, /// The DataType used to hold the state for each input - state_data_types: Vec, + state_data_types: DataType, /// The input arguments - exprs: Vec>, + exprs: Arc, } impl DistinctCount { /// Create a new COUNT(DISTINCT) aggregate function. pub fn new( - input_data_types: Vec, - exprs: Vec>, + input_data_types: DataType, + exprs: Arc, name: String, - data_type: DataType, ) -> Self { - let state_data_types = input_data_types; - + // let state_data_types = input_data_types[0].clone(); Self { name, - data_type, - state_data_types, + state_data_types: input_data_types, exprs, } } @@ -72,36 +67,30 @@ impl AggregateExpr for DistinctCount { } fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) + Ok(Field::new(&self.name, DataType::Int64, true)) } fn state_fields(&self) -> Result> { - Ok(self - .state_data_types - .iter() - .map(|state_data_type| { - Field::new( - format_state_name(&self.name, "count distinct"), - DataType::List(Box::new(Field::new( - "item", - state_data_type.clone(), - true, - ))), - false, - ) - }) - .collect::>()) + Ok(vec![Field::new( + format_state_name(&self.name, "count distinct"), + DataType::List(Box::new(Field::new( + "item", + self.state_data_types.clone(), + true, + ))), + false, + )]) } fn expressions(&self) -> Vec> { - self.exprs.clone() + vec![self.exprs.clone()] } fn create_accumulator(&self) -> Result> { Ok(Box::new(DistinctCountAccumulator { values: HashSet::default(), state_data_types: self.state_data_types.clone(), - count_data_type: self.data_type.clone(), + count_data_type: DataType::Int64, })) } @@ -113,106 +102,61 @@ impl AggregateExpr for DistinctCount { #[derive(Debug)] struct DistinctCountAccumulator { values: HashSet, - state_data_types: Vec, + state_data_types: DataType, count_data_type: DataType, } -impl DistinctCountAccumulator { - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - // If a row has a NULL, it is not included in the final count. - if !values.iter().any(|v| v.is_null()) { - self.values.insert(DistinctScalarValues(values.to_vec())); - } - - Ok(()) - } - - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - let col_values = states +impl Accumulator for DistinctCountAccumulator { + fn state(&self) -> Result> { + let mut cols_out = + ScalarValue::new_list(Some(Vec::new()), self.state_data_types.clone()); + self.values .iter() - .map(|state| match state { - ScalarValue::List(Some(values), _) => Ok(values), - _ => Err(DataFusionError::Internal(format!( - "Unexpected accumulator state {state:?}" - ))), - }) - .collect::>>()?; - - (0..col_values[0].len()).try_for_each(|row_index| { - let row_values = col_values - .iter() - .map(|col| col[row_index].clone()) - .collect::>(); - self.update(&row_values) - }) + .enumerate() + .for_each(|(_, distinct_values)| { + if let ScalarValue::List(Some(ref mut v), _) = cols_out { + v.push(distinct_values.0.clone()); + } + }); + Ok(vec![cols_out]) } -} - -impl Accumulator for DistinctCountAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { if values.is_empty() { return Ok(()); } - (0..values[0].len()).try_for_each(|index| { - let v = values - .iter() - .map(|array| ScalarValue::try_from_array(array, index)) - .collect::>>()?; - self.update(&v) + let arr = &values[0]; + (0..arr.len()).try_for_each(|index| { + if !arr.is_null(index) { + let scalar = ScalarValue::try_from_array(arr, index)?; + self.values.insert(DistinctScalarValues(scalar)); + } + Ok(()) }) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { if states.is_empty() { return Ok(()); } - (0..states[0].len()).try_for_each(|index| { - let v = states - .iter() - .map(|array| ScalarValue::try_from_array(array, index)) - .collect::>>()?; - self.merge(&v) + let arr = &states[0]; + (0..arr.len()).try_for_each(|index| { + let scalar = ScalarValue::try_from_array(arr, index)?; + + if let ScalarValue::List(Some(scalar), _) = scalar { + scalar.iter().for_each(|scalar| { + if !ScalarValue::is_null(scalar) { + self.values.insert(DistinctScalarValues(scalar.clone())); + } + }); + } else { + return Err(DataFusionError::Internal( + "Unexpected accumulator state".into(), + )); + } + Ok(()) }) } - fn state(&self) -> Result> { - let mut cols_out = self - .state_data_types - .iter() - .map(|state_data_type| { - ScalarValue::new_list(Some(Vec::new()), state_data_type.clone()) - }) - .collect::>(); - - let mut cols_vec = cols_out - .iter_mut() - .map(|c| match c { - ScalarValue::List(Some(ref mut v), _) => Ok(v), - t => Err(DataFusionError::Internal(format!( - "cols_out should only consist of ScalarValue::List. {t:?} is found" - ))), - }) - .collect::>>()?; - - self.values.iter().for_each(|distinct_values| { - distinct_values.0.iter().enumerate().for_each( - |(col_index, distinct_value)| { - cols_vec[col_index].push(distinct_value.clone()); - }, - ) - }); - - Ok(cols_out.into_iter().collect()) - } fn evaluate(&self) -> Result { - match &self.count_data_type { - DataType::Int64 => Ok(ScalarValue::Int64(Some(self.values.len() as i64))), - t => Err(DataFusionError::Internal(format!( - "Invalid data type {t:?} for count distinct aggregation" - ))), - } + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) } fn size(&self) -> usize { @@ -221,16 +165,11 @@ impl Accumulator for DistinctCountAccumulator { + self .values .iter() - .map(|vals| { - ScalarValue::size_of_vec(&vals.0) - std::mem::size_of_val(&vals.0) - }) - .sum::() - + (std::mem::size_of::() * self.state_data_types.capacity()) - + self - .state_data_types - .iter() - .map(|dt| dt.size() - std::mem::size_of_val(dt)) + .map(|vals| (vals.0.size()) - std::mem::size_of_val(&vals.0)) .sum::() + + std::mem::size_of::() + + self.state_data_types.size() + - std::mem::size_of_val(&self.state_data_types) + self.count_data_type.size() - std::mem::size_of_val(&self.count_data_type) } @@ -238,14 +177,14 @@ impl Accumulator for DistinctCountAccumulator { #[cfg(test)] mod tests { + use crate::expressions::NoOp; + use super::*; use arrow::array::{ ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; - use arrow::array::{Int32Builder, ListBuilder, UInt64Builder}; use arrow::datatypes::DataType; - use datafusion_common::cast::as_list_array; macro_rules! state_to_vec { ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{ @@ -275,31 +214,6 @@ mod tests { }}; } - macro_rules! build_list { - ($LISTS:expr, $BUILDER_TYPE:ident) => {{ - let mut builder = ListBuilder::new($BUILDER_TYPE::with_capacity(0)); - for list in $LISTS.iter() { - match list { - Some(values) => { - for value in values.iter() { - match value { - Some(v) => builder.values().append_value((*v).into()), - None => builder.values().append_null(), - } - } - - builder.append(true); - } - None => { - builder.append(false); - } - } - } - - Arc::new(builder.finish()) as ArrayRef - }}; - } - macro_rules! test_count_distinct_update_batch_numeric { ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ let values: Vec> = vec![ @@ -330,28 +244,11 @@ mod tests { }}; } - fn collect_states( - state1: &[Option], - state2: &[Option], - ) -> Vec<(Option, Option)> { - let mut states = state1 - .iter() - .zip(state2.iter()) - .map(|(l, r)| (l.clone(), r.clone())) - .collect::, Option)>>(); - states.sort(); - states - } - fn run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec, ScalarValue)> { let agg = DistinctCount::new( - arrays - .iter() - .map(|a| a.data_type().clone()) - .collect::>(), - vec![], + arrays[0].data_type().clone(), + Arc::new(NoOp::new()), String::from("__col_name__"), - DataType::Int64, ); let mut accum = agg.create_accumulator()?; @@ -365,10 +262,9 @@ mod tests { rows: &[Vec], ) -> Result<(Vec, ScalarValue)> { let agg = DistinctCount::new( - data_types.to_vec(), - vec![], + data_types[0].clone(), + Arc::new(NoOp::new()), String::from("__col_name__"), - DataType::Int64, ); let mut accum = agg.create_accumulator()?; @@ -391,24 +287,6 @@ mod tests { Ok((accum.state()?, accum.evaluate()?)) } - fn run_merge_batch(arrays: &[ArrayRef]) -> Result<(Vec, ScalarValue)> { - let agg = DistinctCount::new( - arrays - .iter() - .map(|a| as_list_array(a).unwrap()) - .map(|a| a.values().data_type().clone()) - .collect::>(), - vec![], - String::from("__col_name__"), - DataType::Int64, - ); - - let mut accum = agg.create_accumulator()?; - accum.merge_batch(arrays)?; - - Ok((accum.state()?, accum.evaluate()?)) - } - // Used trait to create associated constant for f32 and f64 trait SubNormal: 'static { const SUBNORMAL: Self; @@ -610,133 +488,75 @@ mod tests { Ok(()) } - #[test] - fn count_distinct_update_batch_multiple_columns() -> Result<()> { - let array_int8: ArrayRef = Arc::new(Int8Array::from(vec![1, 1, 2])); - let array_int16: ArrayRef = Arc::new(Int16Array::from(vec![3, 3, 4])); - let arrays = vec![array_int8, array_int16]; - - let (states, result) = run_update_batch(&arrays)?; - - let state_vec1 = state_to_vec!(&states[0], Int8, i8).unwrap(); - let state_vec2 = state_to_vec!(&states[1], Int16, i16).unwrap(); - let state_pairs = collect_states::(&state_vec1, &state_vec2); - - assert_eq!(states.len(), 2); - assert_eq!( - state_pairs, - vec![(Some(1_i8), Some(3_i16)), (Some(2_i8), Some(4_i16))] - ); - - assert_eq!(result, ScalarValue::Int64(Some(2))); - - Ok(()) - } - #[test] fn count_distinct_update() -> Result<()> { let (states, result) = run_update( - &[DataType::Int32, DataType::UInt64], + &[DataType::Int32], &[ - vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(5))], - vec![ScalarValue::Int32(Some(5)), ScalarValue::UInt64(Some(1))], - vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(5))], - vec![ScalarValue::Int32(Some(5)), ScalarValue::UInt64(Some(1))], - vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(6))], - vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(7))], - vec![ScalarValue::Int32(Some(2)), ScalarValue::UInt64(Some(7))], + vec![ScalarValue::Int32(Some(-1))], + vec![ScalarValue::Int32(Some(5))], + vec![ScalarValue::Int32(Some(-1))], + vec![ScalarValue::Int32(Some(5))], + vec![ScalarValue::Int32(Some(-1))], + vec![ScalarValue::Int32(Some(-1))], + vec![ScalarValue::Int32(Some(2))], ], )?; + assert_eq!(states.len(), 1); + assert_eq!(result, ScalarValue::Int64(Some(3))); - let state_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap(); - let state_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap(); - let state_pairs = collect_states::(&state_vec1, &state_vec2); - - assert_eq!(states.len(), 2); - assert_eq!( - state_pairs, - vec![ - (Some(-1_i32), Some(5_u64)), - (Some(-1_i32), Some(6_u64)), - (Some(-1_i32), Some(7_u64)), - (Some(2_i32), Some(7_u64)), - (Some(5_i32), Some(1_u64)), - ] - ); - assert_eq!(result, ScalarValue::Int64(Some(5))); - + let (states, result) = run_update( + &[DataType::UInt64], + &[ + vec![ScalarValue::UInt64(Some(1))], + vec![ScalarValue::UInt64(Some(5))], + vec![ScalarValue::UInt64(Some(1))], + vec![ScalarValue::UInt64(Some(5))], + vec![ScalarValue::UInt64(Some(1))], + vec![ScalarValue::UInt64(Some(1))], + vec![ScalarValue::UInt64(Some(2))], + ], + )?; + assert_eq!(states.len(), 1); + assert_eq!(result, ScalarValue::Int64(Some(3))); Ok(()) } #[test] fn count_distinct_update_with_nulls() -> Result<()> { let (states, result) = run_update( - &[DataType::Int32, DataType::UInt64], + &[DataType::Int32], &[ // None of these updates contains a None, so these are accumulated. - vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(5))], - vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(Some(5))], - vec![ScalarValue::Int32(Some(-2)), ScalarValue::UInt64(Some(5))], + vec![ScalarValue::Int32(Some(-1))], + vec![ScalarValue::Int32(Some(-1))], + vec![ScalarValue::Int32(Some(-2))], // Each of these updates contains at least one None, so these // won't be accumulated. - vec![ScalarValue::Int32(Some(-1)), ScalarValue::UInt64(None)], - vec![ScalarValue::Int32(None), ScalarValue::UInt64(Some(5))], - vec![ScalarValue::Int32(None), ScalarValue::UInt64(None)], + vec![ScalarValue::Int32(Some(-1))], + vec![ScalarValue::Int32(None)], + vec![ScalarValue::Int32(None)], ], )?; - - let state_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap(); - let state_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap(); - let state_pairs = collect_states::(&state_vec1, &state_vec2); - - assert_eq!(states.len(), 2); - assert_eq!( - state_pairs, - vec![(Some(-2_i32), Some(5_u64)), (Some(-1_i32), Some(5_u64))] - ); - + assert_eq!(states.len(), 1); assert_eq!(result, ScalarValue::Int64(Some(2))); - Ok(()) - } - - #[test] - fn count_distinct_merge_batch() -> Result<()> { - let state_in1 = build_list!( - vec![ - Some(vec![Some(-1_i32), Some(-1_i32), Some(-2_i32), Some(-2_i32)]), - Some(vec![Some(-2_i32), Some(-3_i32)]), - ], - Int32Builder - ); - - let state_in2 = build_list!( - vec![ - Some(vec![Some(5_u64), Some(6_u64), Some(5_u64), Some(7_u64)]), - Some(vec![Some(5_u64), Some(7_u64)]), + let (states, result) = run_update( + &[DataType::UInt64], + &[ + // None of these updates contains a None, so these are accumulated. + vec![ScalarValue::UInt64(Some(1))], + vec![ScalarValue::UInt64(Some(1))], + vec![ScalarValue::UInt64(Some(2))], + // Each of these updates contains at least one None, so these + // won't be accumulated. + vec![ScalarValue::UInt64(Some(1))], + vec![ScalarValue::UInt64(None)], + vec![ScalarValue::UInt64(None)], ], - UInt64Builder - ); - - let (states, result) = run_merge_batch(&[state_in1, state_in2])?; - - let state_out_vec1 = state_to_vec!(&states[0], Int32, i32).unwrap(); - let state_out_vec2 = state_to_vec!(&states[1], UInt64, u64).unwrap(); - let state_pairs = collect_states::(&state_out_vec1, &state_out_vec2); - - assert_eq!( - state_pairs, - vec![ - (Some(-3_i32), Some(7_u64)), - (Some(-2_i32), Some(5_u64)), - (Some(-2_i32), Some(7_u64)), - (Some(-1_i32), Some(5_u64)), - (Some(-1_i32), Some(6_u64)), - ] - ); - - assert_eq!(result, ScalarValue::Int64(Some(5))); - + )?; + assert_eq!(states.len(), 1); + assert_eq!(result, ScalarValue::Int64(Some(2))); Ok(()) } }