Skip to content

Commit c5a885b

Browse files
UBarneycipherstakes
authored andcommitted
Improve performance of last_value by implementing special GroupsAccumulator (apache#15542)
* Improve performance of `last_value` by implementing special `GroupsAccumulator` * less diff
1 parent abce1b9 commit c5a885b

File tree

4 files changed

+270
-18
lines changed

4 files changed

+270
-18
lines changed

datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,32 @@ async fn test_first_val() {
120120
.await;
121121
}
122122

123+
#[tokio::test(flavor = "multi_thread")]
124+
async fn test_last_val() {
125+
let mut data_gen_config = baseline_config();
126+
127+
for i in 0..data_gen_config.columns.len() {
128+
if data_gen_config.columns[i].get_max_num_distinct().is_none() {
129+
data_gen_config.columns[i] = data_gen_config.columns[i]
130+
.clone()
131+
// Minimize the chance of identical values in the order by columns to make the test more stable
132+
.with_max_num_distinct(usize::MAX);
133+
}
134+
}
135+
136+
let query_builder = QueryBuilder::new()
137+
.with_table_name("fuzz_table")
138+
.with_aggregate_function("last_value")
139+
.with_aggregate_arguments(data_gen_config.all_columns())
140+
.set_group_by_columns(data_gen_config.all_columns());
141+
142+
AggregationFuzzerBuilder::from(data_gen_config)
143+
.add_query_builder(query_builder)
144+
.build()
145+
.run()
146+
.await;
147+
}
148+
123149
#[tokio::test(flavor = "multi_thread")]
124150
async fn test_max() {
125151
let data_gen_config = baseline_config();

datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,9 @@ impl QueryBuilder {
503503
let distinct = if *is_distinct { "DISTINCT " } else { "" };
504504
alias_gen += 1;
505505

506-
let (order_by, null_opt) = if function_name.eq("first_value") {
506+
let (order_by, null_opt) = if function_name.eq("first_value")
507+
|| function_name.eq("last_value")
508+
{
507509
(
508510
self.order_by(&order_by_black_list), /* Among the order by columns, at most one group by column can be included to avoid all order by column values being identical */
509511
self.null_opt(),

datafusion/functions-aggregate/src/first_last.rs

Lines changed: 216 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ impl AggregateUDFImpl for FirstValue {
166166
}
167167

168168
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
169+
// TODO: extract to function
169170
use DataType::*;
170171
matches!(
171172
args.return_type,
@@ -193,6 +194,7 @@ impl AggregateUDFImpl for FirstValue {
193194
&self,
194195
args: AccumulatorArgs,
195196
) -> Result<Box<dyn GroupsAccumulator>> {
197+
// TODO: extract to function
196198
fn create_accumulator<T>(
197199
args: AccumulatorArgs,
198200
) -> Result<Box<dyn GroupsAccumulator>>
@@ -210,6 +212,7 @@ impl AggregateUDFImpl for FirstValue {
210212
args.ignore_nulls,
211213
args.return_type,
212214
&ordering_dtypes,
215+
true,
213216
)?))
214217
}
215218

@@ -258,10 +261,12 @@ impl AggregateUDFImpl for FirstValue {
258261
create_accumulator::<Time64NanosecondType>(args)
259262
}
260263

261-
_ => internal_err!(
262-
"GroupsAccumulator not supported for first({})",
263-
args.return_type
264-
),
264+
_ => {
265+
internal_err!(
266+
"GroupsAccumulator not supported for first_value({})",
267+
args.return_type
268+
)
269+
}
265270
}
266271
}
267272

@@ -291,6 +296,7 @@ impl AggregateUDFImpl for FirstValue {
291296
}
292297
}
293298

299+
// TODO: rename to PrimitiveGroupsAccumulator
294300
struct FirstPrimitiveGroupsAccumulator<T>
295301
where
296302
T: ArrowPrimitiveType + Send,
@@ -316,12 +322,16 @@ where
316322
// buffer for `get_filtered_min_of_each_group`
317323
// filter_min_of_each_group_buf.0[group_idx] -> idx_in_val
318324
// only valid if filter_min_of_each_group_buf.1[group_idx] == true
325+
// TODO: rename to extreme_of_each_group_buf
319326
min_of_each_group_buf: (Vec<usize>, BooleanBufferBuilder),
320327

321328
// =========== option ============
322329

323330
// Stores the applicable ordering requirement.
324331
ordering_req: LexOrdering,
332+
// true: take first element in an aggregation group according to the requested ordering.
333+
// false: take last element in an aggregation group according to the requested ordering.
334+
pick_first_in_group: bool,
325335
// derived from `ordering_req`.
326336
sort_options: Vec<SortOptions>,
327337
// Stores whether incoming data already satisfies the ordering requirement.
@@ -342,6 +352,7 @@ where
342352
ignore_nulls: bool,
343353
data_type: &DataType,
344354
ordering_dtypes: &[DataType],
355+
pick_first_in_group: bool,
345356
) -> Result<Self> {
346357
let requirement_satisfied = ordering_req.is_empty();
347358

@@ -365,6 +376,7 @@ where
365376
is_sets: BooleanBufferBuilder::new(0),
366377
size_of_orderings: 0,
367378
min_of_each_group_buf: (Vec::new(), BooleanBufferBuilder::new(0)),
379+
pick_first_in_group,
368380
})
369381
}
370382

@@ -391,8 +403,13 @@ where
391403

392404
assert!(new_ordering_values.len() == self.ordering_req.len());
393405
let current_ordering = &self.orderings[group_idx];
394-
compare_rows(current_ordering, new_ordering_values, &self.sort_options)
395-
.map(|x| x.is_gt())
406+
compare_rows(current_ordering, new_ordering_values, &self.sort_options).map(|x| {
407+
if self.pick_first_in_group {
408+
x.is_gt()
409+
} else {
410+
x.is_lt()
411+
}
412+
})
396413
}
397414

398415
fn take_orderings(&mut self, emit_to: EmitTo) -> Vec<Vec<ScalarValue>> {
@@ -501,10 +518,10 @@ where
501518
.map(ScalarValue::size_of_vec)
502519
.sum::<usize>()
503520
}
504-
505521
/// Returns a vector of tuples `(group_idx, idx_in_val)` representing the index of the
506522
/// minimum value in `orderings` for each group, using lexicographical comparison.
507523
/// Values are filtered using `opt_filter` and `is_set_arr` if provided.
524+
/// TODO: rename to get_filtered_extreme_of_each_group
508525
fn get_filtered_min_of_each_group(
509526
&mut self,
510527
orderings: &[ArrayRef],
@@ -556,15 +573,19 @@ where
556573
}
557574

558575
let is_valid = self.min_of_each_group_buf.1.get_bit(group_idx);
559-
if is_valid
560-
&& comparator
561-
.compare(self.min_of_each_group_buf.0[group_idx], idx_in_val)
562-
.is_gt()
563-
{
564-
self.min_of_each_group_buf.0[group_idx] = idx_in_val;
565-
} else if !is_valid {
576+
577+
if !is_valid {
566578
self.min_of_each_group_buf.1.set_bit(group_idx, true);
567579
self.min_of_each_group_buf.0[group_idx] = idx_in_val;
580+
} else {
581+
let ordering = comparator
582+
.compare(self.min_of_each_group_buf.0[group_idx], idx_in_val);
583+
584+
if (ordering.is_gt() && self.pick_first_in_group)
585+
|| (ordering.is_lt() && !self.pick_first_in_group)
586+
{
587+
self.min_of_each_group_buf.0[group_idx] = idx_in_val;
588+
}
568589
}
569590
}
570591

@@ -1052,6 +1073,109 @@ impl AggregateUDFImpl for LastValue {
10521073
fn documentation(&self) -> Option<&Documentation> {
10531074
self.doc()
10541075
}
1076+
1077+
fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
1078+
use DataType::*;
1079+
matches!(
1080+
args.return_type,
1081+
Int8 | Int16
1082+
| Int32
1083+
| Int64
1084+
| UInt8
1085+
| UInt16
1086+
| UInt32
1087+
| UInt64
1088+
| Float16
1089+
| Float32
1090+
| Float64
1091+
| Decimal128(_, _)
1092+
| Decimal256(_, _)
1093+
| Date32
1094+
| Date64
1095+
| Time32(_)
1096+
| Time64(_)
1097+
| Timestamp(_, _)
1098+
)
1099+
}
1100+
1101+
fn create_groups_accumulator(
1102+
&self,
1103+
args: AccumulatorArgs,
1104+
) -> Result<Box<dyn GroupsAccumulator>> {
1105+
fn create_accumulator<T>(
1106+
args: AccumulatorArgs,
1107+
) -> Result<Box<dyn GroupsAccumulator>>
1108+
where
1109+
T: ArrowPrimitiveType + Send,
1110+
{
1111+
let ordering_dtypes = args
1112+
.ordering_req
1113+
.iter()
1114+
.map(|e| e.expr.data_type(args.schema))
1115+
.collect::<Result<Vec<_>>>()?;
1116+
1117+
Ok(Box::new(FirstPrimitiveGroupsAccumulator::<T>::try_new(
1118+
args.ordering_req.clone(),
1119+
args.ignore_nulls,
1120+
args.return_type,
1121+
&ordering_dtypes,
1122+
false,
1123+
)?))
1124+
}
1125+
1126+
match args.return_type {
1127+
DataType::Int8 => create_accumulator::<Int8Type>(args),
1128+
DataType::Int16 => create_accumulator::<Int16Type>(args),
1129+
DataType::Int32 => create_accumulator::<Int32Type>(args),
1130+
DataType::Int64 => create_accumulator::<Int64Type>(args),
1131+
DataType::UInt8 => create_accumulator::<UInt8Type>(args),
1132+
DataType::UInt16 => create_accumulator::<UInt16Type>(args),
1133+
DataType::UInt32 => create_accumulator::<UInt32Type>(args),
1134+
DataType::UInt64 => create_accumulator::<UInt64Type>(args),
1135+
DataType::Float16 => create_accumulator::<Float16Type>(args),
1136+
DataType::Float32 => create_accumulator::<Float32Type>(args),
1137+
DataType::Float64 => create_accumulator::<Float64Type>(args),
1138+
1139+
DataType::Decimal128(_, _) => create_accumulator::<Decimal128Type>(args),
1140+
DataType::Decimal256(_, _) => create_accumulator::<Decimal256Type>(args),
1141+
1142+
DataType::Timestamp(TimeUnit::Second, _) => {
1143+
create_accumulator::<TimestampSecondType>(args)
1144+
}
1145+
DataType::Timestamp(TimeUnit::Millisecond, _) => {
1146+
create_accumulator::<TimestampMillisecondType>(args)
1147+
}
1148+
DataType::Timestamp(TimeUnit::Microsecond, _) => {
1149+
create_accumulator::<TimestampMicrosecondType>(args)
1150+
}
1151+
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
1152+
create_accumulator::<TimestampNanosecondType>(args)
1153+
}
1154+
1155+
DataType::Date32 => create_accumulator::<Date32Type>(args),
1156+
DataType::Date64 => create_accumulator::<Date64Type>(args),
1157+
DataType::Time32(TimeUnit::Second) => {
1158+
create_accumulator::<Time32SecondType>(args)
1159+
}
1160+
DataType::Time32(TimeUnit::Millisecond) => {
1161+
create_accumulator::<Time32MillisecondType>(args)
1162+
}
1163+
1164+
DataType::Time64(TimeUnit::Microsecond) => {
1165+
create_accumulator::<Time64MicrosecondType>(args)
1166+
}
1167+
DataType::Time64(TimeUnit::Nanosecond) => {
1168+
create_accumulator::<Time64NanosecondType>(args)
1169+
}
1170+
1171+
_ => {
1172+
internal_err!(
1173+
"GroupsAccumulator not supported for last_value({})",
1174+
args.return_type
1175+
)
1176+
}
1177+
}
1178+
}
10551179
}
10561180

10571181
#[derive(Debug)]
@@ -1411,6 +1535,7 @@ mod tests {
14111535
true,
14121536
&DataType::Int64,
14131537
&[DataType::Int64],
1538+
true,
14141539
)?;
14151540

14161541
let mut val_with_orderings = {
@@ -1485,7 +1610,7 @@ mod tests {
14851610
}
14861611

14871612
#[test]
1488-
fn test_frist_group_acc_size_of_ordering() -> Result<()> {
1613+
fn test_group_acc_size_of_ordering() -> Result<()> {
14891614
let schema = Arc::new(Schema::new(vec![
14901615
Field::new("a", DataType::Int64, true),
14911616
Field::new("b", DataType::Int64, true),
@@ -1504,6 +1629,7 @@ mod tests {
15041629
true,
15051630
&DataType::Int64,
15061631
&[DataType::Int64],
1632+
true,
15071633
)?;
15081634

15091635
let val_with_orderings = {
@@ -1563,4 +1689,79 @@ mod tests {
15631689

15641690
Ok(())
15651691
}
1692+
1693+
#[test]
1694+
fn test_last_group_acc() -> Result<()> {
1695+
let schema = Arc::new(Schema::new(vec![
1696+
Field::new("a", DataType::Int64, true),
1697+
Field::new("b", DataType::Int64, true),
1698+
Field::new("c", DataType::Int64, true),
1699+
Field::new("d", DataType::Int32, true),
1700+
Field::new("e", DataType::Boolean, true),
1701+
]));
1702+
1703+
let sort_key = LexOrdering::new(vec![PhysicalSortExpr {
1704+
expr: col("c", &schema).unwrap(),
1705+
options: SortOptions::default(),
1706+
}]);
1707+
1708+
let mut group_acc = FirstPrimitiveGroupsAccumulator::<Int64Type>::try_new(
1709+
sort_key,
1710+
true,
1711+
&DataType::Int64,
1712+
&[DataType::Int64],
1713+
false,
1714+
)?;
1715+
1716+
let mut val_with_orderings = {
1717+
let mut val_with_orderings = Vec::<ArrayRef>::new();
1718+
1719+
let vals = Arc::new(Int64Array::from(vec![Some(1), None, Some(3), Some(-6)]));
1720+
let orderings = Arc::new(Int64Array::from(vec![1, -9, 3, -6]));
1721+
1722+
val_with_orderings.push(vals);
1723+
val_with_orderings.push(orderings);
1724+
1725+
val_with_orderings
1726+
};
1727+
1728+
group_acc.update_batch(
1729+
&val_with_orderings,
1730+
&[0, 1, 2, 1],
1731+
Some(&BooleanArray::from(vec![true, true, false, true])),
1732+
3,
1733+
)?;
1734+
1735+
let state = group_acc.state(EmitTo::All)?;
1736+
1737+
let expected_state: Vec<Arc<dyn Array>> = vec![
1738+
Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1739+
Arc::new(Int64Array::from(vec![Some(1), Some(-6), None])),
1740+
Arc::new(BooleanArray::from(vec![true, true, false])),
1741+
];
1742+
assert_eq!(state, expected_state);
1743+
1744+
group_acc.merge_batch(
1745+
&state,
1746+
&[0, 1, 2],
1747+
Some(&BooleanArray::from(vec![true, false, false])),
1748+
3,
1749+
)?;
1750+
1751+
val_with_orderings.clear();
1752+
val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1753+
val_with_orderings.push(Arc::new(Int64Array::from(vec![66, 6])));
1754+
1755+
group_acc.update_batch(&val_with_orderings, &[1, 2], None, 4)?;
1756+
1757+
let binding = group_acc.evaluate(EmitTo::All)?;
1758+
let eval_result = binding.as_any().downcast_ref::<Int64Array>().unwrap();
1759+
1760+
let expect: PrimitiveArray<Int64Type> =
1761+
Int64Array::from(vec![Some(1), Some(66), Some(6), None]);
1762+
1763+
assert_eq!(eval_result, &expect);
1764+
1765+
Ok(())
1766+
}
15661767
}

0 commit comments

Comments
 (0)