From 608bce37da13265360951822ac8d7a8547b06416 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gediminas=20Aleknavi=C4=8Dius?= Date: Mon, 15 Apr 2024 13:59:11 +0300 Subject: [PATCH] Groups accumulator for array_agg (#233) * Groups accumulator for array_agg * small fix * fmt * clippy * clippy --- .../physical-expr/src/aggregate/array_agg.rs | 461 +++++++++++++++++- .../groups_accumulator/accumulate.rs | 164 ++++++- 2 files changed, 614 insertions(+), 11 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index 5dc29f834feb..96e8b6b899ba 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -17,17 +17,30 @@ //! Defines physical expressions that can evaluated at runtime during query execution +use crate::aggregate::groups_accumulator::accumulate::NullState; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; +use arrow_array::builder::{ListBuilder, PrimitiveBuilder, StringBuilder}; +use arrow_array::cast::AsArray; +use arrow_array::types::{ + Date32Type, Date64Type, Decimal128Type, Decimal256Type, DurationMicrosecondType, + DurationMillisecondType, DurationNanosecondType, DurationSecondType, Float32Type, + Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, + IntervalMonthDayNanoType, IntervalYearMonthType, Time32MillisecondType, + Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow_array::{Array, ArrowPrimitiveType, BooleanArray}; +use arrow_schema::{IntervalUnit, TimeUnit}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::array_into_list_array; -use datafusion_common::Result; use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::{Accumulator, EmitTo, GroupsAccumulator}; use std::any::Any; use std::sync::Arc; @@ -96,6 +109,139 @@ impl AggregateExpr for ArrayAgg { fn name(&self) -> &str { &self.name } + + fn groups_accumulator_supported(&self) -> bool { + self.input_data_type.is_primitive() || self.input_data_type == DataType::Utf8 + } + + fn create_groups_accumulator(&self) -> Result> { + match self.input_data_type { + DataType::Int8 => Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ))), + DataType::Int16 => Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ))), + DataType::Int32 => Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ))), + DataType::Int64 => Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ))), + DataType::UInt8 => Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ))), + DataType::UInt16 => Ok(Box::new( + ArrayAggGroupsAccumulator::::new(&self.input_data_type), + )), + DataType::UInt32 => Ok(Box::new( + ArrayAggGroupsAccumulator::::new(&self.input_data_type), + )), + DataType::UInt64 => Ok(Box::new( + ArrayAggGroupsAccumulator::::new(&self.input_data_type), + )), + DataType::Float32 => Ok(Box::new( + ArrayAggGroupsAccumulator::::new(&self.input_data_type), + )), + DataType::Float64 => Ok(Box::new( + ArrayAggGroupsAccumulator::::new(&self.input_data_type), + )), + DataType::Decimal128(_, _) => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ))) + } + DataType::Decimal256(_, _) => { + Ok(Box::new(ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ))) + } + DataType::Date32 => Ok(Box::new( + ArrayAggGroupsAccumulator::::new(&self.input_data_type), + )), + DataType::Date64 => Ok(Box::new( + ArrayAggGroupsAccumulator::::new(&self.input_data_type), + )), + DataType::Timestamp(TimeUnit::Second, _) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ), + )), + DataType::Timestamp(TimeUnit::Millisecond, _) => { + Ok(Box::new(ArrayAggGroupsAccumulator::< + TimestampMillisecondType, + >::new(&self.input_data_type))) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + Ok(Box::new(ArrayAggGroupsAccumulator::< + TimestampMicrosecondType, + >::new(&self.input_data_type))) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ), + )), + DataType::Time32(TimeUnit::Second) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new(&self.input_data_type), + )), + DataType::Time32(TimeUnit::Millisecond) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ), + )), + DataType::Time64(TimeUnit::Microsecond) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ), + )), + DataType::Time64(TimeUnit::Nanosecond) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ), + )), + DataType::Duration(TimeUnit::Second) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ), + )), + DataType::Duration(TimeUnit::Millisecond) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ), + )), + DataType::Duration(TimeUnit::Microsecond) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ), + )), + DataType::Duration(TimeUnit::Nanosecond) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ), + )), + DataType::Interval(IntervalUnit::YearMonth) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ), + )), + DataType::Interval(IntervalUnit::DayTime) => Ok(Box::new( + ArrayAggGroupsAccumulator::::new( + &self.input_data_type, + ), + )), + DataType::Interval(IntervalUnit::MonthDayNano) => { + Ok(Box::new(ArrayAggGroupsAccumulator::< + IntervalMonthDayNanoType, + >::new(&self.input_data_type))) + } + DataType::Utf8 => Ok(Box::new(StringArrayAggGroupsAccumulator::new())), + _ => Err(DataFusionError::Internal(format!( + "ArrayAggGroupsAccumulator not supported for data type {:?}", + self.input_data_type + ))), + } + } } impl PartialEq for ArrayAgg { @@ -187,19 +333,258 @@ impl Accumulator for ArrayAggAccumulator { } } +struct ArrayAggGroupsAccumulator +where + T: ArrowPrimitiveType + Send, +{ + values: Vec::Native>>>, + data_type: DataType, + null_state: NullState, +} + +impl ArrayAggGroupsAccumulator +where + T: ArrowPrimitiveType + Send, +{ + pub fn new(data_type: &DataType) -> Self { + Self { + values: vec![], + data_type: data_type.clone(), + null_state: NullState::new(), + } + } +} + +impl ArrayAggGroupsAccumulator { + fn build_list(&mut self, emit_to: EmitTo) -> Result { + let array = emit_to.take_needed(&mut self.values); + let nulls = self.null_state.build(emit_to); + + let len = nulls.len(); + assert_eq!(array.len(), len); + + let mut builder = ListBuilder::with_capacity( + PrimitiveBuilder::::new().with_data_type(self.data_type.clone()), + len, + ); + + for (is_valid, arr) in nulls.iter().zip(array.into_iter()) { + if is_valid { + builder.append_value(arr); + } else { + builder.append_null(); + } + } + + Ok(Arc::new(builder.finish())) + } +} + +impl GroupsAccumulator for ArrayAggGroupsAccumulator +where + T: ArrowPrimitiveType + Send + Sync, +{ + fn update_batch( + &mut self, + new_values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(new_values.len(), 1, "single argument to update_batch"); + let new_values = new_values[0].as_primitive::(); + + self.values.resize(total_num_groups, vec![]); + + self.null_state.accumulate( + group_indices, + new_values, + opt_filter, + total_num_groups, + |group_index, new_value| { + self.values[group_index].push(Some(new_value)); + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to merge_batch"); + let values = values[0].as_list(); + + self.values.resize(total_num_groups, vec![]); + + self.null_state.accumulate_array( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, new_value: ArrayRef| { + let new_value = new_value.as_primitive::(); + self.values[group_index].append( + new_value + .into_iter() + .collect::>>() + .as_mut(), + ); + }, + ); + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + self.build_list(emit_to) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + Ok(vec![self.build_list(emit_to)?]) + } + + fn size(&self) -> usize { + self.values.capacity() + + self.values.iter().map(|arr| arr.capacity()).sum::() + * std::mem::size_of::<::Native>() + + self.null_state.size() + } +} + +struct StringArrayAggGroupsAccumulator { + values: Vec>>, + null_state: NullState, +} + +impl StringArrayAggGroupsAccumulator { + pub fn new() -> Self { + Self { + values: vec![], + null_state: NullState::new(), + } + } +} + +impl StringArrayAggGroupsAccumulator { + fn build_list(&mut self, emit_to: EmitTo) -> Result { + let array = emit_to.take_needed(&mut self.values); + let nulls = self.null_state.build(emit_to); + + assert_eq!(array.len(), nulls.len()); + + let mut builder = ListBuilder::with_capacity(StringBuilder::new(), nulls.len()); + for (is_valid, arr) in nulls.iter().zip(array.into_iter()) { + if is_valid { + builder.append_value(arr); + } else { + builder.append_null(); + } + } + + Ok(Arc::new(builder.finish())) + } +} + +impl GroupsAccumulator for StringArrayAggGroupsAccumulator { + fn update_batch( + &mut self, + new_values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(new_values.len(), 1, "single argument to update_batch"); + let new_values = new_values[0].as_string(); + + self.values.resize(total_num_groups, vec![]); + + self.null_state.accumulate_string( + group_indices, + new_values, + opt_filter, + total_num_groups, + |group_index, new_value| { + self.values[group_index].push(Some(new_value.to_string())); + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to merge_batch"); + let values = values[0].as_list(); + + self.values + .resize(total_num_groups, Vec::>::new()); + + self.null_state.accumulate_array( + group_indices, + values, + opt_filter, + total_num_groups, + |group_index, new_value: ArrayRef| { + let new_value = new_value.as_string::(); + + self.values[group_index].append( + new_value + .into_iter() + .map(|s| s.map(|s| s.to_string())) + .collect::>>() + .as_mut(), + ); + }, + ); + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + self.build_list(emit_to) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + Ok(vec![self.build_list(emit_to)?]) + } + + fn size(&self) -> usize { + self.values.capacity() + + self + .values + .iter() + .map(|arr| { + arr.iter() + .map(|e| e.as_ref().map(|s| s.len()).unwrap_or(0)) + .sum::() + }) + .sum::() + + self.null_state.size() + } +} + #[cfg(test)] mod tests { use super::*; use crate::expressions::col; - use crate::expressions::tests::aggregate; + use crate::expressions::tests::{aggregate, aggregate_new}; use arrow::array::ArrayRef; use arrow::array::Int32Array; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; - use arrow_array::Array; use arrow_array::ListArray; + use arrow_array::{Array, StringArray}; use arrow_buffer::OffsetBuffer; - use datafusion_common::DataFusionError; use datafusion_common::Result; macro_rules! test_op { @@ -221,8 +606,32 @@ mod tests { let expected = ScalarValue::from($EXPECTED); assert_eq!(expected, actual); + }}; + } - Ok(()) as Result<(), DataFusionError> + macro_rules! test_op_new { + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { + generic_test_op_new!( + $ARRAY, + $DATATYPE, + $OP, + $EXPECTED, + $EXPECTED.data_type().clone() + ) + }; + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ + let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; + + let agg = Arc::new(<$OP>::new( + col("a", &schema)?, + "bla".to_string(), + $EXPECTED_DATATYPE, + true, + )); + let actual = aggregate_new(&batch, agg)?; + assert_eq!($EXPECTED, &actual); }}; } @@ -237,9 +646,39 @@ mod tests { Some(4), Some(5), ])]); - let list = ScalarValue::List(Arc::new(list)); + let expected = ScalarValue::List(Arc::new(list.clone())); + + test_op!( + a.clone(), + DataType::Int32, + ArrayAgg, + expected, + DataType::Int32 + ); + + let expected: ArrayRef = Arc::new(list); + test_op_new!(a, DataType::Int32, ArrayAgg, &expected, DataType::Int32); + + Ok(()) + } + + #[test] + fn array_agg_str() -> Result<()> { + let a: ArrayRef = Arc::new(StringArray::from(vec!["1", "2", "3", "4", "5"])); + + let mut list_builder = ListBuilder::with_capacity(StringBuilder::new(), 5); + list_builder.values().append_value("1"); + list_builder.values().append_value("2"); + list_builder.values().append_value("3"); + list_builder.values().append_value("4"); + list_builder.values().append_value("5"); + list_builder.append(true); - test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) + let list = list_builder.finish(); + let expected: ArrayRef = Arc::new(list); + test_op_new!(a, DataType::Utf8, ArrayAgg, &expected, DataType::Utf8); + + Ok(()) } #[test] @@ -305,6 +744,8 @@ mod tests { ArrayAgg, list, DataType::List(Arc::new(Field::new("item", DataType::Int32, true,))) - ) + ); + + Ok(()) } } diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs index 7080ea40039d..0f75e5a81e53 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs @@ -20,7 +20,9 @@ //! [`GroupsAccumulator`]: datafusion_expr::GroupsAccumulator use arrow::datatypes::ArrowPrimitiveType; -use arrow_array::{Array, BooleanArray, PrimitiveArray}; +use arrow_array::{ + Array, ArrayRef, BooleanArray, ListArray, PrimitiveArray, StringArray, +}; use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer}; use datafusion_expr::EmitTo; @@ -324,6 +326,166 @@ impl NullState { } } + /// Invokes `value_fn(group_index, value)` for each non null, non + /// filtered value in `values`, while tracking which groups have + /// seen null inputs and which groups have seen any inputs, for + /// [`ListArray`]s. + /// + /// See [`Self::accumulate`], which handles `PrimitiveArray`s, for + /// more details on other arguments. + pub fn accumulate_array( + &mut self, + group_indices: &[usize], + values: &ListArray, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + mut value_fn: F, + ) where + F: FnMut(usize, ArrayRef) + Send, + { + assert_eq!(values.len(), group_indices.len()); + + // ensure the seen_values is big enough (start everything at + // "not seen" valid) + let seen_values = + initialize_builder(&mut self.seen_values, total_num_groups, false); + + match (values.null_count() > 0, opt_filter) { + // no nulls, no filter, + (false, None) => { + let iter = group_indices.iter().zip(values.iter()); + for (&group_index, new_value) in iter { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value.unwrap()); + } + } + // nulls, no filter + (true, None) => { + let nulls = values.nulls().unwrap(); + group_indices + .iter() + .zip(values.iter()) + .zip(nulls.iter()) + .for_each(|((&group_index, new_value), is_valid)| { + if is_valid { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value.unwrap()); + } + }) + } + // no nulls, but a filter + (false, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + group_indices + .iter() + .zip(values.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, new_value), filter_value)| { + if let Some(true) = filter_value { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value.unwrap()); + } + }); + } + // both null values and filters + (true, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + filter + .iter() + .zip(group_indices.iter()) + .zip(values.iter()) + .for_each(|((filter_value, &group_index), new_value)| { + if let Some(true) = filter_value { + if let Some(new_value) = new_value { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); + } + } + }); + } + } + } + + /// Invokes `value_fn(group_index, value)` for each non-null, + /// non-filtered value in `values`, while tracking which groups have + /// seen null inputs and which groups have seen any inputs, for + /// [`ListArray`]s. + /// + /// See [`Self::accumulate`], which handles `PrimitiveArray`s, for + /// more details on other arguments. + pub fn accumulate_string( + &mut self, + group_indices: &[usize], + values: &StringArray, + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + mut value_fn: F, + ) where + F: FnMut(usize, &str) + Send, + { + assert_eq!(values.len(), group_indices.len()); + + // ensure the seen_values is big enough (start everything at + // "not seen" valid) + let seen_values = + initialize_builder(&mut self.seen_values, total_num_groups, false); + + match (values.null_count() > 0, opt_filter) { + // no nulls, no filter, + (false, None) => { + let iter = group_indices.iter().zip(values.iter()); + for (&group_index, new_value) in iter { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value.unwrap()); + } + } + // nulls, no filter + (true, None) => { + let nulls = values.nulls().unwrap(); + group_indices + .iter() + .zip(values.iter()) + .zip(nulls.iter()) + .for_each(|((&group_index, new_value), is_valid)| { + if is_valid { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value.unwrap()); + } + }) + } + // no nulls, but a filter + (false, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + group_indices + .iter() + .zip(values.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, new_value), filter_value)| { + if let Some(true) = filter_value { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value.unwrap()); + } + }); + } + // both null values and filters + (true, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + filter + .iter() + .zip(group_indices.iter()) + .zip(values.iter()) + .for_each(|((filter_value, &group_index), new_value)| { + if let Some(true) = filter_value { + if let Some(new_value) = new_value { + seen_values.set_bit(group_index, true); + value_fn(group_index, new_value); + } + } + }); + } + } + } + /// Creates the a [`NullBuffer`] representing which group_indices /// should have null values (because they never saw any values) /// for the `emit_to` rows.