Skip to content

Commit 471a2a5

Browse files
nuno-fariaLiaCastaneda
authored andcommitted
fix: string_agg not respecting ORDER BY
1 parent 025ddde commit 471a2a5

File tree

2 files changed

+130
-2
lines changed

2 files changed

+130
-2
lines changed

datafusion/functions-aggregate/src/array_agg.rs

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
2020
use std::cmp::Ordering;
2121
use std::collections::{HashSet, VecDeque};
22-
use std::mem::{size_of, size_of_val};
22+
use std::mem::{size_of, size_of_val, take};
2323
use std::sync::Arc;
2424

2525
use arrow::array::{
@@ -31,7 +31,7 @@ use arrow::datatypes::{DataType, Field, FieldRef, Fields};
3131

3232
use datafusion_common::cast::as_list_array;
3333
use datafusion_common::scalar::copy_array_data;
34-
use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder};
34+
use datafusion_common::utils::{compare_rows, get_row_at_idx, SingleRowListArrayBuilder};
3535
use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
3636
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
3737
use datafusion_expr::utils::format_state_name;
@@ -78,12 +78,14 @@ This aggregation function can only mix DISTINCT and ORDER BY if the ordering exp
7878
/// ARRAY_AGG aggregate expression
7979
pub struct ArrayAgg {
8080
signature: Signature,
81+
is_input_pre_ordered: bool,
8182
}
8283

8384
impl Default for ArrayAgg {
8485
fn default() -> Self {
8586
Self {
8687
signature: Signature::any(1, Volatility::Immutable),
88+
is_input_pre_ordered: false,
8789
}
8890
}
8991
}
@@ -144,6 +146,16 @@ impl AggregateUDFImpl for ArrayAgg {
144146
Ok(fields)
145147
}
146148

149+
fn with_beneficial_ordering(
150+
self: Arc<Self>,
151+
beneficial_ordering: bool,
152+
) -> Result<Option<Arc<dyn AggregateUDFImpl>>> {
153+
Ok(Some(Arc::new(Self {
154+
signature: self.signature.clone(),
155+
is_input_pre_ordered: beneficial_ordering,
156+
})))
157+
}
158+
147159
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
148160
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
149161
let ignore_nulls =
@@ -196,6 +208,7 @@ impl AggregateUDFImpl for ArrayAgg {
196208
&data_type,
197209
&ordering_dtypes,
198210
ordering,
211+
self.is_input_pre_ordered,
199212
acc_args.is_reversed,
200213
ignore_nulls,
201214
)
@@ -518,6 +531,8 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator {
518531
datatypes: Vec<DataType>,
519532
/// Stores the ordering requirement of the `Accumulator`.
520533
ordering_req: LexOrdering,
534+
/// Whether the input is known to be pre-ordered
535+
is_input_pre_ordered: bool,
521536
/// Whether the aggregation is running in reverse.
522537
reverse: bool,
523538
/// Whether the aggregation should ignore null values.
@@ -531,6 +546,7 @@ impl OrderSensitiveArrayAggAccumulator {
531546
datatype: &DataType,
532547
ordering_dtypes: &[DataType],
533548
ordering_req: LexOrdering,
549+
is_input_pre_ordered: bool,
534550
reverse: bool,
535551
ignore_nulls: bool,
536552
) -> Result<Self> {
@@ -541,11 +557,34 @@ impl OrderSensitiveArrayAggAccumulator {
541557
ordering_values: vec![],
542558
datatypes,
543559
ordering_req,
560+
is_input_pre_ordered,
544561
reverse,
545562
ignore_nulls,
546563
})
547564
}
548565

566+
fn sort(&mut self) {
567+
let sort_options = self
568+
.ordering_req
569+
.iter()
570+
.map(|sort_expr| sort_expr.options)
571+
.collect::<Vec<_>>();
572+
let mut values = take(&mut self.values)
573+
.into_iter()
574+
.zip(take(&mut self.ordering_values))
575+
.collect::<Vec<_>>();
576+
let mut delayed_cmp_err = Ok(());
577+
values.sort_by(|(_, left_ordering), (_, right_ordering)| {
578+
compare_rows(left_ordering, right_ordering, &sort_options).unwrap_or_else(
579+
|err| {
580+
delayed_cmp_err = Err(err);
581+
Ordering::Equal
582+
},
583+
)
584+
});
585+
(self.values, self.ordering_values) = values.into_iter().unzip();
586+
}
587+
549588
fn evaluate_orderings(&self) -> Result<ScalarValue> {
550589
let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]);
551590

@@ -629,6 +668,9 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
629668
let mut partition_ordering_values = vec![];
630669

631670
// Existing values should be merged also.
671+
if !self.is_input_pre_ordered {
672+
self.sort();
673+
}
632674
partition_values.push(self.values.clone().into());
633675
partition_ordering_values.push(self.ordering_values.clone().into());
634676

@@ -679,13 +721,21 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
679721
}
680722

681723
fn state(&mut self) -> Result<Vec<ScalarValue>> {
724+
if !self.is_input_pre_ordered {
725+
self.sort();
726+
}
727+
682728
let mut result = vec![self.evaluate()?];
683729
result.push(self.evaluate_orderings()?);
684730

685731
Ok(result)
686732
}
687733

688734
fn evaluate(&mut self) -> Result<ScalarValue> {
735+
if !self.is_input_pre_ordered {
736+
self.sort();
737+
}
738+
689739
if self.values.is_empty() {
690740
return Ok(ScalarValue::new_null_list(
691741
self.datatypes[0].clone(),

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6028,6 +6028,84 @@ GROUP BY dummy
60286028
----
60296029
text1
60306030

6031+
6032+
# Test string_agg with ORDER BY clasuses (issue #17011)
6033+
statement ok
6034+
create table t (k varchar, v int);
6035+
6036+
statement ok
6037+
insert into t values ('a', 2), ('b', 3), ('c', 1), ('d', null);
6038+
6039+
query T
6040+
select string_agg(k, ',' order by k) from t;
6041+
----
6042+
a,b,c,d
6043+
6044+
query T
6045+
select string_agg(k, ',' order by k desc) from t;
6046+
----
6047+
d,c,b,a
6048+
6049+
query T
6050+
select string_agg(k, ',' order by v) from t;
6051+
----
6052+
c,a,b,d
6053+
6054+
query T
6055+
select string_agg(k, ',' order by v nulls first) from t;
6056+
----
6057+
d,c,a,b
6058+
6059+
query T
6060+
select string_agg(k, ',' order by v desc) from t;
6061+
----
6062+
d,b,a,c
6063+
6064+
query T
6065+
select string_agg(k, ',' order by v desc nulls last) from t;
6066+
----
6067+
b,a,c,d
6068+
6069+
query T
6070+
-- odd indexes should appear first, ties solved by v
6071+
select string_agg(k, ',' order by v % 2 == 0, v) from t;
6072+
----
6073+
c,b,a,d
6074+
6075+
query T
6076+
-- odd indexes should appear first, ties solved by v desc
6077+
select string_agg(k, ',' order by v % 2 == 0, v desc) from t;
6078+
----
6079+
b,c,a,d
6080+
6081+
query T
6082+
select string_agg(k, ',' order by
6083+
case
6084+
when k = 'a' then 3
6085+
when k = 'b' then 0
6086+
when k = 'c' then 2
6087+
when k = 'd' then 1
6088+
end)
6089+
from t;
6090+
----
6091+
b,d,c,a
6092+
6093+
query T
6094+
select string_agg(k, ',' order by
6095+
case
6096+
when k = 'a' then 3
6097+
when k = 'b' then 0
6098+
when k = 'c' then 2
6099+
when k = 'd' then 1
6100+
end desc)
6101+
from t;
6102+
----
6103+
a,c,d,b
6104+
6105+
statement ok
6106+
drop table t;
6107+
6108+
60316109
# Tests for aggregating with NaN values
60326110
statement ok
60336111
CREATE TABLE float_table (

0 commit comments

Comments
 (0)