Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ArragAgg (not ordered or distinct) into a UDAF #11045

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1658,10 +1658,10 @@ mod tests {
use datafusion_common::{Constraint, Constraints};
use datafusion_common_runtime::SpawnedTask;
use datafusion_expr::{
array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction,
ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition,
cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation,
Volatility, WindowFrame, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::count_distinct;
use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties};

Expand Down
28 changes: 28 additions & 0 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ use datafusion_expr::{
DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, StringifiedPlan,
WindowFrame, WindowFrameBound, WriteOp,
};
use datafusion_functions_aggregate::array_agg::array_agg_udaf;
use datafusion_physical_expr::expressions::Literal;
use datafusion_physical_expr::LexOrdering;
use datafusion_physical_plan::placeholder_row::PlaceholderRowExec;
Expand Down Expand Up @@ -1854,6 +1855,33 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
== NullTreatment::IgnoreNulls;

let (agg_expr, filter, order_by) = match func_def {
AggregateFunctionDefinition::BuiltIn(
datafusion_expr::AggregateFunction::ArrayAgg,
) if !distinct && order_by.is_none() => {
let sort_exprs = order_by.clone().unwrap_or(vec![]);
let physical_sort_exprs = match order_by {
Some(exprs) => Some(create_physical_sort_exprs(
exprs,
logical_input_schema,
execution_props,
)?),
None => None,
};
let ordering_reqs: Vec<PhysicalSortExpr> =
physical_sort_exprs.clone().unwrap_or(vec![]);
let agg_expr = udaf::create_aggregate_expr(
&array_agg_udaf(),
&physical_args,
args,
&sort_exprs,
&ordering_reqs,
physical_input_schema,
name,
ignore_nulls,
*distinct,
)?;
(agg_expr, filter, physical_sort_exprs)
}
AggregateFunctionDefinition::BuiltIn(fun) => {
let physical_sort_exprs = match order_by {
Some(exprs) => Some(create_physical_sort_exprs(
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
array_agg, avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col,
placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame,
WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder,
scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{count, sum};
use datafusion_functions_aggregate::expr_fn::{array_agg, count, sum};

#[tokio::test]
async fn test_count_wildcard_on_sort() -> Result<()> {
Expand Down
12 changes: 0 additions & 12 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,6 @@ pub fn max(expr: Expr) -> Expr {
))
}

/// Create an expression to represent the array_agg() aggregate function
pub fn array_agg(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::ArrayAgg,
vec![expr],
false,
None,
None,
None,
))
}

/// Create an expression to represent the avg() aggregate function
pub fn avg(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
Expand Down
3 changes: 3 additions & 0 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ pub struct StateFieldsArgs<'a> {
/// The input type of the aggregate function.
pub input_type: &'a DataType,

/// If the input type is nullable.
pub input_nullable: bool,

/// The return type of the aggregate function.
pub return_type: &'a DataType,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,102 +17,118 @@

//! Defines physical expressions that can evaluated at runtime during query execution

use crate::aggregate::utils::down_cast_any_ref;
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field};
use arrow_array::Array;
use arrow::array::{Array, ArrayRef};
use arrow::datatypes::DataType;
use arrow_schema::Field;

use datafusion_common::cast::as_list_array;
use datafusion_common::utils::array_into_list_array;
use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::Accumulator;
use std::any::Any;
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::expr::AggregateFunctionDefinition;
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::AggregateUDFImpl;
use datafusion_expr::Expr;
use datafusion_expr::{Accumulator, Signature, Volatility};
use std::sync::Arc;

/// ARRAY_AGG aggregate expression
make_udaf_expr_and_func!(
ArrayAgg,
array_agg,
expression,
"input values, including nulls, concatenated into an array",
array_agg_udaf
);

#[derive(Debug)]
/// ARRAY_AGG aggregate expression
pub struct ArrayAgg {
/// Column name
name: String,
/// The DataType for the input expression
input_data_type: DataType,
/// The input expression
expr: Arc<dyn PhysicalExpr>,
/// If the input expression can have NULLs
nullable: bool,
signature: Signature,
alias: Vec<String>,
}

impl ArrayAgg {
/// Create a new ArrayAgg aggregate function
pub fn new(
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
nullable: bool,
) -> Self {
impl Default for ArrayAgg {
fn default() -> Self {
Self {
name: name.into(),
input_data_type: data_type,
expr,
nullable,
signature: Signature::any(1, Volatility::Immutable),
alias: vec!["array_agg".to_string()],
}
}
}

impl AggregateExpr for ArrayAgg {
fn as_any(&self) -> &dyn Any {
impl AggregateUDFImpl for ArrayAgg {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn field(&self) -> Result<Field> {
Ok(Field::new_list(
&self.name,
// This should be the same as return type of AggregateFunction::ArrayAgg
Field::new("item", self.input_data_type.clone(), true),
self.nullable,
))
fn name(&self) -> &str {
"ARRAY_AGG"
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(ArrayAggAccumulator::try_new(
&self.input_data_type,
)?))
fn aliases(&self) -> &[String] {
&self.alias
}

fn state_fields(&self) -> Result<Vec<Field>> {
fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::List(Arc::new(Field::new(
"item",
arg_types[0].clone(),
true,
))))
}

fn state_fields(
&self,
args: datafusion_expr::function::StateFieldsArgs,
) -> Result<Vec<Field>> {
Ok(vec![Field::new_list(
format_state_name(&self.name, "array_agg"),
Field::new("item", self.input_data_type.clone(), true),
self.nullable,
format_state_name(args.name, "array_agg"),
Field::new("item", args.input_type.clone(), true),
args.input_nullable,
)])
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(ArrayAggAccumulator::try_new(acc_args.input_type)?))
}

fn name(&self) -> &str {
&self.name
fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
datafusion_expr::ReversedUDAF::Identical
}
}

impl PartialEq<dyn Any> for ArrayAgg {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& self.input_data_type == x.input_data_type
&& self.expr.eq(&x.expr)
})
.unwrap_or(false)
fn simplify(
&self,
) -> Option<datafusion_expr::function::AggregateFunctionSimplification> {
let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| {
if aggregate_function.order_by.is_some() || aggregate_function.distinct {
Ok(Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::BuiltIn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This's cool 😎

datafusion_expr::aggregate_function::AggregateFunction::ArrayAgg,
),
args: aggregate_function.args,
distinct: aggregate_function.distinct,
filter: aggregate_function.filter,
order_by: aggregate_function.order_by,
null_treatment: aggregate_function.null_treatment,
}))
} else {
Ok(Expr::AggregateFunction(aggregate_function))
}
};

Some(Box::new(simplify))
}
}

#[derive(Debug)]
pub(crate) struct ArrayAggAccumulator {
pub struct ArrayAggAccumulator {
values: Vec<ArrayRef>,
datatype: DataType,
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ impl AggregateUDFImpl for LastValue {
let StateFieldsArgs {
name,
input_type,
input_nullable: _,
return_type: _,
ordering_fields,
is_distinct: _,
Expand Down
6 changes: 5 additions & 1 deletion datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
pub mod macros;

pub mod approx_distinct;
pub mod array_agg;
pub mod count;
pub mod covariance;
pub mod first_last;
Expand Down Expand Up @@ -86,6 +87,7 @@ pub mod expr_fn {
pub use super::approx_median::approx_median;
pub use super::approx_percentile_cont::approx_percentile_cont;
pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight;
pub use super::array_agg::array_agg;
pub use super::bit_and_or_xor::bit_and;
pub use super::bit_and_or_xor::bit_or;
pub use super::bit_and_or_xor::bit_xor;
Expand Down Expand Up @@ -117,6 +119,7 @@ pub mod expr_fn {
/// Returns all default aggregate functions
pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
vec![
array_agg::array_agg_udaf(),
first_last::first_value_udaf(),
first_last::last_value_udaf(),
covariance::covar_samp_udaf(),
Expand Down Expand Up @@ -177,7 +180,8 @@ mod tests {
for func in all_default_aggregate_functions() {
// TODO: remove this
// These functions are in intermidiate migration state, skip them
if func.name().to_lowercase() == "count" {
let name_lower_case = func.name().to_lowercase();
if name_lower_case == "count" || name_lower_case == "array_agg" {
continue;
}
assert!(
Expand Down
9 changes: 8 additions & 1 deletion datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ pub fn create_aggregate_expr(
ordering_fields,
is_distinct,
input_type: input_exprs_types[0].clone(),
input_nullable: input_phy_exprs[0].nullable(schema)?,
}))
}

Expand Down Expand Up @@ -248,6 +249,7 @@ pub struct AggregateFunctionExpr {
ordering_fields: Vec<Field>,
is_distinct: bool,
input_type: DataType,
input_nullable: bool,
}

impl AggregateFunctionExpr {
Expand Down Expand Up @@ -276,6 +278,7 @@ impl AggregateExpr for AggregateFunctionExpr {
let args = StateFieldsArgs {
name: &self.name,
input_type: &self.input_type,
input_nullable: self.input_nullable,
return_type: &self.data_type,
ordering_fields: &self.ordering_fields,
is_distinct: self.is_distinct,
Expand All @@ -285,7 +288,11 @@ impl AggregateExpr for AggregateFunctionExpr {
}

fn field(&self) -> Result<Field> {
Ok(Field::new(&self.name, self.data_type.clone(), true))
Ok(Field::new(
&self.name,
self.data_type.clone(),
self.input_nullable,
))
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Expand Down
Loading
Loading