Skip to content

Commit

Permalink
Remove ArrayAgg Builtin in favor of UDF (#11611)
Browse files Browse the repository at this point in the history
* rm def

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rewrite test

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 committed Jul 23, 2024
1 parent 6d8bd2c commit fc8e7b9
Show file tree
Hide file tree
Showing 18 changed files with 100 additions and 130 deletions.
4 changes: 2 additions & 2 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1389,7 +1389,7 @@ async fn unnest_with_redundant_columns() -> Result<()> {
let expected = vec![
"Projection: shapes.shape_id [shape_id:UInt32]",
" Unnest: lists[shape_id2] structs[] [shape_id:UInt32, shape_id2:UInt32;N]",
" Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]",
" Aggregate: groupBy=[[shapes.shape_id]], aggr=[[array_agg(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]",
" TableScan: shapes projection=[shape_id] [shape_id:UInt32]",
];

Expand Down Expand Up @@ -1973,7 +1973,7 @@ async fn test_array_agg() -> Result<()> {

let expected = [
"+-------------------------------------+",
"| ARRAY_AGG(test.a) |",
"| array_agg(test.a) |",
"+-------------------------------------+",
"| [abcDEF, abc123, CBAdef, 123AbcDef] |",
"+-------------------------------------+",
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
assert_eq!(
*actual[0].schema(),
Schema::new(vec![Field::new_list(
"ARRAY_AGG(DISTINCT aggregate_test_100.c2)",
"array_agg(DISTINCT aggregate_test_100.c2)",
Field::new("item", DataType::UInt32, true),
true
),])
Expand Down
16 changes: 2 additions & 14 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

//! Aggregate function module contains all built-in aggregate functions definitions

use std::sync::Arc;
use std::{fmt, str::FromStr};

use crate::utils;
use crate::{type_coercion::aggregates::*, Signature, Volatility};

use arrow::datatypes::{DataType, Field};
use arrow::datatypes::DataType;
use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result};

use strum_macros::EnumIter;
Expand All @@ -37,8 +36,6 @@ pub enum AggregateFunction {
Min,
/// Maximum
Max,
/// Aggregation into an array
ArrayAgg,
}

impl AggregateFunction {
Expand All @@ -47,7 +44,6 @@ impl AggregateFunction {
match self {
Min => "MIN",
Max => "MAX",
ArrayAgg => "ARRAY_AGG",
}
}
}
Expand All @@ -65,7 +61,6 @@ impl FromStr for AggregateFunction {
// general
"max" => AggregateFunction::Max,
"min" => AggregateFunction::Min,
"array_agg" => AggregateFunction::ArrayAgg,
_ => {
return plan_err!("There is no built-in function named {name}");
}
Expand All @@ -80,7 +75,7 @@ impl AggregateFunction {
pub fn return_type(
&self,
input_expr_types: &[DataType],
input_expr_nullable: &[bool],
_input_expr_nullable: &[bool],
) -> Result<DataType> {
// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.
Expand All @@ -105,11 +100,6 @@ impl AggregateFunction {
// The coerced_data_types is same with input_types.
Ok(coerced_data_types[0].clone())
}
AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new(
"item",
coerced_data_types[0].clone(),
input_expr_nullable[0],
)))),
}
}

Expand All @@ -118,7 +108,6 @@ impl AggregateFunction {
pub fn nullable(&self) -> Result<bool> {
match self {
AggregateFunction::Max | AggregateFunction::Min => Ok(true),
AggregateFunction::ArrayAgg => Ok(true),
}
}
}
Expand All @@ -128,7 +117,6 @@ impl AggregateFunction {
pub fn signature(&self) -> Signature {
// note: the physical expression must accept the type returned by this function or the execution panics.
match self {
AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable),
AggregateFunction::Min | AggregateFunction::Max => {
let valid = STRINGS
.iter()
Expand Down
7 changes: 1 addition & 6 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ pub fn coerce_types(
check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?;

match agg_fun {
AggregateFunction::ArrayAgg => Ok(input_types.to_vec()),
AggregateFunction::Min | AggregateFunction::Max => {
// min and max support the dictionary data type
// unpack the dictionary to get the value
Expand Down Expand Up @@ -360,11 +359,7 @@ mod tests {

// test count, array_agg, approx_distinct, min, max.
// the coerced types is same with input types
let funs = vec![
AggregateFunction::ArrayAgg,
AggregateFunction::Min,
AggregateFunction::Max,
];
let funs = vec![AggregateFunction::Min, AggregateFunction::Max];
let input_types = vec![
vec![DataType::Int32],
vec![DataType::Decimal128(10, 2)],
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
pub enum ReversedUDAF {
/// The expression is the same as the original expression, like SUM, COUNT
Identical,
/// The expression does not support reverse calculation, like ArrayAgg
/// The expression does not support reverse calculation
NotSupported,
/// The expression is different from the original expression
Reversed(Arc<AggregateUDF>),
Expand Down
7 changes: 2 additions & 5 deletions datafusion/functions-aggregate/src/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,12 @@ make_udaf_expr_and_func!(
/// ARRAY_AGG aggregate expression
pub struct ArrayAgg {
signature: Signature,
alias: Vec<String>,
}

impl Default for ArrayAgg {
fn default() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
alias: vec!["array_agg".to_string()],
}
}
}
Expand All @@ -67,13 +65,12 @@ impl AggregateUDFImpl for ArrayAgg {
self
}

// TODO: change name to lowercase
fn name(&self) -> &str {
"ARRAY_AGG"
"array_agg"
}

fn aliases(&self) -> &[String] {
&self.alias
&[]
}

fn signature(&self) -> &Signature {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-array/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl ExprPlanner for FieldAccessPlanner {

fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool {
if let AggregateFunctionDefinition::UDF(udf) = &agg_func.func_def {
return udf.name() == "ARRAY_AGG";
return udf.name() == "array_agg";
}

false
Expand Down
5 changes: 3 additions & 2 deletions datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,9 @@ impl AggregateExpr for AggregateFunctionExpr {
})
.collect::<Vec<_>>();
let mut name = self.name().to_string();
// TODO: Generalize order-by clause rewrite
if reverse_udf.name() == "ARRAY_AGG" {
// If the function is changed, we need to reverse order_by clause as well
// i.e. First(a order by b asc null first) -> Last(a order by b desc null last)
if self.fun().name() == reverse_udf.name() {
} else {
replace_order_by_clause(&mut name);
}
Expand Down
4 changes: 1 addition & 3 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use std::sync::Arc;

use arrow::datatypes::Schema;

use datafusion_common::{internal_err, Result};
use datafusion_common::Result;
use datafusion_expr::AggregateFunction;

use crate::expressions::{self};
Expand All @@ -56,7 +56,6 @@ pub fn create_aggregate_expr(
let data_type = input_phy_types[0].clone();
let input_phy_exprs = input_phy_exprs.to_vec();
Ok(match (fun, distinct) {
(AggregateFunction::ArrayAgg, _) => return internal_err!("not reachable"),
(AggregateFunction::Min, _) => Arc::new(expressions::Min::new(
Arc::clone(&input_phy_exprs[0]),
name,
Expand Down Expand Up @@ -123,7 +122,6 @@ mod tests {
result_agg_phy_exprs.field().unwrap()
);
}
_ => {}
};
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ enum AggregateFunction {
// AVG = 3;
// COUNT = 4;
// APPROX_DISTINCT = 5;
ARRAY_AGG = 6;
// ARRAY_AGG = 6;
// VARIANCE = 7;
// VARIANCE_POP = 8;
// COVARIANCE = 9;
Expand Down
3 changes: 0 additions & 3 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 2 additions & 5 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ impl From<protobuf::AggregateFunction> for AggregateFunction {
match agg_fun {
protobuf::AggregateFunction::Min => Self::Min,
protobuf::AggregateFunction::Max => Self::Max,
protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg,
}
}
}
Expand Down
2 changes: 0 additions & 2 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction {
match value {
AggregateFunction::Min => Self::Min,
AggregateFunction::Max => Self::Max,
AggregateFunction::ArrayAgg => Self::ArrayAgg,
}
}
}
Expand Down Expand Up @@ -386,7 +385,6 @@ pub fn serialize_expr(
}) => match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
let aggr_function = match fun {
AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg,
AggregateFunction::Min => protobuf::AggregateFunction::Min,
AggregateFunction::Max => protobuf::AggregateFunction::Max,
};
Expand Down
16 changes: 8 additions & 8 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,12 @@ query TT
explain select array_agg(c1 order by c2 desc, c3) from agg_order;
----
logical_plan
01)Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]]
01)Aggregate: groupBy=[[]], aggr=[[array_agg(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]]
02)--TableScan: agg_order projection=[c1, c2, c3]
physical_plan
01)AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]
01)AggregateExec: mode=Final, gby=[], aggr=[array_agg(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]
02)--CoalescePartitionsExec
03)----AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]
03)----AggregateExec: mode=Partial, gby=[], aggr=[array_agg(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]
04)------SortExec: expr=[c2@1 DESC,c3@2 ASC NULLS LAST], preserve_partitioning=[true]
05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_agg_multi_order.csv]]}, projection=[c1, c2, c3], has_header=true
Expand Down Expand Up @@ -231,8 +231,8 @@ explain with A as (
) select array_length(array_agg(distinct a.foo)), sum(distinct 1) from A a group by a.id;
----
logical_plan
01)Projection: array_length(ARRAY_AGG(DISTINCT a.foo)), sum(DISTINCT Int64(1))
02)--Aggregate: groupBy=[[a.id]], aggr=[[ARRAY_AGG(DISTINCT a.foo), sum(DISTINCT Int64(1))]]
01)Projection: array_length(array_agg(DISTINCT a.foo)), sum(DISTINCT Int64(1))
02)--Aggregate: groupBy=[[a.id]], aggr=[[array_agg(DISTINCT a.foo), sum(DISTINCT Int64(1))]]
03)----SubqueryAlias: a
04)------SubqueryAlias: a
05)--------Union
Expand All @@ -247,11 +247,11 @@ logical_plan
14)----------Projection: Int64(1) AS id, Int64(2) AS foo
15)------------EmptyRelation
physical_plan
01)ProjectionExec: expr=[array_length(ARRAY_AGG(DISTINCT a.foo)@1) as array_length(ARRAY_AGG(DISTINCT a.foo)), sum(DISTINCT Int64(1))@2 as sum(DISTINCT Int64(1))]
02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), sum(DISTINCT Int64(1))]
01)ProjectionExec: expr=[array_length(array_agg(DISTINCT a.foo)@1) as array_length(array_agg(DISTINCT a.foo)), sum(DISTINCT Int64(1))@2 as sum(DISTINCT Int64(1))]
02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[array_agg(DISTINCT a.foo), sum(DISTINCT Int64(1))]
03)----CoalesceBatchesExec: target_batch_size=8192
04)------RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=5
05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[ARRAY_AGG(DISTINCT a.foo), sum(DISTINCT Int64(1))], ordering_mode=Sorted
05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[array_agg(DISTINCT a.foo), sum(DISTINCT Int64(1))], ordering_mode=Sorted
06)----------UnionExec
07)------------ProjectionExec: expr=[1 as id, 2 as foo]
08)--------------PlaceholderRowExec
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/binary_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,4 @@ Raphael R false false true true
NULL R NULL NULL NULL NULL

statement ok
drop table test;
drop table test;
Loading

0 comments on commit fc8e7b9

Please sign in to comment.