-
Notifications
You must be signed in to change notification settings - Fork 1.7k
perf: Faster string_agg()
aggregate function (1000x speed for no DISTINCT and ORDER case)
#17837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d7c7c84
2ee945b
574ef18
a041219
2bc95c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -25,9 +25,14 @@ use crate::array_agg::ArrayAgg; | |||||||||||||||||||||||
|
||||||||||||||||||||||||
use arrow::array::ArrayRef; | ||||||||||||||||||||||||
use arrow::datatypes::{DataType, Field, FieldRef}; | ||||||||||||||||||||||||
use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; | ||||||||||||||||||||||||
use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; | ||||||||||||||||||||||||
use datafusion_common::cast::{ | ||||||||||||||||||||||||
as_generic_string_array, as_string_array, as_string_view_array, | ||||||||||||||||||||||||
}; | ||||||||||||||||||||||||
use datafusion_common::{ | ||||||||||||||||||||||||
internal_datafusion_err, internal_err, not_impl_err, Result, ScalarValue, | ||||||||||||||||||||||||
}; | ||||||||||||||||||||||||
use datafusion_expr::function::AccumulatorArgs; | ||||||||||||||||||||||||
use datafusion_expr::utils::format_state_name; | ||||||||||||||||||||||||
use datafusion_expr::{ | ||||||||||||||||||||||||
Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility, | ||||||||||||||||||||||||
}; | ||||||||||||||||||||||||
|
@@ -120,6 +125,8 @@ impl Default for StringAgg { | |||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
/// If there is no `distinct` and `order by` required by the `string_agg` call, a | ||||||||||||||||||||||||
/// more efficient accumulator `SimpleStringAggAccumulator` will be used. | ||||||||||||||||||||||||
impl AggregateUDFImpl for StringAgg { | ||||||||||||||||||||||||
fn as_any(&self) -> &dyn Any { | ||||||||||||||||||||||||
self | ||||||||||||||||||||||||
|
@@ -138,7 +145,21 @@ impl AggregateUDFImpl for StringAgg { | |||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> { | ||||||||||||||||||||||||
self.array_agg.state_fields(args) | ||||||||||||||||||||||||
// See comments in `impl AggregateUDFImpl ...` for more detail | ||||||||||||||||||||||||
let no_order_no_distinct = | ||||||||||||||||||||||||
(args.ordering_fields.is_empty()) && (!args.is_distinct); | ||||||||||||||||||||||||
if no_order_no_distinct { | ||||||||||||||||||||||||
// Case `SimpleStringAggAccumulator` | ||||||||||||||||||||||||
Ok(vec![Field::new( | ||||||||||||||||||||||||
format_state_name(args.name, "string_agg"), | ||||||||||||||||||||||||
DataType::LargeUtf8, | ||||||||||||||||||||||||
true, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
.into()]) | ||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||
// Case `StringAggAccumulator` | ||||||||||||||||||||||||
self.array_agg.state_fields(args) | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> { | ||||||||||||||||||||||||
|
@@ -161,21 +182,31 @@ impl AggregateUDFImpl for StringAgg { | |||||||||||||||||||||||
); | ||||||||||||||||||||||||
}; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs { | ||||||||||||||||||||||||
return_field: Field::new( | ||||||||||||||||||||||||
"f", | ||||||||||||||||||||||||
DataType::new_list(acc_args.return_field.data_type().clone(), true), | ||||||||||||||||||||||||
true, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
.into(), | ||||||||||||||||||||||||
exprs: &filter_index(acc_args.exprs, 1), | ||||||||||||||||||||||||
..acc_args | ||||||||||||||||||||||||
})?; | ||||||||||||||||||||||||
// See comments in `impl AggregateUDFImpl ...` for more detail | ||||||||||||||||||||||||
let no_order_no_distinct = | ||||||||||||||||||||||||
acc_args.order_bys.is_empty() && (!acc_args.is_distinct); | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
Ok(Box::new(StringAggAccumulator::new( | ||||||||||||||||||||||||
array_agg_acc, | ||||||||||||||||||||||||
delimiter, | ||||||||||||||||||||||||
))) | ||||||||||||||||||||||||
if no_order_no_distinct { | ||||||||||||||||||||||||
// simple case (more efficient) | ||||||||||||||||||||||||
Ok(Box::new(SimpleStringAggAccumulator::new(delimiter))) | ||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||
// general case | ||||||||||||||||||||||||
let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs { | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto here for encapsulating this |
||||||||||||||||||||||||
return_field: Field::new( | ||||||||||||||||||||||||
"f", | ||||||||||||||||||||||||
DataType::new_list(acc_args.return_field.data_type().clone(), true), | ||||||||||||||||||||||||
true, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
.into(), | ||||||||||||||||||||||||
exprs: &filter_index(acc_args.exprs, 1), | ||||||||||||||||||||||||
..acc_args | ||||||||||||||||||||||||
})?; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
Ok(Box::new(StringAggAccumulator::new( | ||||||||||||||||||||||||
array_agg_acc, | ||||||||||||||||||||||||
delimiter, | ||||||||||||||||||||||||
))) | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { | ||||||||||||||||||||||||
|
@@ -187,6 +218,7 @@ impl AggregateUDFImpl for StringAgg { | |||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
/// StringAgg accumulator for the general case (with order or distinct specified) | ||||||||||||||||||||||||
#[derive(Debug)] | ||||||||||||||||||||||||
pub(crate) struct StringAggAccumulator { | ||||||||||||||||||||||||
array_agg_acc: Box<dyn Accumulator>, | ||||||||||||||||||||||||
|
@@ -269,6 +301,105 @@ fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> { | |||||||||||||||||||||||
.collect::<Vec<_>>() | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
/// StringAgg accumulator for the simple case (no order or distinct specified) | ||||||||||||||||||||||||
/// This accumulator is more efficient than `StringAggAccumulator` | ||||||||||||||||||||||||
2010YOUY01 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
/// because it accumulates the string directly, | ||||||||||||||||||||||||
/// whereas `StringAggAccumulator` uses `ArrayAggAccumulator`. | ||||||||||||||||||||||||
#[derive(Debug)] | ||||||||||||||||||||||||
pub(crate) struct SimpleStringAggAccumulator { | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is likely much better than what we have. We can probably do better still with a GroupsAccumulator as well |
||||||||||||||||||||||||
delimiter: String, | ||||||||||||||||||||||||
/// Updated during `update_batch()`. e.g. "foo,bar" | ||||||||||||||||||||||||
accumulated_string: String, | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rater than accumulated_string: Option<String>, |
||||||||||||||||||||||||
has_value: bool, | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
impl SimpleStringAggAccumulator { | ||||||||||||||||||||||||
pub fn new(delimiter: &str) -> Self { | ||||||||||||||||||||||||
Self { | ||||||||||||||||||||||||
delimiter: delimiter.to_string(), | ||||||||||||||||||||||||
accumulated_string: "".to_string(), | ||||||||||||||||||||||||
has_value: false, | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
#[inline] | ||||||||||||||||||||||||
fn append_strings<'a, I>(&mut self, iter: I) | ||||||||||||||||||||||||
where | ||||||||||||||||||||||||
I: Iterator<Item = Option<&'a str>>, | ||||||||||||||||||||||||
{ | ||||||||||||||||||||||||
for value in iter.flatten() { | ||||||||||||||||||||||||
if self.has_value { | ||||||||||||||||||||||||
self.accumulated_string.push_str(&self.delimiter); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
self.accumulated_string.push_str(value); | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||||||
self.has_value = true; | ||||||||||||||||||||||||
Comment on lines
+331
to
+336
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you used an option, this could be like
Suggested change
|
||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
impl Accumulator for SimpleStringAggAccumulator { | ||||||||||||||||||||||||
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { | ||||||||||||||||||||||||
let string_arr = values.first().ok_or_else(|| { | ||||||||||||||||||||||||
Comment on lines
+342
to
+343
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there only one element in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it's an array of arg1. The arg is validated during the planning time, and we can also assume it's the right type here ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! |
||||||||||||||||||||||||
internal_datafusion_err!( | ||||||||||||||||||||||||
"Planner should ensure its first arg is Utf8/Utf8View" | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
})?; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
match string_arr.data_type() { | ||||||||||||||||||||||||
DataType::Utf8 => { | ||||||||||||||||||||||||
let array = as_string_array(string_arr)?; | ||||||||||||||||||||||||
self.append_strings(array.iter()); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
DataType::LargeUtf8 => { | ||||||||||||||||||||||||
let array = as_generic_string_array::<i64>(string_arr)?; | ||||||||||||||||||||||||
self.append_strings(array.iter()); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
DataType::Utf8View => { | ||||||||||||||||||||||||
let array = as_string_view_array(string_arr)?; | ||||||||||||||||||||||||
self.append_strings(array.iter()); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
other => { | ||||||||||||||||||||||||
return internal_err!( | ||||||||||||||||||||||||
"Planner should ensure string_agg first argument is Utf8-like, found {other}" | ||||||||||||||||||||||||
); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
Ok(()) | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
fn evaluate(&mut self) -> Result<ScalarValue> { | ||||||||||||||||||||||||
let result = if self.has_value { | ||||||||||||||||||||||||
ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string))) | ||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||
ScalarValue::LargeUtf8(None) | ||||||||||||||||||||||||
}; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
self.has_value = false; | ||||||||||||||||||||||||
Ok(result) | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
fn size(&self) -> usize { | ||||||||||||||||||||||||
size_of_val(self) + self.delimiter.capacity() + self.accumulated_string.capacity() | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
fn state(&mut self) -> Result<Vec<ScalarValue>> { | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just asking to understand the Accumulator trait: I see that this and evaluate are the same except for what they return - what is the difference between the two and when they are used, do you know? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. state is for per-partition intermediate result, and and I think there is a detailed doc in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the explanation! And oh yeah, I should read the doc comments on the trait! |
||||||||||||||||||||||||
let result = if self.has_value { | ||||||||||||||||||||||||
ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string))) | ||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||
ScalarValue::LargeUtf8(None) | ||||||||||||||||||||||||
}; | ||||||||||||||||||||||||
self.has_value = false; | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
Ok(vec![result]) | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { | ||||||||||||||||||||||||
self.update_batch(values) | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
#[cfg(test)] | ||||||||||||||||||||||||
mod tests { | ||||||||||||||||||||||||
use super::*; | ||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to put this as part of SimpleStringAggAccumulator, something like