From 2e4b5b08f65b6de170a7031fcf751763e84fbd6c Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Sat, 22 Jun 2024 09:33:34 +0200 Subject: [PATCH 1/6] Add input_nullable to UDAF args StateField/AccumulatorArgs This follows how it done for input_type and only provide a single value. But might need to be changed into a Vec in the future. This is need when we are moving `arrag_agg` to udaf where one of the states nullability will depend on the nullability of the input. --- datafusion/expr/src/function.rs | 6 ++++++ datafusion/functions-aggregate/src/first_last.rs | 1 + datafusion/physical-expr-common/src/aggregate/mod.rs | 7 +++++++ 3 files changed, 14 insertions(+) diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 169436145aae..53dfa2c460d0 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -83,6 +83,9 @@ pub struct AccumulatorArgs<'a> { /// The input type of the aggregate function. pub input_type: &'a DataType, + /// If the input type is nullable. + pub input_nullable: bool, + /// The logical expression of arguments the aggregate function takes. pub input_exprs: &'a [Expr], } @@ -98,6 +101,9 @@ pub struct StateFieldsArgs<'a> { /// The input type of the aggregate function. pub input_type: &'a DataType, + /// If the input type is nullable. + pub input_nullable: bool, + /// The return type of the aggregate function. pub return_type: &'a DataType, diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index dd38e3487264..066ab77eadcf 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -440,6 +440,7 @@ impl AggregateUDFImpl for LastValue { let StateFieldsArgs { name, input_type, + input_nullable: _, return_type: _, ordering_fields, is_distinct: _, diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 432267e045b2..e82601efa58e 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -87,6 +87,7 @@ pub fn create_aggregate_expr( ordering_fields, is_distinct, input_type: input_exprs_types[0].clone(), + input_nullable: input_phy_exprs[0].nullable(&schema)?, })) } @@ -248,6 +249,7 @@ pub struct AggregateFunctionExpr { ordering_fields: Vec, is_distinct: bool, input_type: DataType, + input_nullable: bool, } impl AggregateFunctionExpr { @@ -276,6 +278,7 @@ impl AggregateExpr for AggregateFunctionExpr { let args = StateFieldsArgs { name: &self.name, input_type: &self.input_type, + input_nullable: self.input_nullable, return_type: &self.data_type, ordering_fields: &self.ordering_fields, is_distinct: self.is_distinct, @@ -296,6 +299,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, + input_nullable: self.input_nullable, input_exprs: &self.logical_args, name: &self.name, }; @@ -311,6 +315,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, + input_nullable: self.input_nullable, input_exprs: &self.logical_args, name: &self.name, }; @@ -381,6 +386,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, + input_nullable: self.input_nullable, input_exprs: &self.logical_args, name: &self.name, }; @@ -395,6 +401,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, + input_nullable: self.input_nullable, input_exprs: &self.logical_args, name: &self.name, }; From ecdf2744da6e498ef776d81fb37ef2cd18628b94 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Fri, 21 Jun 2024 09:50:36 +0200 Subject: [PATCH 2/6] Make ArragAgg (not ordered or distinct) into a UDAF --- datafusion/core/src/dataframe/mod.rs | 6 +- datafusion/core/src/physical_planner.rs | 28 ++ datafusion/core/tests/dataframe/mod.rs | 8 +- datafusion/core/tests/sql/aggregates.rs | 2 +- datafusion/expr/src/expr_fn.rs | 12 - datafusion/functions-aggregate/Cargo.toml | 1 + .../functions-aggregate/src/array_agg.rs | 271 ++++++++++++++++++ datafusion/functions-aggregate/src/lib.rs | 6 +- .../physical-expr/src/aggregate/array_agg.rs | 188 ------------ .../physical-expr/src/aggregate/build_in.rs | 29 +- datafusion/physical-expr/src/aggregate/mod.rs | 1 - .../physical-expr/src/expressions/mod.rs | 1 - .../proto/src/physical_plan/to_proto.rs | 10 +- datafusion/sql/src/expr/mod.rs | 11 +- 14 files changed, 331 insertions(+), 243 deletions(-) create mode 100644 datafusion/functions-aggregate/src/array_agg.rs delete mode 100644 datafusion/physical-expr/src/aggregate/array_agg.rs diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 398f59e35d10..818c36c23a57 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1658,10 +1658,10 @@ mod tests { use datafusion_common::{Constraint, Constraints}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, - ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition, + cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation, + Volatility, WindowFrame, WindowFunctionDefinition, }; - use datafusion_functions_aggregate::expr_fn::count_distinct; + use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index b539544d8372..a9bca1830dfa 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -92,6 +92,7 @@ use datafusion_expr::{ DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; +use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::LexOrdering; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; @@ -1854,6 +1855,33 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( == NullTreatment::IgnoreNulls; let (agg_expr, filter, order_by) = match func_def { + AggregateFunctionDefinition::BuiltIn( + datafusion_expr::AggregateFunction::ArrayAgg, + ) if !distinct && order_by.is_none() => { + let sort_exprs = order_by.clone().unwrap_or(vec![]); + let physical_sort_exprs = match order_by { + Some(exprs) => Some(create_physical_sort_exprs( + exprs, + logical_input_schema, + execution_props, + )?), + None => None, + }; + let ordering_reqs: Vec = + physical_sort_exprs.clone().unwrap_or(vec![]); + let agg_expr = udaf::create_aggregate_expr( + &array_agg_udaf(), + &physical_args, + args, + &sort_exprs, + &ordering_reqs, + physical_input_schema, + name, + ignore_nulls, + *distinct, + )?; + (agg_expr, filter, physical_sort_exprs) + } AggregateFunctionDefinition::BuiltIn(fun) => { let physical_sort_exprs = match order_by { Some(exprs) => Some(create_physical_sort_exprs( diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index fa364c5f2a65..2d8736b9c47d 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -52,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - array_agg, avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, - placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, + scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion_functions_aggregate::expr_fn::{count, sum}; +use datafusion_functions_aggregate::expr_fn::{array_agg, count, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 84b791a3de05..3acf5f814984 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -37,7 +37,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { Schema::new(vec![Field::new_list( "ARRAY_AGG(DISTINCT aggregate_test_100.c2)", Field::new("item", DataType::UInt32, true), - false + true ),]) ); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index a87412ee6356..9feff05dcb32 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -169,18 +169,6 @@ pub fn max(expr: Expr) -> Expr { )) } -/// Create an expression to represent the array_agg() aggregate function -pub fn array_agg(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ArrayAgg, - vec![expr], - false, - None, - None, - None, - )) -} - /// Create an expression to represent the avg() aggregate function pub fn avg(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 26630a0352d5..3331701844b4 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -40,6 +40,7 @@ path = "src/lib.rs" [dependencies] ahash = { workspace = true } arrow = { workspace = true } +arrow-array = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs new file mode 100644 index 000000000000..27e3c11049f2 --- /dev/null +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -0,0 +1,271 @@ +// Licensed to the Apache Software Foundation (ASF) under on +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use arrow::array::ArrayRef; +use arrow::datatypes::DataType; +use arrow_array::Array; +use arrow_schema::Field; + +use datafusion_common::cast::as_list_array; +use datafusion_common::utils::array_into_list_array; +use datafusion_common::Result; +use datafusion_common::ScalarValue; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::expr::AggregateFunctionDefinition; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::simplify::SimplifyInfo; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::AggregateUDFImpl; +use datafusion_expr::Expr; +use datafusion_expr::{Accumulator, Signature, Volatility}; +use std::sync::Arc; + +make_udaf_expr_and_func!( + ArrayAgg, + array_agg, + expression, + "Computes the nth value", + array_agg_udaf +); + +#[derive(Debug)] +/// ARRAY_AGG aggregate expression +pub struct ArrayAgg { + signature: Signature, + alias: Vec, +} + +impl Default for ArrayAgg { + fn default() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + alias: vec!["array_agg".to_string()], + } + } +} + +impl AggregateUDFImpl for ArrayAgg { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "ARRAY_AGG" + } + + fn aliases(&self) -> &[String] { + &self.alias + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::List(Arc::new(Field::new( + "item", + arg_types[0].clone(), + true, + )))) + } + + fn state_fields( + &self, + args: datafusion_expr::function::StateFieldsArgs, + ) -> Result> { + Ok(vec![Field::new_list( + format_state_name(args.name, "array_agg"), + Field::new("item", args.input_type.clone(), true), + true, + )]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?)) + } + + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Identical + } + + fn simplify( + &self, + ) -> Option { + let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| { + if aggregate_function.order_by.is_some() || aggregate_function.distinct { + Ok(Expr::AggregateFunction(AggregateFunction { + func_def: AggregateFunctionDefinition::BuiltIn( + datafusion_expr::aggregate_function::AggregateFunction::ArrayAgg, + ), + args: aggregate_function.args, + distinct: aggregate_function.distinct, + filter: aggregate_function.filter, + order_by: aggregate_function.order_by, + null_treatment: aggregate_function.null_treatment, + })) + } else { + Ok(Expr::AggregateFunction(aggregate_function)) + } + }; + + Some(Box::new(simplify)) + } +} + +#[derive(Debug)] +pub struct ArrayAggAccumulator { + values: Vec, + datatype: DataType, +} + +impl ArrayAggAccumulator { + /// new array_agg accumulator based on given item data type + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + values: vec![], + datatype: datatype.clone(), + }) + } +} + +impl Accumulator for ArrayAggAccumulator { + // Append value like Int64Array(1,2,3) + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + assert!(values.len() == 1, "array_agg can only take 1 param!"); + let val = values[0].clone(); + self.values.push(val); + Ok(()) + } + + // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + assert!(states.len() == 1, "array_agg states must be singleton!"); + + let list_arr = as_list_array(&states[0])?; + for arr in list_arr.iter().flatten() { + self.values.push(arr); + } + Ok(()) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn evaluate(&mut self) -> Result { + // Transform Vec to ListArr + + let element_arrays: Vec<&dyn Array> = + self.values.iter().map(|a| a.as_ref()).collect(); + + if element_arrays.is_empty() { + let arr = ScalarValue::new_list(&[], &self.datatype); + return Ok(ScalarValue::List(arr)); + } + + let concated_array = arrow::compute::concat(&element_arrays)?; + let list_array = array_into_list_array(concated_array); + + Ok(ScalarValue::List(Arc::new(list_array))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.values.capacity()) + + self + .values + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum::() + + self.datatype.size() + - std::mem::size_of_val(&self.datatype) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::Result; + use datafusion_physical_expr_common::aggregate::create_aggregate_expr; + use datafusion_physical_expr_common::expressions::column::Column; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + + #[test] + fn test_array_agg_expr() -> Result<()> { + let data_types = vec![ + DataType::UInt32, + DataType::Int32, + DataType::Float32, + DataType::Float64, + DataType::Decimal128(10, 2), + DataType::Utf8, + ]; + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &array_agg_udaf(), + &input_phy_exprs[0..1], + &[], + &[], + &[], + &input_schema, + "c1", + false, + false, + )?; + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new_list("c1", Field::new("item", data_type.clone(), true), true,), + result_agg_phy_exprs.field().unwrap() + ); + + let result_distinct = create_aggregate_expr( + &array_agg_udaf(), + &input_phy_exprs[0..1], + &[], + &[], + &[], + &input_schema, + "c1", + false, + true, + )?; + assert_eq!("c1", result_distinct.name()); + assert_eq!( + Field::new_list("c1", Field::new("item", data_type.clone(), true), true,), + result_agg_phy_exprs.field().unwrap() + ); + } + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 260d6dab31b9..e0556d64f768 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -56,6 +56,7 @@ pub mod macros; pub mod approx_distinct; +pub mod array_agg; pub mod count; pub mod covariance; pub mod first_last; @@ -86,6 +87,7 @@ pub mod expr_fn { pub use super::approx_median::approx_median; pub use super::approx_percentile_cont::approx_percentile_cont; pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight; + pub use super::array_agg::array_agg; pub use super::bit_and_or_xor::bit_and; pub use super::bit_and_or_xor::bit_or; pub use super::bit_and_or_xor::bit_xor; @@ -117,6 +119,7 @@ pub mod expr_fn { /// Returns all default aggregate functions pub fn all_default_aggregate_functions() -> Vec> { vec![ + array_agg::array_agg_udaf(), first_last::first_value_udaf(), first_last::last_value_udaf(), covariance::covar_samp_udaf(), @@ -177,7 +180,8 @@ mod tests { for func in all_default_aggregate_functions() { // TODO: remove this // These functions are in intermidiate migration state, skip them - if func.name().to_lowercase() == "count" { + let name_lower_case = func.name().to_lowercase(); + if name_lower_case == "count" || name_lower_case == "array_agg" { continue; } assert!( diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs deleted file mode 100644 index a23ba07de44a..000000000000 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ /dev/null @@ -1,188 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; -use datafusion_common::cast::as_list_array; -use datafusion_common::utils::array_into_list_array; -use datafusion_common::Result; -use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; - -/// ARRAY_AGG aggregate expression -#[derive(Debug)] -pub struct ArrayAgg { - /// Column name - name: String, - /// The DataType for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, - /// If the input expression can have NULLs - nullable: bool, -} - -impl ArrayAgg { - /// Create a new ArrayAgg aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - nullable: bool, - ) -> Self { - Self { - name: name.into(), - input_data_type: data_type, - expr, - nullable, - } - } -} - -impl AggregateExpr for ArrayAgg { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new_list( - &self.name, - // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(ArrayAggAccumulator::try_new( - &self.input_data_type, - )?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new_list( - format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), true), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for ArrayAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -pub(crate) struct ArrayAggAccumulator { - values: Vec, - datatype: DataType, -} - -impl ArrayAggAccumulator { - /// new array_agg accumulator based on given item data type - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - values: vec![], - datatype: datatype.clone(), - }) - } -} - -impl Accumulator for ArrayAggAccumulator { - // Append value like Int64Array(1,2,3) - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - assert!(values.len() == 1, "array_agg can only take 1 param!"); - let val = values[0].clone(); - self.values.push(val); - Ok(()) - } - - // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - assert!(states.len() == 1, "array_agg states must be singleton!"); - - let list_arr = as_list_array(&states[0])?; - for arr in list_arr.iter().flatten() { - self.values.push(arr); - } - Ok(()) - } - - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn evaluate(&mut self) -> Result { - // Transform Vec to ListArr - - let element_arrays: Vec<&dyn Array> = - self.values.iter().map(|a| a.as_ref()).collect(); - - if element_arrays.is_empty() { - let arr = ScalarValue::new_list(&[], &self.datatype); - return Ok(ScalarValue::List(arr)); - } - - let concated_array = arrow::compute::concat(&element_arrays)?; - let list_array = array_into_list_array(concated_array); - - Ok(ScalarValue::List(Arc::new(list_array))) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .map(|arr| arr.get_array_memory_size()) - .sum::() - + self.datatype.size() - - std::mem::size_of_val(&self.datatype) - } -} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 53cfcfb033a1..dffddedbf5fd 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -30,7 +30,7 @@ use std::sync::Arc; use arrow::datatypes::Schema; -use datafusion_common::{exec_err, not_impl_err, Result}; +use datafusion_common::{exec_err, internal_err, not_impl_err, Result}; use datafusion_expr::AggregateFunction; use crate::aggregate::average::Avg; @@ -71,7 +71,9 @@ pub fn create_aggregate_expr( let nullable = expr.nullable(input_schema)?; if ordering_req.is_empty() { - Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable)) + return internal_err!( + "ArrayAgg without ordering should be handled as UDAF" + ); } else { Arc::new(expressions::OrderSensitiveArrayAgg::new( expr, @@ -155,7 +157,7 @@ mod tests { use datafusion_common::plan_err; use datafusion_expr::{type_coercion, Signature}; - use crate::expressions::{try_cast, ArrayAgg, Avg, DistinctArrayAgg, Max, Min}; + use crate::expressions::{try_cast, Avg, DistinctArrayAgg, Max, Min}; use super::*; #[test] @@ -176,25 +178,6 @@ mod tests { let input_phy_exprs: Vec> = vec![Arc::new( expressions::Column::new_with_schema("c1", &input_schema).unwrap(), )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::ArrayAgg { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } let result_distinct = create_physical_agg_expr_for_test( &fun, @@ -212,7 +195,7 @@ mod tests { Field::new("item", data_type.clone(), true), true, ), - result_agg_phy_exprs.field().unwrap() + result_distinct.field().unwrap() ); } } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index f64c5b1fb260..a4aaa7d03951 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -17,7 +17,6 @@ pub use datafusion_physical_expr_common::aggregate::AggregateExpr; -pub(crate) mod array_agg; pub(crate) mod array_agg_distinct; pub(crate) mod array_agg_ordered; pub(crate) mod average; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 0020aa5f55b2..bc5056e3d7ff 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -35,7 +35,6 @@ mod try_cast; pub mod helpers { pub use crate::aggregate::min_max::{max, min}; } -pub use crate::aggregate::array_agg::ArrayAgg; pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; pub use crate::aggregate::average::Avg; diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index a9d3736dee08..b963113a82dc 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,10 +23,10 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - ArrayAgg, Avg, BinaryExpr, CaseExpr, CastExpr, Column, Correlation, CumeDist, - DistinctArrayAgg, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, - NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, - RankType, RowNumber, TryCastExpr, WindowShift, + Avg, BinaryExpr, CaseExpr, CastExpr, Column, Correlation, CumeDist, DistinctArrayAgg, + Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, + NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, + RowNumber, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -240,8 +240,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Grouping - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ArrayAgg } else if aggr_expr.downcast_ref::().is_some() { distinct = true; protobuf::AggregateFunction::ArrayAgg diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 8b64ccfb52cb..4db19b8381b0 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -596,10 +596,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// TODO: this should likely be done in ArrayAgg::simplify when it is moved to a UDAF fn simplify_array_index_expr(expr: &Expr, index: &Expr) -> Option { fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { - agg_func.func_def - == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn( + match agg_func.func_def { + datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn( AggregateFunction::ArrayAgg, - ) + ) => true, + datafusion_expr::expr::AggregateFunctionDefinition::UDF(ref udf) => { + udf.name() == "ARRAY_AGG" + } + _ => false, + } } match expr { Expr::AggregateFunction(agg_func) if is_array_agg(agg_func) => { From 8ec61003e10b8b4b717742bed1ca8a59043966b8 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Fri, 21 Jun 2024 10:00:52 +0200 Subject: [PATCH 3/6] Add roundtrip_expr_api test case --- datafusion/proto/tests/cases/roundtrip_logical_plan.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index b3966c3f0204..cf842453822d 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -60,7 +60,7 @@ use datafusion_expr::{ WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_functions_aggregate::expr_fn::{ - bit_and, bit_or, bit_xor, bool_and, bool_or, + array_agg, bit_and, bit_or, bit_xor, bool_and, bool_or }; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_proto::bytes::{ @@ -675,6 +675,7 @@ async fn roundtrip_expr_api() -> Result<()> { string_agg(col("a").cast_to(&DataType::Utf8, &schema)?, lit("|")), bool_and(lit(true)), bool_or(lit(true)), + array_agg(lit(1)) ]; // ensure expressions created with the expr api can be round tripped From 47ab2535e85766a3c5ba025b0b569f1de9ea5bd9 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Sat, 22 Jun 2024 10:32:47 +0200 Subject: [PATCH 4/6] Address PR comments --- datafusion/core/tests/sql/aggregates.rs | 2 +- datafusion/functions-aggregate/Cargo.toml | 1 - .../functions-aggregate/src/array_agg.rs | 75 +------------------ .../tests/cases/roundtrip_logical_plan.rs | 4 +- 4 files changed, 7 insertions(+), 75 deletions(-) diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 3acf5f814984..84b791a3de05 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -37,7 +37,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { Schema::new(vec![Field::new_list( "ARRAY_AGG(DISTINCT aggregate_test_100.c2)", Field::new("item", DataType::UInt32, true), - true + false ),]) ); diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 3331701844b4..26630a0352d5 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -40,7 +40,6 @@ path = "src/lib.rs" [dependencies] ahash = { workspace = true } arrow = { workspace = true } -arrow-array = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 27e3c11049f2..a0cedf5817ff 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -1,4 +1,4 @@ -// Licensed to the Apache Software Foundation (ASF) under on +// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file @@ -17,9 +17,8 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use arrow::array::ArrayRef; +use arrow::array::{Array, ArrayRef}; use arrow::datatypes::DataType; -use arrow_array::Array; use arrow_schema::Field; use datafusion_common::cast::as_list_array; @@ -40,7 +39,7 @@ make_udaf_expr_and_func!( ArrayAgg, array_agg, expression, - "Computes the nth value", + "input values, including nulls, concatenated into an array", array_agg_udaf ); @@ -92,7 +91,7 @@ impl AggregateUDFImpl for ArrayAgg { Ok(vec![Field::new_list( format_state_name(args.name, "array_agg"), Field::new("item", args.input_type.clone(), true), - true, + args.input_nullable, )]) } @@ -203,69 +202,3 @@ impl Accumulator for ArrayAggAccumulator { - std::mem::size_of_val(&self.datatype) } } - -#[cfg(test)] -mod tests { - use super::*; - - use std::sync::Arc; - - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::Result; - use datafusion_physical_expr_common::aggregate::create_aggregate_expr; - use datafusion_physical_expr_common::expressions::column::Column; - use datafusion_physical_expr_common::physical_expr::PhysicalExpr; - - #[test] - fn test_array_agg_expr() -> Result<()> { - let data_types = vec![ - DataType::UInt32, - DataType::Int32, - DataType::Float32, - DataType::Float64, - DataType::Decimal128(10, 2), - DataType::Utf8, - ]; - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_aggregate_expr( - &array_agg_udaf(), - &input_phy_exprs[0..1], - &[], - &[], - &[], - &input_schema, - "c1", - false, - false, - )?; - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new_list("c1", Field::new("item", data_type.clone(), true), true,), - result_agg_phy_exprs.field().unwrap() - ); - - let result_distinct = create_aggregate_expr( - &array_agg_udaf(), - &input_phy_exprs[0..1], - &[], - &[], - &[], - &input_schema, - "c1", - false, - true, - )?; - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new_list("c1", Field::new("item", data_type.clone(), true), true,), - result_agg_phy_exprs.field().unwrap() - ); - } - Ok(()) - } -} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index cf842453822d..95e75f825cfd 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -60,7 +60,7 @@ use datafusion_expr::{ WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_functions_aggregate::expr_fn::{ - array_agg, bit_and, bit_or, bit_xor, bool_and, bool_or + array_agg, bit_and, bit_or, bit_xor, bool_and, bool_or, }; use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_proto::bytes::{ @@ -675,7 +675,7 @@ async fn roundtrip_expr_api() -> Result<()> { string_agg(col("a").cast_to(&DataType::Utf8, &schema)?, lit("|")), bool_and(lit(true)), bool_or(lit(true)), - array_agg(lit(1)) + array_agg(lit(1)), ]; // ensure expressions created with the expr api can be round tripped From 1a690459426980c6b341b624beb5116f7398f65c Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Sat, 22 Jun 2024 17:39:59 +0200 Subject: [PATCH 5/6] Propegate input nullability for aggregates --- datafusion/physical-expr-common/src/aggregate/mod.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index e82601efa58e..e37ef7bd71ab 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -288,7 +288,11 @@ impl AggregateExpr for AggregateFunctionExpr { } fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) + Ok(Field::new( + &self.name, + self.data_type.clone(), + self.input_nullable, + )) } fn create_accumulator(&self) -> Result> { From a6ea4045f22fa2a7d327fed1698556f5f0a41215 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Sat, 22 Jun 2024 17:42:20 +0200 Subject: [PATCH 6/6] Remove from accumulator args --- datafusion/expr/src/function.rs | 3 --- datafusion/physical-expr-common/src/aggregate/mod.rs | 6 +----- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 53dfa2c460d0..7198f17a9df9 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -83,9 +83,6 @@ pub struct AccumulatorArgs<'a> { /// The input type of the aggregate function. pub input_type: &'a DataType, - /// If the input type is nullable. - pub input_nullable: bool, - /// The logical expression of arguments the aggregate function takes. pub input_exprs: &'a [Expr], } diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index e37ef7bd71ab..1d1f2c44d5c9 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -87,7 +87,7 @@ pub fn create_aggregate_expr( ordering_fields, is_distinct, input_type: input_exprs_types[0].clone(), - input_nullable: input_phy_exprs[0].nullable(&schema)?, + input_nullable: input_phy_exprs[0].nullable(schema)?, })) } @@ -303,7 +303,6 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - input_nullable: self.input_nullable, input_exprs: &self.logical_args, name: &self.name, }; @@ -319,7 +318,6 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - input_nullable: self.input_nullable, input_exprs: &self.logical_args, name: &self.name, }; @@ -390,7 +388,6 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - input_nullable: self.input_nullable, input_exprs: &self.logical_args, name: &self.name, }; @@ -405,7 +402,6 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - input_nullable: self.input_nullable, input_exprs: &self.logical_args, name: &self.name, };