From 5f2b0c26107ba0691f7ae3fe976d9e3f81ab2d00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Mon, 17 Jun 2024 10:38:12 +0800 Subject: [PATCH 01/10] Convert StringAgg to UDAF --- datafusion/expr/src/aggregate_function.rs | 8 - .../expr/src/type_coercion/aggregates.rs | 26 -- datafusion/functions-aggregate/src/lib.rs | 1 + .../functions-aggregate/src/string_agg.rs | 141 ++++++++++ .../physical-expr/src/aggregate/build_in.rs | 16 -- datafusion/physical-expr/src/aggregate/mod.rs | 1 - .../physical-expr/src/aggregate/string_agg.rs | 246 ------------------ .../physical-expr/src/expressions/mod.rs | 1 - datafusion/proto/proto/datafusion.proto | 2 +- .../proto/src/logical_plan/from_proto.rs | 1 - datafusion/proto/src/logical_plan/to_proto.rs | 4 - .../proto/src/physical_plan/to_proto.rs | 4 +- .../tests/cases/roundtrip_physical_plan.rs | 9 +- 13 files changed, 145 insertions(+), 315 deletions(-) create mode 100644 datafusion/functions-aggregate/src/string_agg.rs delete mode 100644 datafusion/physical-expr/src/aggregate/string_agg.rs diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 441e8953dffc..a7e9c30de11f 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -57,8 +57,6 @@ pub enum AggregateFunction { BoolAnd, /// Bool Or BoolOr, - /// String aggregation - StringAgg, } impl AggregateFunction { @@ -77,7 +75,6 @@ impl AggregateFunction { BitXor => "BIT_XOR", BoolAnd => "BOOL_AND", BoolOr => "BOOL_OR", - StringAgg => "STRING_AGG", } } } @@ -104,7 +101,6 @@ impl FromStr for AggregateFunction { "min" => AggregateFunction::Min, "array_agg" => AggregateFunction::ArrayAgg, "nth_value" => AggregateFunction::NthValue, - "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, // other @@ -161,7 +157,6 @@ impl AggregateFunction { )))), AggregateFunction::Grouping => Ok(DataType::Int32), AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()), - AggregateFunction::StringAgg => Ok(DataType::LargeUtf8), } } } @@ -215,9 +210,6 @@ impl AggregateFunction { AggregateFunction::Correlation => { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } - AggregateFunction::StringAgg => { - Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable) - } } } } diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 98324ed6120b..cde96b93d48a 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -159,23 +159,6 @@ pub fn coerce_types( } AggregateFunction::NthValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), - AggregateFunction::StringAgg => { - if !is_string_agg_supported_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}", - agg_fun, - input_types[0] - ); - } - if !is_string_agg_supported_arg_type(&input_types[1]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}", - agg_fun, - input_types[1] - ); - } - Ok(vec![LargeUtf8, input_types[1].clone()]) - } } } @@ -409,15 +392,6 @@ pub fn is_integer_arg_type(arg_type: &DataType) -> bool { arg_type.is_integer() } -/// Return `true` if `arg_type` is of a [`DataType`] that the -/// [`AggregateFunction::StringAgg`] aggregation can operate on. -pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Null - ) -} - #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index daddb9d93f78..7925106b4616 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -69,6 +69,7 @@ pub mod variance; pub mod approx_median; pub mod approx_percentile_cont; pub mod approx_percentile_cont_with_weight; +pub mod string_agg; use crate::approx_percentile_cont::approx_percentile_cont_udaf; use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf; diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs new file mode 100644 index 000000000000..0e800837ac65 --- /dev/null +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -0,0 +1,141 @@ +use arrow::array::ArrayRef; +use arrow_schema::DataType; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::Result; +use datafusion_common::{not_impl_err, ScalarValue}; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Expr, Signature, TypeSignature, Volatility, +}; +use std::any::Any; + +/// STRING_AGG aggregate expression +#[derive(Debug)] +pub struct StringAgg { + signature: Signature, + aliases: Vec, +} + +impl StringAgg { + /// Create a new StringAgg aggregate function + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Uniform(2, vec![DataType::LargeUtf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]), + TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]), + ], + Volatility::Immutable, + ), + aliases: vec!["string_agg".to_string()], + } + } +} + +impl Default for StringAgg { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for StringAgg { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "STRING_AGG" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::LargeUtf8) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + return match &acc_args.input_exprs[1] { + Expr::Literal(ScalarValue::Utf8(Some(delimiter))) + | Expr::Literal(ScalarValue::LargeUtf8(Some(delimiter))) => { + Ok(Box::new(StringAggAccumulator::new(delimiter))) + } + Expr::Literal(ScalarValue::Null) => { + Ok(Box::new(StringAggAccumulator::new(""))) + } + _ => not_impl_err!( + "StringAgg not supported for delimiter {}", + &acc_args.input_exprs[1] + ), + }; + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +#[derive(Debug)] +pub(crate) struct StringAggAccumulator { + values: Option, + delimiter: String, +} + +impl StringAggAccumulator { + pub fn new(delimiter: &str) -> Self { + Self { + values: None, + delimiter: delimiter.to_string(), + } + } +} + +impl Accumulator for StringAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let string_array: Vec<_> = as_generic_string_array::(&values[0])? + .iter() + .filter_map(|v| v.as_ref().map(ToString::to_string)) + .collect(); + if !string_array.is_empty() { + let s = string_array.join(self.delimiter.as_str()); + let v = self.values.get_or_insert("".to_string()); + if !v.is_empty() { + v.push_str(self.delimiter.as_str()); + } + v.push_str(s.as_str()); + } + Ok(()) + } + + fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.update_batch(values)?; + Ok(()) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::LargeUtf8(self.values.clone())) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) + + self.delimiter.capacity() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::*; + use arrow::record_batch::RecordBatch; + use datafusion_expr::type_coercion::aggregates::coerce_types; + use datafusion_expr::AggregateFunction; +} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index a1f5f153a9ff..ef1ac8e0cc10 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -175,22 +175,6 @@ pub fn create_aggregate_expr( ordering_req.to_vec(), )) } - (AggregateFunction::StringAgg, false) => { - if !ordering_req.is_empty() { - return not_impl_err!( - "STRING_AGG(ORDER BY a ASC) order-sensitive aggregations are not available" - ); - } - Arc::new(expressions::StringAgg::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - data_type, - )) - } - (AggregateFunction::StringAgg, true) => { - return not_impl_err!("STRING_AGG(DISTINCT) aggregations are not available"); - } }) } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index c20902c11b86..b760a3903fcb 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -27,7 +27,6 @@ pub(crate) mod correlation; pub(crate) mod covariance; pub(crate) mod grouping; pub(crate) mod nth_value; -pub(crate) mod string_agg; #[macro_use] pub(crate) mod min_max; pub(crate) mod groups_accumulator; diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs deleted file mode 100644 index dc0ffc557968..000000000000 --- a/datafusion/physical-expr/src/aggregate/string_agg.rs +++ /dev/null @@ -1,246 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the `string_agg` function - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::{format_state_name, Literal}; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::cast::as_generic_string_array; -use datafusion_common::{not_impl_err, Result, ScalarValue}; -use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; - -/// STRING_AGG aggregate expression -#[derive(Debug)] -pub struct StringAgg { - name: String, - data_type: DataType, - expr: Arc, - delimiter: Arc, - nullable: bool, -} - -impl StringAgg { - /// Create a new StringAgg aggregate function - pub fn new( - expr: Arc, - delimiter: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - data_type, - delimiter, - expr, - nullable: true, - } - } -} - -impl AggregateExpr for StringAgg { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - if let Some(delimiter) = self.delimiter.as_any().downcast_ref::() { - match delimiter.value() { - ScalarValue::Utf8(Some(delimiter)) - | ScalarValue::LargeUtf8(Some(delimiter)) => { - return Ok(Box::new(StringAggAccumulator::new(delimiter))); - } - ScalarValue::Null => { - return Ok(Box::new(StringAggAccumulator::new(""))); - } - _ => return not_impl_err!("StringAgg not supported for {}", self.name), - } - } - not_impl_err!("StringAgg not supported for {}", self.name) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "string_agg"), - self.data_type.clone(), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone(), self.delimiter.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for StringAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.expr.eq(&x.expr) - && self.delimiter.eq(&x.delimiter) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -pub(crate) struct StringAggAccumulator { - values: Option, - delimiter: String, -} - -impl StringAggAccumulator { - pub fn new(delimiter: &str) -> Self { - Self { - values: None, - delimiter: delimiter.to_string(), - } - } -} - -impl Accumulator for StringAggAccumulator { - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let string_array: Vec<_> = as_generic_string_array::(&values[0])? - .iter() - .filter_map(|v| v.as_ref().map(ToString::to_string)) - .collect(); - if !string_array.is_empty() { - let s = string_array.join(self.delimiter.as_str()); - let v = self.values.get_or_insert("".to_string()); - if !v.is_empty() { - v.push_str(self.delimiter.as_str()); - } - v.push_str(s.as_str()); - } - Ok(()) - } - - fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - self.update_batch(values)?; - Ok(()) - } - - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::LargeUtf8(self.values.clone())) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) - + self.delimiter.capacity() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::tests::aggregate; - use crate::expressions::{col, create_aggregate_expr, try_cast}; - use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; - use arrow_array::LargeStringArray; - use arrow_array::StringArray; - use datafusion_expr::type_coercion::aggregates::coerce_types; - use datafusion_expr::AggregateFunction; - - fn assert_string_aggregate( - array: ArrayRef, - function: AggregateFunction, - distinct: bool, - expected: ScalarValue, - delimiter: String, - ) { - let data_type = array.data_type(); - let sig = function.signature(); - let coerced = - coerce_types(&function, &[data_type.clone(), DataType::Utf8], &sig).unwrap(); - - let input_schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]); - let batch = - RecordBatch::try_new(Arc::new(input_schema.clone()), vec![array]).unwrap(); - - let input = try_cast( - col("a", &input_schema).unwrap(), - &input_schema, - coerced[0].clone(), - ) - .unwrap(); - - let delimiter = Arc::new(Literal::new(ScalarValue::from(delimiter))); - let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]); - let agg = create_aggregate_expr( - &function, - distinct, - &[input, delimiter], - &[], - &schema, - "agg", - false, - ) - .unwrap(); - - let result = aggregate(&batch, agg).unwrap(); - assert_eq!(expected, result); - } - - #[test] - fn string_agg_utf8() { - let a: ArrayRef = Arc::new(StringArray::from(vec!["h", "e", "l", "l", "o"])); - assert_string_aggregate( - a, - AggregateFunction::StringAgg, - false, - ScalarValue::LargeUtf8(Some("h,e,l,l,o".to_owned())), - ",".to_owned(), - ); - } - - #[test] - fn string_agg_largeutf8() { - let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["h", "e", "l", "l", "o"])); - assert_string_aggregate( - a, - AggregateFunction::StringAgg, - false, - ScalarValue::LargeUtf8(Some("h|e|l|l|o".to_owned())), - "|".to_owned(), - ); - } -} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index b9a159b21e3d..1b59c4b97e71 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -48,7 +48,6 @@ pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::nth_value::NthValueAgg; pub use crate::aggregate::stats::StatsType; -pub use crate::aggregate::string_agg::StringAgg; pub use crate::window::cume_dist::{cume_dist, CumeDist}; pub use crate::window::lead_lag::{lag, lead, WindowShift}; pub use crate::window::nth_value::NthValue; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index e5578ae62f3e..74418677b9dd 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -505,7 +505,7 @@ enum AggregateFunction { // REGR_SXX = 32; // REGR_SYY = 33; // REGR_SXY = 34; - STRING_AGG = 35; + // STRING_AGG = 35; NTH_VALUE_AGG = 36; } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 25b7413a984a..1a1c0067707f 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -149,7 +149,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Correlation => Self::Correlation, protobuf::AggregateFunction::Grouping => Self::Grouping, protobuf::AggregateFunction::NthValueAgg => Self::NthValue, - protobuf::AggregateFunction::StringAgg => Self::StringAgg, } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index d9548325dac3..674f07d0551f 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -120,7 +120,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Correlation => Self::Correlation, AggregateFunction::Grouping => Self::Grouping, AggregateFunction::NthValue => Self::NthValueAgg, - AggregateFunction::StringAgg => Self::StringAgg, } } } @@ -393,9 +392,6 @@ pub fn serialize_expr( AggregateFunction::NthValue => { protobuf::AggregateFunction::NthValueAgg } - AggregateFunction::StringAgg => { - protobuf::AggregateFunction::StringAgg - } }; let aggregate_expr = protobuf::AggregateExprNode { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 3a4c35a93e16..1c9bd45acfb5 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -27,7 +27,7 @@ use datafusion::physical_plan::expressions::{ CastExpr, Column, Correlation, CumeDist, DistinctArrayAgg, DistinctBitXor, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, - StringAgg, TryCastExpr, WindowShift, + TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -269,8 +269,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::Avg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Correlation - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::StringAgg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::NthValueAgg } else { diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 7f66cdbf7663..5c1bd7777b67 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -48,7 +48,7 @@ use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, NotExpr, NthValue, - PhysicalSortExpr, StringAgg, + PhysicalSortExpr, }; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::insert::DataSinkExec; @@ -356,13 +356,6 @@ fn rountrip_aggregate() -> Result<()> { Vec::new(), Vec::new(), ))], - // STRING_AGG - vec![Arc::new(StringAgg::new( - cast(col("b", &schema)?, &schema, DataType::Utf8)?, - lit(ScalarValue::Utf8(Some(",".to_string()))), - "STRING_AGG(name, ',')".to_string(), - DataType::Utf8, - ))], ]; for aggregates in test_cases { From e86c860c34ff88418e6285daf3248cc518487249 Mon Sep 17 00:00:00 2001 From: zhanglw Date: Mon, 17 Jun 2024 11:35:17 +0800 Subject: [PATCH 02/10] generate proto code --- datafusion/functions-aggregate/src/lib.rs | 1 + .../functions-aggregate/src/string_agg.rs | 21 +++++++++---------- datafusion/proto/src/generated/pbjson.rs | 3 --- datafusion/proto/src/generated/prost.rs | 4 +--- .../sqllogictest/test_files/aggregate.slt | 16 ++++++++++++++ 5 files changed, 28 insertions(+), 17 deletions(-) diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 7925106b4616..dd9e4c3156f7 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -135,6 +135,7 @@ pub fn all_default_aggregate_functions() -> Vec> { approx_distinct::approx_distinct_udaf(), approx_percentile_cont_udaf(), approx_percentile_cont_with_weight_udaf(), + string_agg::string_agg_udaf(), ] } diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 0e800837ac65..0361e8322a53 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -9,6 +9,14 @@ use datafusion_expr::{ }; use std::any::Any; +make_udaf_expr_and_func!( + StringAgg, + string_agg, + expression, + "Concatenates the values of string expressions and places separator values between them", + string_agg_udaf +); + /// STRING_AGG aggregate expression #[derive(Debug)] pub struct StringAgg { @@ -30,7 +38,7 @@ impl StringAgg { ], Volatility::Immutable, ), - aliases: vec!["string_agg".to_string()], + aliases: vec![], } } } @@ -47,7 +55,7 @@ impl AggregateUDFImpl for StringAgg { } fn name(&self) -> &str { - "STRING_AGG" + "string_agg" } fn signature(&self) -> &Signature { @@ -130,12 +138,3 @@ impl Accumulator for StringAggAccumulator { + self.delimiter.capacity() } } - -#[cfg(test)] -mod tests { - use super::*; - use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; - use datafusion_expr::type_coercion::aggregates::coerce_types; - use datafusion_expr::AggregateFunction; -} diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 4a7b9610e5bc..394635b5e302 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -543,7 +543,6 @@ impl serde::Serialize for AggregateFunction { Self::BitXor => "BIT_XOR", Self::BoolAnd => "BOOL_AND", Self::BoolOr => "BOOL_OR", - Self::StringAgg => "STRING_AGG", Self::NthValueAgg => "NTH_VALUE_AGG", }; serializer.serialize_str(variant) @@ -567,7 +566,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "BIT_XOR", "BOOL_AND", "BOOL_OR", - "STRING_AGG", "NTH_VALUE_AGG", ]; @@ -620,7 +618,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "BIT_XOR" => Ok(AggregateFunction::BitXor), "BOOL_AND" => Ok(AggregateFunction::BoolAnd), "BOOL_OR" => Ok(AggregateFunction::BoolOr), - "STRING_AGG" => Ok(AggregateFunction::StringAgg), "NTH_VALUE_AGG" => Ok(AggregateFunction::NthValueAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ffaef445d668..776d50340746 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1959,7 +1959,7 @@ pub enum AggregateFunction { /// REGR_SXX = 32; /// REGR_SYY = 33; /// REGR_SXY = 34; - StringAgg = 35, + /// STRING_AGG = 35; NthValueAgg = 36, } impl AggregateFunction { @@ -1980,7 +1980,6 @@ impl AggregateFunction { AggregateFunction::BitXor => "BIT_XOR", AggregateFunction::BoolAnd => "BOOL_AND", AggregateFunction::BoolOr => "BOOL_OR", - AggregateFunction::StringAgg => "STRING_AGG", AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", } } @@ -1998,7 +1997,6 @@ impl AggregateFunction { "BIT_XOR" => Some(Self::BitXor), "BOOL_AND" => Some(Self::BoolAnd), "BOOL_OR" => Some(Self::BoolOr), - "STRING_AGG" => Some(Self::StringAgg), "NTH_VALUE_AGG" => Some(Self::NthValueAgg), _ => None, } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 0a6def3d6f27..378cab206240 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -4972,6 +4972,22 @@ CREATE TABLE float_table ( ( 32768.3, arrow_cast('NAN','Float32'), 32768.3, 32768.3 ), ( 27.3, 27.3, 27.3, arrow_cast('NAN','Float64') ); +# Test string_agg with largeutf8 +statement ok +create table string_agg_large_utf8 (c string) as values + (arrow_cast('a', 'LargeUtf8')), + (arrow_cast('b', 'LargeUtf8')), + (arrow_cast('c', 'LargeUtf8')) +; + +query T +SELECT STRING_AGG(c, ',') FROM string_agg_large_utf8; +---- +a,b,c + +statement ok +drop table string_agg_large_utf8; + query RRRRI select min(col_f32), max(col_f32), avg(col_f32), sum(col_f32), count(col_f32) from float_table; ---- From 42ee1eb3671569857aed4347cae702d827588967 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Mon, 17 Jun 2024 13:59:34 +0800 Subject: [PATCH 03/10] Fix bug --- .../functions-aggregate/src/string_agg.rs | 32 ++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 0361e8322a53..1e150296d790 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -69,12 +69,13 @@ impl AggregateUDFImpl for StringAgg { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { return match &acc_args.input_exprs[1] { Expr::Literal(ScalarValue::Utf8(Some(delimiter))) - | Expr::Literal(ScalarValue::LargeUtf8(Some(delimiter))) => { - Ok(Box::new(StringAggAccumulator::new(delimiter))) - } - Expr::Literal(ScalarValue::Null) => { - Ok(Box::new(StringAggAccumulator::new(""))) - } + | Expr::Literal(ScalarValue::LargeUtf8(Some(delimiter))) => Ok(Box::new( + StringAggAccumulator::new(delimiter, acc_args.input_type.clone()), + )), + Expr::Literal(ScalarValue::Null) => Ok(Box::new(StringAggAccumulator::new( + "", + acc_args.input_type.clone(), + ))), _ => not_impl_err!( "StringAgg not supported for delimiter {}", &acc_args.input_exprs[1] @@ -91,23 +92,32 @@ impl AggregateUDFImpl for StringAgg { pub(crate) struct StringAggAccumulator { values: Option, delimiter: String, + input_type: DataType, } impl StringAggAccumulator { - pub fn new(delimiter: &str) -> Self { + pub fn new(delimiter: &str, input_type: DataType) -> Self { Self { values: None, delimiter: delimiter.to_string(), + input_type, } } } impl Accumulator for StringAggAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let string_array: Vec<_> = as_generic_string_array::(&values[0])? - .iter() - .filter_map(|v| v.as_ref().map(ToString::to_string)) - .collect(); + let string_array: Vec<_> = match self.input_type { + DataType::Utf8 => as_generic_string_array::(&values[0])? + .iter() + .filter_map(|v| v.as_ref().map(ToString::to_string)) + .collect(), + DataType::LargeUtf8 => as_generic_string_array::(&values[0])? + .iter() + .filter_map(|v| v.as_ref().map(ToString::to_string)) + .collect(), + _ => unreachable!(), + }; if !string_array.is_empty() { let s = string_array.join(self.delimiter.as_str()); let v = self.values.get_or_insert("".to_string()); From 7a93f6749ce7f58e426bea37f19aca30447784e4 Mon Sep 17 00:00:00 2001 From: zhanglw Date: Mon, 17 Jun 2024 14:27:17 +0800 Subject: [PATCH 04/10] Fix --- .../functions-aggregate/src/string_agg.rs | 38 +++++++------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 1e150296d790..7b1ef7fbd6fa 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -30,11 +30,9 @@ impl StringAgg { Self { signature: Signature::one_of( vec![ - TypeSignature::Uniform(2, vec![DataType::LargeUtf8, DataType::Utf8]), TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]), - TypeSignature::Exact(vec![DataType::Utf8, DataType::LargeUtf8]), - TypeSignature::Exact(vec![DataType::Utf8, DataType::Null]), ], Volatility::Immutable, ), @@ -69,13 +67,14 @@ impl AggregateUDFImpl for StringAgg { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { return match &acc_args.input_exprs[1] { Expr::Literal(ScalarValue::Utf8(Some(delimiter))) - | Expr::Literal(ScalarValue::LargeUtf8(Some(delimiter))) => Ok(Box::new( - StringAggAccumulator::new(delimiter, acc_args.input_type.clone()), - )), - Expr::Literal(ScalarValue::Null) => Ok(Box::new(StringAggAccumulator::new( - "", - acc_args.input_type.clone(), - ))), + | Expr::Literal(ScalarValue::LargeUtf8(Some(delimiter))) => { + Ok(Box::new(StringAggAccumulator::new(delimiter))) + } + Expr::Literal(ScalarValue::Utf8(None)) + | Expr::Literal(ScalarValue::LargeUtf8(None)) + | Expr::Literal(ScalarValue::Null) => { + Ok(Box::new(StringAggAccumulator::new(""))) + } _ => not_impl_err!( "StringAgg not supported for delimiter {}", &acc_args.input_exprs[1] @@ -92,32 +91,23 @@ impl AggregateUDFImpl for StringAgg { pub(crate) struct StringAggAccumulator { values: Option, delimiter: String, - input_type: DataType, } impl StringAggAccumulator { - pub fn new(delimiter: &str, input_type: DataType) -> Self { + pub fn new(delimiter: &str) -> Self { Self { values: None, delimiter: delimiter.to_string(), - input_type, } } } impl Accumulator for StringAggAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let string_array: Vec<_> = match self.input_type { - DataType::Utf8 => as_generic_string_array::(&values[0])? - .iter() - .filter_map(|v| v.as_ref().map(ToString::to_string)) - .collect(), - DataType::LargeUtf8 => as_generic_string_array::(&values[0])? - .iter() - .filter_map(|v| v.as_ref().map(ToString::to_string)) - .collect(), - _ => unreachable!(), - }; + let string_array: Vec<_> = as_generic_string_array::(&values[0])? + .iter() + .filter_map(|v| v.as_ref().map(ToString::to_string)) + .collect(); if !string_array.is_empty() { let s = string_array.join(self.delimiter.as_str()); let v = self.values.get_or_insert("".to_string()); From 5ffef4b30723dcb533b0fac1fcfdcd1048fec8ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Mon, 17 Jun 2024 14:31:39 +0800 Subject: [PATCH 05/10] Add license --- .../functions-aggregate/src/string_agg.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 7b1ef7fbd6fa..48e04b7bc839 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use arrow::array::ArrayRef; use arrow_schema::DataType; use datafusion_common::cast::as_generic_string_array; From d297f9bab824391eb3b2992cb5c4de6639fb4704 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Mon, 17 Jun 2024 14:33:45 +0800 Subject: [PATCH 06/10] Add doc --- datafusion/functions-aggregate/src/string_agg.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 48e04b7bc839..4132d918d967 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the `string_agg` function + use arrow::array::ArrayRef; use arrow_schema::DataType; use datafusion_common::cast::as_generic_string_array; From 97f63fe750f3c457addd519f81e1d265276e3f79 Mon Sep 17 00:00:00 2001 From: zhanglw Date: Mon, 17 Jun 2024 14:58:58 +0800 Subject: [PATCH 07/10] Fix clippy --- datafusion/functions-aggregate/src/string_agg.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 4132d918d967..882dec21f913 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -84,7 +84,7 @@ impl AggregateUDFImpl for StringAgg { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - return match &acc_args.input_exprs[1] { + match &acc_args.input_exprs[1] { Expr::Literal(ScalarValue::Utf8(Some(delimiter))) | Expr::Literal(ScalarValue::LargeUtf8(Some(delimiter))) => { Ok(Box::new(StringAggAccumulator::new(delimiter))) @@ -98,7 +98,7 @@ impl AggregateUDFImpl for StringAgg { "StringAgg not supported for delimiter {}", &acc_args.input_exprs[1] ), - }; + } } fn aliases(&self) -> &[String] { From 073be8c12ddbee344f6b563546a8502f35f0325c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Tue, 18 Jun 2024 11:10:49 +0800 Subject: [PATCH 08/10] Remove aliases field --- datafusion/functions-aggregate/src/string_agg.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 882dec21f913..5a4a12390b02 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -40,7 +40,6 @@ make_udaf_expr_and_func!( #[derive(Debug)] pub struct StringAgg { signature: Signature, - aliases: Vec, } impl StringAgg { @@ -55,7 +54,6 @@ impl StringAgg { ], Volatility::Immutable, ), - aliases: vec![], } } } @@ -100,10 +98,6 @@ impl AggregateUDFImpl for StringAgg { ), } } - - fn aliases(&self) -> &[String] { - &self.aliases - } } #[derive(Debug)] From 7879e82a25620900e65500c444ca382df4760fcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Tue, 18 Jun 2024 11:25:18 +0800 Subject: [PATCH 09/10] Add StringAgg proto test --- .../proto/tests/cases/roundtrip_physical_plan.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 5c1bd7777b67..eb3313239544 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -79,6 +79,7 @@ use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, }; +use datafusion_functions_aggregate::string_agg::StringAgg; use datafusion_proto::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; @@ -356,6 +357,21 @@ fn rountrip_aggregate() -> Result<()> { Vec::new(), Vec::new(), ))], + // STRING_AGG + vec![udaf::create_aggregate_expr( + &AggregateUDF::new_from_impl(StringAgg::new()), + &[ + cast(col("b", &schema)?, &schema, DataType::Utf8)?, + lit(ScalarValue::Utf8(Some(",".to_string()))), + ], + &[], + &[], + &[], + &schema, + "STRING_AGG(name, ',')", + false, + false, + )?], ]; for aggregates in test_cases { From fd5ba17f6c51cbd58eb681ab419524941141b718 Mon Sep 17 00:00:00 2001 From: zhanglw Date: Tue, 18 Jun 2024 13:54:02 +0800 Subject: [PATCH 10/10] Add roundtrip_expr_api test --- datafusion/functions-aggregate/src/string_agg.rs | 2 +- datafusion/proto/tests/cases/roundtrip_logical_plan.rs | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 5a4a12390b02..371cc8fb9739 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -31,7 +31,7 @@ use std::any::Any; make_udaf_expr_and_func!( StringAgg, string_agg, - expression, + expr delimiter, "Concatenates the values of string expressions and places separator values between them", string_agg_udaf ); diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 52696a106183..61764394ee74 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -60,6 +60,7 @@ use datafusion_expr::{ WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_functions_aggregate::expr_fn::{bit_and, bit_or, bit_xor}; +use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, @@ -669,6 +670,7 @@ async fn roundtrip_expr_api() -> Result<()> { bit_and(lit(2)), bit_or(lit(2)), bit_xor(lit(2)), + string_agg(col("a").cast_to(&DataType::Utf8, &schema)?, lit("|")), ]; // ensure expressions created with the expr api can be round tripped