diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs new file mode 100644 index 000000000000..dfe19e0eb4c2 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -0,0 +1,308 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)` + +use super::*; +use arrow::datatypes::{DataType, Field}; +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef}; +use std::collections::HashSet; + +use crate::{AggregateExpr, PhysicalExpr}; +use datafusion_common::Result; +use datafusion_common::ScalarValue; +use datafusion_expr::Accumulator; + +/// Expression for a ARRAY_AGG(DISTINCT) aggregation. +#[derive(Debug)] +pub struct DistinctArrayAgg { + /// Column name + name: String, + /// The DataType for the input expression + input_data_type: DataType, + /// The input expression + expr: Arc, +} + +impl DistinctArrayAgg { + /// Create a new DistinctArrayAgg aggregate function + pub fn new( + expr: Arc, + name: impl Into, + input_data_type: DataType, + ) -> Self { + let name = name.into(); + Self { + name, + expr, + input_data_type, + } + } +} + +impl AggregateExpr for DistinctArrayAgg { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new( + &self.name, + DataType::List(Box::new(Field::new( + "item", + self.input_data_type.clone(), + true, + ))), + false, + )) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(DistinctArrayAggAccumulator::try_new( + &self.input_data_type, + )?)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![Field::new( + &format_state_name(&self.name, "distinct_array_agg"), + DataType::List(Box::new(Field::new( + "item", + self.input_data_type.clone(), + true, + ))), + false, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +#[derive(Debug)] +struct DistinctArrayAggAccumulator { + values: HashSet, + datatype: DataType, +} + +impl DistinctArrayAggAccumulator { + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + values: HashSet::new(), + datatype: datatype.clone(), + }) + } +} + +impl Accumulator for DistinctArrayAggAccumulator { + fn state(&self) -> Result> { + Ok(vec![ScalarValue::List( + Some(Box::new(self.values.clone().into_iter().collect())), + Box::new(self.datatype.clone()), + )]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + assert_eq!(values.len(), 1, "batch input should only include 1 column!"); + + let arr = &values[0]; + for i in 0..arr.len() { + self.values.insert(ScalarValue::try_from_array(arr, i)?); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + }; + + for array in states { + for j in 0..array.len() { + self.values.insert(ScalarValue::try_from_array(array, j)?); + } + } + + Ok(()) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::List( + Some(Box::new(self.values.clone().into_iter().collect())), + Box::new(self.datatype.clone()), + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::col; + use crate::expressions::tests::aggregate; + use arrow::array::{ArrayRef, Int32Array}; + use arrow::datatypes::{DataType, Schema}; + use arrow::record_batch::RecordBatch; + + fn check_distinct_array_agg( + input: ArrayRef, + expected: ScalarValue, + datatype: DataType, + ) -> Result<()> { + let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![input])?; + + let agg = Arc::new(DistinctArrayAgg::new( + col("a", &schema)?, + "bla".to_string(), + datatype, + )); + let actual = aggregate(&batch, agg)?; + + match (expected, actual) { + (ScalarValue::List(Some(mut e), _), ScalarValue::List(Some(mut a), _)) => { + // workaround lack of Ord of ScalarValue + let cmp = |a: &ScalarValue, b: &ScalarValue| { + a.partial_cmp(b).expect("Can compare ScalarValues") + }; + + e.sort_by(cmp); + a.sort_by(cmp); + // Check that the inputs are the same + assert_eq!(e, a); + } + _ => { + unreachable!() + } + } + + Ok(()) + } + + #[test] + fn distinct_array_agg_i32() -> Result<()> { + let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); + + let out = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::Int32(Some(1)), + ScalarValue::Int32(Some(2)), + ScalarValue::Int32(Some(7)), + ScalarValue::Int32(Some(4)), + ScalarValue::Int32(Some(5)), + ])), + Box::new(DataType::Int32), + ); + + check_distinct_array_agg(col, out, DataType::Int32) + } + + #[test] + fn distinct_array_agg_nested() -> Result<()> { + // [[1, 2, 3], [4, 5]] + let l1 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(1i32), + ScalarValue::from(2i32), + ScalarValue::from(3i32), + ])), + Box::new(DataType::Int32), + ), + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(4i32), + ScalarValue::from(5i32), + ])), + Box::new(DataType::Int32), + ), + ])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + // [[6], [7, 8]] + let l2 = ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::List( + Some(Box::new(vec![ScalarValue::from(6i32)])), + Box::new(DataType::Int32), + ), + ScalarValue::List( + Some(Box::new(vec![ + ScalarValue::from(7i32), + ScalarValue::from(8i32), + ])), + Box::new(DataType::Int32), + ), + ])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + // [[9]] + let l3 = ScalarValue::List( + Some(Box::new(vec![ScalarValue::List( + Some(Box::new(vec![ScalarValue::from(9i32)])), + Box::new(DataType::Int32), + )])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + let list = ScalarValue::List( + Some(Box::new(vec![l1.clone(), l2.clone(), l3.clone()])), + Box::new(DataType::List(Box::new(Field::new( + "item", + DataType::Int32, + true, + )))), + ); + + // Duplicate l1 in the input array and check that it is deduped in the output. + let array = ScalarValue::iter_to_array(vec![l1.clone(), l2, l3, l1]).unwrap(); + + check_distinct_array_agg( + array, + list, + DataType::List(Box::new(Field::new( + "item", + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + true, + ))), + ) + } +} diff --git a/datafusion/physical-expr/src/aggregate/distinct_expressions.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs similarity index 74% rename from datafusion/physical-expr/src/aggregate/distinct_expressions.rs rename to datafusion/physical-expr/src/aggregate/count_distinct.rs index e4f1e01e35c1..cb32dcd4969b 100644 --- a/datafusion/physical-expr/src/aggregate/distinct_expressions.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -15,12 +15,10 @@ // specific language governing permissions and limitations // under the License. -//! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)` - +use super::*; use arrow::datatypes::{DataType, Field}; use std::any::Any; use std::fmt::Debug; -use std::hash::Hash; use std::sync::Arc; use ahash::RandomState; @@ -32,13 +30,6 @@ use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; -#[derive(Debug, PartialEq, Eq, Hash, Clone)] -struct DistinctScalarValues(Vec); - -fn format_state_name(name: &str, state_name: &str) -> String { - format!("{}[{}]", name, state_name) -} - /// Expression for a COUNT(DISTINCT) aggregation. #[derive(Debug)] pub struct DistinctCount { @@ -233,146 +224,44 @@ impl Accumulator for DistinctCountAccumulator { } } -/// Expression for a ARRAY_AGG(DISTINCT) aggregation. -#[derive(Debug)] -pub struct DistinctArrayAgg { - /// Column name - name: String, - /// The DataType for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, -} - -impl DistinctArrayAgg { - /// Create a new DistinctArrayAgg aggregate function - pub fn new( - expr: Arc, - name: impl Into, - input_data_type: DataType, - ) -> Self { - let name = name.into(); - Self { - name, - expr, - input_data_type, - } - } -} - -impl AggregateExpr for DistinctArrayAgg { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - DataType::List(Box::new(Field::new( - "item", - self.input_data_type.clone(), - true, - ))), - false, - )) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(DistinctArrayAggAccumulator::try_new( - &self.input_data_type, - )?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - &format_state_name(&self.name, "distinct_array_agg"), - DataType::List(Box::new(Field::new( - "item", - self.input_data_type.clone(), - true, - ))), - false, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -#[derive(Debug)] -struct DistinctArrayAggAccumulator { - values: HashSet, - datatype: DataType, -} - -impl DistinctArrayAggAccumulator { - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - values: HashSet::new(), - datatype: datatype.clone(), - }) - } -} - -impl Accumulator for DistinctArrayAggAccumulator { - fn state(&self) -> Result> { - Ok(vec![ScalarValue::List( - Some(Box::new(self.values.clone().into_iter().collect())), - Box::new(self.datatype.clone()), - )]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - assert_eq!(values.len(), 1, "batch input should only include 1 column!"); - - let arr = &values[0]; - for i in 0..arr.len() { - self.values.insert(ScalarValue::try_from_array(arr, i)?); - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - }; - - for array in states { - for j in 0..array.len() { - self.values.insert(ScalarValue::try_from_array(array, j)?); - } - } - - Ok(()) - } - - fn evaluate(&self) -> Result { - Ok(ScalarValue::List( - Some(Box::new(self.values.clone().into_iter().collect())), - Box::new(self.datatype.clone()), - )) - } -} - #[cfg(test)] mod tests { use super::*; - use crate::expressions::col; - use crate::expressions::tests::aggregate; use arrow::array::{ ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, ListArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use arrow::array::{Int32Builder, ListBuilder, UInt64Builder}; - use arrow::datatypes::{DataType, Schema}; - use arrow::record_batch::RecordBatch; + use arrow::datatypes::DataType; + + macro_rules! state_to_vec { + ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{ + match $LIST { + ScalarValue::List(_, data_type) => match data_type.as_ref() { + &DataType::$DATA_TYPE => (), + _ => panic!("Unexpected DataType for list"), + }, + _ => panic!("Expected a ScalarValue::List"), + } + + match $LIST { + ScalarValue::List(None, _) => None, + ScalarValue::List(Some(scalar_values), _) => { + let vec = scalar_values + .iter() + .map(|scalar_value| match scalar_value { + ScalarValue::$DATA_TYPE(value) => *value, + _ => panic!("Unexpected ScalarValue variant"), + }) + .collect::>>(); + + Some(vec) + } + _ => unreachable!(), + } + }}; + } macro_rules! build_list { ($LISTS:expr, $BUILDER_TYPE:ident) => {{ @@ -401,31 +290,33 @@ mod tests { }}; } - macro_rules! state_to_vec { - ($LIST:expr, $DATA_TYPE:ident, $PRIM_TY:ty) => {{ - match $LIST { - ScalarValue::List(_, data_type) => match data_type.as_ref() { - &DataType::$DATA_TYPE => (), - _ => panic!("Unexpected DataType for list"), - }, - _ => panic!("Expected a ScalarValue::List"), - } + macro_rules! test_count_distinct_update_batch_numeric { + ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ + let values: Vec> = vec![ + Some(1), + Some(1), + None, + Some(3), + Some(2), + None, + Some(2), + Some(3), + Some(1), + ]; - match $LIST { - ScalarValue::List(None, _) => None, - ScalarValue::List(Some(scalar_values), _) => { - let vec = scalar_values - .iter() - .map(|scalar_value| match scalar_value { - ScalarValue::$DATA_TYPE(value) => *value, - _ => panic!("Unexpected ScalarValue variant"), - }) - .collect::>>(); + let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; - Some(vec) - } - _ => unreachable!(), - } + let (states, result) = run_update_batch(&arrays)?; + + let mut state_vec = + state_to_vec!(&states[0], $DATA_TYPE, $PRIM_TYPE).unwrap(); + state_vec.sort(); + + assert_eq!(states.len(), 1); + assert_eq!(state_vec, vec![Some(1), Some(2), Some(3)]); + assert_eq!(result, ScalarValue::UInt64(Some(3))); + + Ok(()) }}; } @@ -508,36 +399,6 @@ mod tests { Ok((accum.state()?, accum.evaluate()?)) } - macro_rules! test_count_distinct_update_batch_numeric { - ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ - let values: Vec> = vec![ - Some(1), - Some(1), - None, - Some(3), - Some(2), - None, - Some(2), - Some(3), - Some(1), - ]; - - let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - - let mut state_vec = - state_to_vec!(&states[0], $DATA_TYPE, $PRIM_TYPE).unwrap(); - state_vec.sort(); - - assert_eq!(states.len(), 1); - assert_eq!(state_vec, vec![Some(1), Some(2), Some(3)]); - assert_eq!(result, ScalarValue::UInt64(Some(3))); - - Ok(()) - }}; - } - // Used trait to create associated constant for f32 and f64 trait SubNormal: 'static { const SUBNORMAL: Self; @@ -870,143 +731,4 @@ mod tests { Ok(()) } - - fn check_distinct_array_agg( - input: ArrayRef, - expected: ScalarValue, - datatype: DataType, - ) -> Result<()> { - let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![input])?; - - let agg = Arc::new(DistinctArrayAgg::new( - col("a", &schema)?, - "bla".to_string(), - datatype, - )); - let actual = aggregate(&batch, agg)?; - - match (expected, actual) { - (ScalarValue::List(Some(mut e), _), ScalarValue::List(Some(mut a), _)) => { - // workaround lack of Ord of ScalarValue - let cmp = |a: &ScalarValue, b: &ScalarValue| { - a.partial_cmp(b).expect("Can compare ScalarValues") - }; - - e.sort_by(cmp); - a.sort_by(cmp); - // Check that the inputs are the same - assert_eq!(e, a); - } - _ => { - unreachable!() - } - } - - Ok(()) - } - - #[test] - fn distinct_array_agg_i32() -> Result<()> { - let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); - - let out = ScalarValue::List( - Some(Box::new(vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(7)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ])), - Box::new(DataType::Int32), - ); - - check_distinct_array_agg(col, out, DataType::Int32) - } - - #[test] - fn distinct_array_agg_nested() -> Result<()> { - // [[1, 2, 3], [4, 5]] - let l1 = ScalarValue::List( - Some(Box::new(vec![ - ScalarValue::List( - Some(Box::new(vec![ - ScalarValue::from(1i32), - ScalarValue::from(2i32), - ScalarValue::from(3i32), - ])), - Box::new(DataType::Int32), - ), - ScalarValue::List( - Some(Box::new(vec![ - ScalarValue::from(4i32), - ScalarValue::from(5i32), - ])), - Box::new(DataType::Int32), - ), - ])), - Box::new(DataType::List(Box::new(Field::new( - "item", - DataType::Int32, - true, - )))), - ); - - // [[6], [7, 8]] - let l2 = ScalarValue::List( - Some(Box::new(vec![ - ScalarValue::List( - Some(Box::new(vec![ScalarValue::from(6i32)])), - Box::new(DataType::Int32), - ), - ScalarValue::List( - Some(Box::new(vec![ - ScalarValue::from(7i32), - ScalarValue::from(8i32), - ])), - Box::new(DataType::Int32), - ), - ])), - Box::new(DataType::List(Box::new(Field::new( - "item", - DataType::Int32, - true, - )))), - ); - - // [[9]] - let l3 = ScalarValue::List( - Some(Box::new(vec![ScalarValue::List( - Some(Box::new(vec![ScalarValue::from(9i32)])), - Box::new(DataType::Int32), - )])), - Box::new(DataType::List(Box::new(Field::new( - "item", - DataType::Int32, - true, - )))), - ); - - let list = ScalarValue::List( - Some(Box::new(vec![l1.clone(), l2.clone(), l3.clone()])), - Box::new(DataType::List(Box::new(Field::new( - "item", - DataType::Int32, - true, - )))), - ); - - // Duplicate l1 in the input array and check that it is deduped in the output. - let array = ScalarValue::iter_to_array(vec![l1.clone(), l2, l3, l1]).unwrap(); - - check_distinct_array_agg( - array, - list, - DataType::List(Box::new(Field::new( - "item", - DataType::List(Box::new(Field::new("item", DataType::Int32, true))), - true, - ))), - ) - } } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 019a60cd5760..a9f3167c48d1 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -17,7 +17,7 @@ use crate::PhysicalExpr; use arrow::datatypes::Field; -use datafusion_common::Result; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; use std::any::Any; use std::fmt::Debug; @@ -28,12 +28,13 @@ pub(crate) mod approx_median; pub(crate) mod approx_percentile_cont; pub(crate) mod approx_percentile_cont_with_weight; pub(crate) mod array_agg; +pub(crate) mod array_agg_distinct; pub(crate) mod average; pub(crate) mod coercion_rule; pub(crate) mod correlation; pub(crate) mod count; +pub(crate) mod count_distinct; pub(crate) mod covariance; -pub(crate) mod distinct_expressions; #[macro_use] pub(crate) mod min_max; pub mod build_in; @@ -76,3 +77,10 @@ pub trait AggregateExpr: Send + Sync + Debug { "AggregateExpr: default name" } } + +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +struct DistinctScalarValues(Vec); + +fn format_state_name(name: &str, state_name: &str) -> String { + format!("{}[{}]", name, state_name) +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 813e87c8f8df..3190f680fa8d 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -42,12 +42,13 @@ pub use crate::aggregate::approx_median::ApproxMedian; pub use crate::aggregate::approx_percentile_cont::ApproxPercentileCont; pub use crate::aggregate::approx_percentile_cont_with_weight::ApproxPercentileContWithWeight; pub use crate::aggregate::array_agg::ArrayAgg; +pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; pub use crate::aggregate::average::{Avg, AvgAccumulator}; pub use crate::aggregate::build_in::create_aggregate_expr; pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; +pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; -pub use crate::aggregate::distinct_expressions::{DistinctArrayAgg, DistinctCount}; pub use crate::aggregate::min_max::{Max, Min}; pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; pub use crate::aggregate::stats::StatsType;