|
18 | 18 | //! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`] |
19 | 19 |
|
20 | 20 | use arrow::array::{ |
21 | | - new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, StructArray, |
| 21 | + make_array, new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, |
| 22 | + StructArray, |
22 | 23 | }; |
23 | 24 | use arrow::compute::{filter, SortOptions}; |
24 | 25 | use arrow::datatypes::{DataType, Field, FieldRef, Fields}; |
25 | 26 |
|
26 | 27 | use datafusion_common::cast::as_list_array; |
| 28 | +use datafusion_common::scalar::copy_array_data; |
27 | 29 | use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; |
28 | 30 | use datafusion_common::{exec_err, ScalarValue}; |
29 | 31 | use datafusion_common::{internal_err, Result}; |
@@ -319,7 +321,11 @@ impl Accumulator for ArrayAggAccumulator { |
319 | 321 | }; |
320 | 322 |
|
321 | 323 | if !val.is_empty() { |
322 | | - self.values.push(val); |
| 324 | + // The ArrayRef might be holding a reference to its original input buffer, so |
| 325 | + // storing it here directly copied/compacted avoids over accounting memory |
| 326 | + // not used here. |
| 327 | + self.values |
| 328 | + .push(make_array(copy_array_data(&val.to_data()))); |
323 | 329 | } |
324 | 330 |
|
325 | 331 | Ok(()) |
@@ -429,7 +435,8 @@ impl Accumulator for DistinctArrayAggAccumulator { |
429 | 435 | if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) { |
430 | 436 | for i in 0..val.len() { |
431 | 437 | if nulls.is_none_or(|nulls| nulls.is_valid(i)) { |
432 | | - self.values.insert(ScalarValue::try_from_array(val, i)?); |
| 438 | + self.values |
| 439 | + .insert(ScalarValue::try_from_array(val, i)?.compacted()); |
433 | 440 | } |
434 | 441 | } |
435 | 442 | } |
@@ -558,8 +565,14 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { |
558 | 565 | if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) { |
559 | 566 | for i in 0..val.len() { |
560 | 567 | if nulls.is_none_or(|nulls| nulls.is_valid(i)) { |
561 | | - self.values.push(ScalarValue::try_from_array(val, i)?); |
562 | | - self.ordering_values.push(get_row_at_idx(ord, i)?) |
| 568 | + self.values |
| 569 | + .push(ScalarValue::try_from_array(val, i)?.compacted()); |
| 570 | + self.ordering_values.push( |
| 571 | + get_row_at_idx(ord, i)? |
| 572 | + .into_iter() |
| 573 | + .map(|v| v.compacted()) |
| 574 | + .collect(), |
| 575 | + ) |
563 | 576 | } |
564 | 577 | } |
565 | 578 | } |
@@ -722,6 +735,7 @@ impl OrderSensitiveArrayAggAccumulator { |
722 | 735 | #[cfg(test)] |
723 | 736 | mod tests { |
724 | 737 | use super::*; |
| 738 | + use arrow::array::{ListBuilder, StringBuilder}; |
725 | 739 | use arrow::datatypes::{FieldRef, Schema}; |
726 | 740 | use datafusion_common::cast::as_generic_string_array; |
727 | 741 | use datafusion_common::internal_err; |
@@ -988,6 +1002,56 @@ mod tests { |
988 | 1002 | Ok(()) |
989 | 1003 | } |
990 | 1004 |
|
| 1005 | + #[test] |
| 1006 | + fn does_not_over_account_memory() -> Result<()> { |
| 1007 | + let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?; |
| 1008 | + |
| 1009 | + acc1.update_batch(&[data(["a", "c", "b"])])?; |
| 1010 | + acc2.update_batch(&[data(["b", "c", "a"])])?; |
| 1011 | + acc1 = merge(acc1, acc2)?; |
| 1012 | + |
| 1013 | + // without compaction, the size is 2652. |
| 1014 | + assert_eq!(acc1.size(), 732); |
| 1015 | + |
| 1016 | + Ok(()) |
| 1017 | + } |
| 1018 | + #[test] |
| 1019 | + fn does_not_over_account_memory_distinct() -> Result<()> { |
| 1020 | + let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string() |
| 1021 | + .distinct() |
| 1022 | + .build_two()?; |
| 1023 | + |
| 1024 | + acc1.update_batch(&[string_list_data([ |
| 1025 | + vec!["a", "b", "c"], |
| 1026 | + vec!["d", "e", "f"], |
| 1027 | + ])])?; |
| 1028 | + acc2.update_batch(&[string_list_data([vec!["e", "f", "g"]])])?; |
| 1029 | + acc1 = merge(acc1, acc2)?; |
| 1030 | + |
| 1031 | + // without compaction, the size is 16660 |
| 1032 | + assert_eq!(acc1.size(), 1660); |
| 1033 | + |
| 1034 | + Ok(()) |
| 1035 | + } |
| 1036 | + |
| 1037 | + #[test] |
| 1038 | + fn does_not_over_account_memory_ordered() -> Result<()> { |
| 1039 | + let mut acc = ArrayAggAccumulatorBuilder::string() |
| 1040 | + .order_by_col("col", SortOptions::new(false, false)) |
| 1041 | + .build()?; |
| 1042 | + |
| 1043 | + acc.update_batch(&[string_list_data([ |
| 1044 | + vec!["a", "b", "c"], |
| 1045 | + vec!["c", "d", "e"], |
| 1046 | + vec!["b", "c", "d"], |
| 1047 | + ])])?; |
| 1048 | + |
| 1049 | + // without compaction, the size is 17112 |
| 1050 | + assert_eq!(acc.size(), 2080); |
| 1051 | + |
| 1052 | + Ok(()) |
| 1053 | + } |
| 1054 | + |
991 | 1055 | struct ArrayAggAccumulatorBuilder { |
992 | 1056 | return_field: FieldRef, |
993 | 1057 | distinct: bool, |
@@ -1066,6 +1130,15 @@ mod tests { |
1066 | 1130 | .collect() |
1067 | 1131 | } |
1068 | 1132 |
|
| 1133 | + fn string_list_data<'a>(data: impl IntoIterator<Item = Vec<&'a str>>) -> ArrayRef { |
| 1134 | + let mut builder = ListBuilder::new(StringBuilder::new()); |
| 1135 | + for string_list in data.into_iter() { |
| 1136 | + builder.append_value(string_list.iter().map(Some).collect::<Vec<_>>()); |
| 1137 | + } |
| 1138 | + |
| 1139 | + Arc::new(builder.finish()) |
| 1140 | + } |
| 1141 | + |
1069 | 1142 | fn data<T, const N: usize>(list: [T; N]) -> ArrayRef |
1070 | 1143 | where |
1071 | 1144 | ScalarValue: From<T>, |
|
0 commit comments