Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 148 additions & 17 deletions datafusion/functions-aggregate/src/string_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ use crate::array_agg::ArrayAgg;

use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
use datafusion_common::cast::{
as_generic_string_array, as_string_array, as_string_view_array,
};
use datafusion_common::{
internal_datafusion_err, internal_err, not_impl_err, Result, ScalarValue,
};
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility,
};
Expand Down Expand Up @@ -120,6 +125,8 @@ impl Default for StringAgg {
}
}

/// If there is no `distinct` and `order by` required by the `string_agg` call, a
/// more efficient accumulator `SimpleStringAggAccumulator` will be used.
impl AggregateUDFImpl for StringAgg {
fn as_any(&self) -> &dyn Any {
self
Expand All @@ -138,7 +145,21 @@ impl AggregateUDFImpl for StringAgg {
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
self.array_agg.state_fields(args)
// See comments in `impl AggregateUDFImpl ...` for more detail
let no_order_no_distinct =
(args.ordering_fields.is_empty()) && (!args.is_distinct);
if no_order_no_distinct {
// Case `SimpleStringAggAccumulator`
Ok(vec![Field::new(
format_state_name(args.name, "string_agg"),
DataType::LargeUtf8,
true,
)
.into()])
Comment on lines +152 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to put this as part of SimpleStringAggAccumulator, something like

  SimpleStringAggAccumulator::state_fields(args)

} else {
// Case `StringAggAccumulator`
self.array_agg.state_fields(args)
}
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Expand All @@ -161,21 +182,31 @@ impl AggregateUDFImpl for StringAgg {
);
};

let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
return_field: Field::new(
"f",
DataType::new_list(acc_args.return_field.data_type().clone(), true),
true,
)
.into(),
exprs: &filter_index(acc_args.exprs, 1),
..acc_args
})?;
// See comments in `impl AggregateUDFImpl ...` for more detail
let no_order_no_distinct =
acc_args.order_bys.is_empty() && (!acc_args.is_distinct);

Ok(Box::new(StringAggAccumulator::new(
array_agg_acc,
delimiter,
)))
if no_order_no_distinct {
// simple case (more efficient)
Ok(Box::new(SimpleStringAggAccumulator::new(delimiter)))
} else {
// general case
let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto here for encapsulating this

return_field: Field::new(
"f",
DataType::new_list(acc_args.return_field.data_type().clone(), true),
true,
)
.into(),
exprs: &filter_index(acc_args.exprs, 1),
..acc_args
})?;

Ok(Box::new(StringAggAccumulator::new(
array_agg_acc,
delimiter,
)))
}
}

fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
Expand All @@ -187,6 +218,7 @@ impl AggregateUDFImpl for StringAgg {
}
}

/// StringAgg accumulator for the general case (with order or distinct specified)
#[derive(Debug)]
pub(crate) struct StringAggAccumulator {
array_agg_acc: Box<dyn Accumulator>,
Expand Down Expand Up @@ -269,6 +301,105 @@ fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
.collect::<Vec<_>>()
}

/// StringAgg accumulator for the simple case (no order or distinct specified)
/// This accumulator is more efficient than `StringAggAccumulator`
/// because it accumulates the string directly,
/// whereas `StringAggAccumulator` uses `ArrayAggAccumulator`.
#[derive(Debug)]
pub(crate) struct SimpleStringAggAccumulator {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is likely much better than what we have. We can probably do better still with a GroupsAccumulator as well

delimiter: String,
/// Updated during `update_batch()`. e.g. "foo,bar"
accumulated_string: String,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rater than has_value perhaps using an option would be better / more rust idomatic and harder to misuse

    accumulated_string: Option<String>,

has_value: bool,
}

impl SimpleStringAggAccumulator {
pub fn new(delimiter: &str) -> Self {
Self {
delimiter: delimiter.to_string(),
accumulated_string: "".to_string(),
has_value: false,
}
}

#[inline]
fn append_strings<'a, I>(&mut self, iter: I)
where
I: Iterator<Item = Option<&'a str>>,
{
for value in iter.flatten() {
if self.has_value {
self.accumulated_string.push_str(&self.delimiter);
}

self.accumulated_string.push_str(value);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

push_str is 💪

self.has_value = true;
Comment on lines +331 to +336
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you used an option, this could be like

Suggested change
if self.has_value {
self.accumulated_string.push_str(&self.delimiter);
}
self.accumulated_string.push_str(value);
self.has_value = true;
if let Some(accumulated_value) = self.accumulated_value.as_mut() {
accumulated_string.push_str(&self.delimiter);
} else {
self.accumulated_valie = Some(String::from(&value))
}

}
}
}

impl Accumulator for SimpleStringAggAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let string_arr = values.first().ok_or_else(|| {
Comment on lines +342 to +343
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there only one element in values? That was surprising

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's an array of arg1. The arg is validated during the planning time, and we can also assume it's the right type here (values[0] is a string array)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

internal_datafusion_err!(
"Planner should ensure its first arg is Utf8/Utf8View"
)
})?;

match string_arr.data_type() {
DataType::Utf8 => {
let array = as_string_array(string_arr)?;
self.append_strings(array.iter());
}
DataType::LargeUtf8 => {
let array = as_generic_string_array::<i64>(string_arr)?;
self.append_strings(array.iter());
}
DataType::Utf8View => {
let array = as_string_view_array(string_arr)?;
self.append_strings(array.iter());
}
other => {
return internal_err!(
"Planner should ensure string_agg first argument is Utf8-like, found {other}"
);
}
}

Ok(())
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let result = if self.has_value {
ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
} else {
ScalarValue::LargeUtf8(None)
};

self.has_value = false;
Ok(result)
}

fn size(&self) -> usize {
size_of_val(self) + self.delimiter.capacity() + self.accumulated_string.capacity()
}

fn state(&mut self) -> Result<Vec<ScalarValue>> {
Copy link
Contributor

@vegarsti vegarsti Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just asking to understand the Accumulator trait: I see that this and evaluate are the same except for what they return - what is the difference between the two and when they are used, do you know?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

state is for per-partition intermediate result, and evaluate() is the final result.
e.g. for group key1, it's getting executed in 2 partitions.
partition 1:
-- [INPUT] (foo, bar) --state()--> "foo, bar"
partition 2:
-- [input] (baz) --state--> "baz"

and evaluate() is called after merge_batch to combine the above intermediates from all partitions, and get the final result "foo, bar, baz"

I think there is a detailed doc in the Accumulator interface

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! And oh yeah, I should read the doc comments on the trait!

let result = if self.has_value {
ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
} else {
ScalarValue::LargeUtf8(None)
};
self.has_value = false;

Ok(vec![result])
}

fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
self.update_batch(values)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading