From 57e768d9c269305ddf9d3f0e232dfc4059d5c7be Mon Sep 17 00:00:00 2001 From: Peter Nguyen Date: Wed, 10 Dec 2025 23:12:18 -0800 Subject: [PATCH 1/7] Implement 'retract_batch' for MedianAccumulator --- datafusion/functions-aggregate/src/median.rs | 20 +++++++- .../sqllogictest/test_files/aggregate.slt | 48 +++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index f09c544628a6f..4734991659ba6 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -289,7 +289,7 @@ impl Accumulator for MedianAccumulator { } fn evaluate(&mut self) -> Result { - let d = std::mem::take(&mut self.all_values); + let d = self.all_values.clone(); let median = calculate_median::(d); ScalarValue::new_primitive::(median, &self.data_type) } @@ -297,6 +297,24 @@ impl Accumulator for MedianAccumulator { fn size(&self) -> usize { size_of_val(self) + self.all_values.capacity() * size_of::() } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + for v in values.iter().flatten() { + if let Some(idx) = self.all_values.iter().position(|x| *x == v) { + self.all_values.swap_remove(idx); + } else { + return Err(internal_datafusion_err!( + "attempted to retract value {v:?} that was not present" + )); + } + } + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } } /// The median groups accumulator accumulates the raw input values diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 1de075088db72..c1d12c06135fc 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -991,6 +991,54 @@ SELECT approx_median(col_f64_nan) FROM median_table ---- NaN +# median_sliding_window +statement ok +CREATE TABLE median_window_test ( + timestamp INT, + tags VARCHAR, + value DOUBLE +); + +statement ok +INSERT INTO median_window_test (timestamp, tags, value) VALUES +(1, 'tag1', 10.0), +(2, 'tag1', 20.0), +(3, 'tag1', 30.0), +(4, 'tag1', 40.0), +(5, 'tag1', 50.0), +(1, 'tag2', 60.0), +(2, 'tag2', 70.0), +(3, 'tag2', 80.0), +(4, 'tag2', 90.0), +(5, 'tag2', 100.0); + +query ITRR +SELECT + timestamp, + tags, + value, + median(value) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING + ) AS value_median_3 +FROM median_window_test +ORDER BY tags, timestamp; +---- +1 tag1 10 15 +2 tag1 20 20 +3 tag1 30 30 +4 tag1 40 40 +5 tag1 50 45 +1 tag2 60 65 +2 tag2 70 70 +3 tag2 80 80 +4 tag2 90 90 +5 tag2 100 95 + +statement ok +DROP TABLE median_window_test; + query RT select approx_median(arrow_cast(col_f32, 'Float16')), arrow_typeof(approx_median(arrow_cast(col_f32, 'Float16'))) from median_table; ---- From 81ced74f3a858dfe93e66a6d800cfea36ad32daf Mon Sep 17 00:00:00 2001 From: Peter Nguyen Date: Sat, 13 Dec 2025 11:34:18 -0800 Subject: [PATCH 2/7] Pass &mut instead of cloning --- datafusion/functions-aggregate/src/median.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 4734991659ba6..9a4eda2e907b1 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -289,8 +289,7 @@ impl Accumulator for MedianAccumulator { } fn evaluate(&mut self) -> Result { - let d = self.all_values.clone(); - let median = calculate_median::(d); + let median = calculate_median::(&mut self.all_values); ScalarValue::new_primitive::(median, &self.data_type) } @@ -461,8 +460,8 @@ impl GroupsAccumulator for MedianGroupsAccumulator::new().with_data_type(self.data_type.clone()); - for values in emit_group_values { - let median = calculate_median::(values); + for mut values in emit_group_values { + let median = calculate_median::(&mut values); evaluate_result_builder.append_option(median); } @@ -546,11 +545,11 @@ impl Accumulator for DistinctMedianAccumulator { } fn evaluate(&mut self) -> Result { - let d = std::mem::take(&mut self.distinct_values.values) + let mut d = std::mem::take(&mut self.distinct_values.values) .into_iter() .map(|v| v.0) .collect::>(); - let median = calculate_median::(d); + let median = calculate_median::(&mut d); ScalarValue::new_primitive::(median, &self.data_type) } @@ -575,7 +574,7 @@ where } fn calculate_median( - mut values: Vec, + values: &mut [T::Native], ) -> Option { let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); From 5771f83e03de5ee6e25fd4983455b3eb8a017c81 Mon Sep 17 00:00:00 2001 From: Peter Nguyen Date: Sat, 13 Dec 2025 12:11:55 -0800 Subject: [PATCH 3/7] Add slt test with 'UNBOUNDED PRECEDING/FOLLOWING --- .../sqllogictest/test_files/aggregate.slt | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index c1d12c06135fc..13052f6e06c26 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1036,6 +1036,40 @@ ORDER BY tags, timestamp; 4 tag2 90 90 5 tag2 100 95 +query ITRRRR +SELECT + timestamp, + tags, + value, + median(value) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) AS value_median_unbounded_preceding, + median(value) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS value_median_unbounded_both, + median(value) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING + ) AS value_median_unbounded_following +FROM median_window_test +ORDER BY tags, timestamp; +---- +1 tag1 10 10 30 30 +2 tag1 20 15 30 35 +3 tag1 30 20 30 40 +4 tag1 40 25 30 45 +5 tag1 50 30 30 50 +1 tag2 60 60 80 80 +2 tag2 70 65 80 85 +3 tag2 80 70 80 90 +4 tag2 90 75 80 95 +5 tag2 100 80 80 100 + statement ok DROP TABLE median_window_test; From 637c0b94bd336a9f24a8754a5ae51494580cb4f2 Mon Sep 17 00:00:00 2001 From: Peter Nguyen Date: Sat, 13 Dec 2025 12:13:20 -0800 Subject: [PATCH 4/7] Fix expected results in mod.rs 'window_using_aggregates' and add test for the bug in aggregate.slt --- datafusion/core/tests/dataframe/mod.rs | 36 +++++++++---------- .../sqllogictest/test_files/aggregate.slt | 35 ++++++++++++++++++ 2 files changed, 53 insertions(+), 18 deletions(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 52b30326809e9..e543ae0701828 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -1102,26 +1102,26 @@ async fn window_using_aggregates() -> Result<()> { | first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 | +-------------+----------+-----------------+---------------+--------+-----+------+----+------+ | | | | | | | | 1 | -85 | - | -85 | -101 | 14 | -12 | -101 | 83 | -101 | 4 | -54 | - | -85 | -101 | 17 | -25 | -101 | 83 | -101 | 5 | -31 | - | -85 | -12 | 10 | -32 | -12 | 83 | -85 | 3 | 13 | - | -85 | -25 | 3 | -56 | -25 | -25 | -85 | 1 | -5 | - | -85 | -31 | 18 | -29 | -31 | 83 | -101 | 5 | 36 | - | -85 | -38 | 16 | -25 | -38 | 83 | -101 | 4 | 65 | + | -85 | -101 | 14 | -12 | -12 | 83 | -101 | 4 | -54 | + | -85 | -101 | 17 | -25 | -25 | 83 | -101 | 5 | -31 | + | -85 | -12 | 10 | -32 | -34 | 83 | -85 | 3 | 13 | + | -85 | -25 | 3 | -56 | -56 | -25 | -85 | 1 | -5 | + | -85 | -31 | 18 | -29 | -28 | 83 | -101 | 5 | 36 | + | -85 | -38 | 16 | -25 | -25 | 83 | -101 | 4 | 65 | | -85 | -43 | 7 | -43 | -43 | 83 | -85 | 2 | 45 | - | -85 | -48 | 6 | -35 | -48 | 83 | -85 | 2 | -43 | - | -85 | -5 | 4 | -37 | -5 | -5 | -85 | 1 | 83 | - | -85 | -54 | 15 | -17 | -54 | 83 | -101 | 4 | -38 | - | -85 | -56 | 2 | -70 | -56 | -56 | -85 | 1 | -25 | - | -85 | -72 | 9 | -43 | -72 | 83 | -85 | 3 | -12 | + | -85 | -48 | 6 | -35 | -36 | 83 | -85 | 2 | -43 | + | -85 | -5 | 4 | -37 | -40 | -5 | -85 | 1 | 83 | + | -85 | -54 | 15 | -17 | -18 | 83 | -101 | 4 | -38 | + | -85 | -56 | 2 | -70 | 57 | -56 | -85 | 1 | -25 | + | -85 | -72 | 9 | -43 | -43 | 83 | -85 | 3 | -12 | | -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 | - | -85 | 13 | 11 | -17 | 13 | 83 | -85 | 3 | 14 | - | -85 | 13 | 11 | -25 | 13 | 83 | -85 | 3 | 13 | - | -85 | 14 | 12 | -12 | 14 | 83 | -85 | 3 | 17 | - | -85 | 17 | 13 | -11 | 17 | 83 | -85 | 4 | -101 | - | -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 | - | -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 | - | -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 | + | -85 | 13 | 11 | -17 | -18 | 83 | -85 | 3 | 14 | + | -85 | 13 | 11 | -25 | -25 | 83 | -85 | 3 | 13 | + | -85 | 14 | 12 | -12 | -12 | 83 | -85 | 3 | 17 | + | -85 | 17 | 13 | -11 | -8 | 83 | -85 | 4 | -101 | + | -85 | 45 | 8 | -34 | -34 | 83 | -85 | 3 | -72 | + | -85 | 65 | 17 | -17 | -18 | 83 | -101 | 5 | -101 | + | -85 | 83 | 5 | -25 | -25 | 83 | -85 | 2 | -48 | +-------------+----------+-----------------+---------------+--------+-----+------+----+------+ "### ); diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 13052f6e06c26..8eaa2e29f2bf6 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1073,6 +1073,41 @@ ORDER BY tags, timestamp; statement ok DROP TABLE median_window_test; +# median_non_sliding_window_unbounded_preceding_to_1_preceding +statement ok +CREATE TABLE median_non_sliding_test ( + row_num INT, + value INT +); + +statement ok +INSERT INTO median_non_sliding_test (row_num, value) VALUES +(1, 10), +(2, 20), +(3, 30), +(4, 40), +(5, 50); + +query III +SELECT + row_num, + value, + median(value) OVER ( + ORDER BY row_num + ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING + ) AS median_so_far +FROM median_non_sliding_test +ORDER BY row_num; +---- +1 10 NULL +2 20 10 +3 30 15 +4 40 20 +5 50 25 + +statement ok +DROP TABLE median_non_sliding_test; + query RT select approx_median(arrow_cast(col_f32, 'Float16')), arrow_typeof(approx_median(arrow_cast(col_f32, 'Float16'))) from median_table; ---- From 839e54e4167eaece50ddcbe35749b73d3b062b5c Mon Sep 17 00:00:00 2001 From: Peter Nguyen Date: Sat, 13 Dec 2025 12:15:42 -0800 Subject: [PATCH 5/7] cargo fmt --- datafusion/functions-aggregate/src/median.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 9a4eda2e907b1..512cdfa05639f 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -573,9 +573,7 @@ where .unwrap() } -fn calculate_median( - values: &mut [T::Native], -) -> Option { +fn calculate_median(values: &mut [T::Native]) -> Option { let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); let len = values.len(); From 768e95e0e583bd45466f50141fd7edcb5b0cb74b Mon Sep 17 00:00:00 2001 From: Peter Nguyen Date: Sat, 13 Dec 2025 12:24:11 -0800 Subject: [PATCH 6/7] Remove redundant test --- .../sqllogictest/test_files/aggregate.slt | 36 +------------------ 1 file changed, 1 insertion(+), 35 deletions(-) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 8eaa2e29f2bf6..f5bc705044f25 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1036,6 +1036,7 @@ ORDER BY tags, timestamp; 4 tag2 90 90 5 tag2 100 95 +# median_non_sliding_window query ITRRRR SELECT timestamp, @@ -1073,41 +1074,6 @@ ORDER BY tags, timestamp; statement ok DROP TABLE median_window_test; -# median_non_sliding_window_unbounded_preceding_to_1_preceding -statement ok -CREATE TABLE median_non_sliding_test ( - row_num INT, - value INT -); - -statement ok -INSERT INTO median_non_sliding_test (row_num, value) VALUES -(1, 10), -(2, 20), -(3, 30), -(4, 40), -(5, 50); - -query III -SELECT - row_num, - value, - median(value) OVER ( - ORDER BY row_num - ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING - ) AS median_so_far -FROM median_non_sliding_test -ORDER BY row_num; ----- -1 10 NULL -2 20 10 -3 30 15 -4 40 20 -5 50 25 - -statement ok -DROP TABLE median_non_sliding_test; - query RT select approx_median(arrow_cast(col_f32, 'Float16')), arrow_typeof(approx_median(arrow_cast(col_f32, 'Float16'))) from median_table; ---- From 1b710cc753efafdf529deb3e40760f70a635d896 Mon Sep 17 00:00:00 2001 From: Peter Nguyen Date: Sun, 14 Dec 2025 15:29:21 -0800 Subject: [PATCH 7/7] Speed up retract_batch using a hash map --- datafusion/functions-aggregate/src/median.rs | 37 +++++++++++++++----- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index 512cdfa05639f..29b8857254dd3 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -53,6 +53,7 @@ use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumu use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask; use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer; use datafusion_macros::user_doc; +use std::collections::HashMap; make_udaf_expr_and_func!( Median, @@ -298,15 +299,35 @@ impl Accumulator for MedianAccumulator { } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = values[0].as_primitive::(); - for v in values.iter().flatten() { - if let Some(idx) = self.all_values.iter().position(|x| *x == v) { - self.all_values.swap_remove(idx); - } else { - return Err(internal_datafusion_err!( - "attempted to retract value {v:?} that was not present" - )); + let mut to_remove: HashMap = HashMap::new(); + + let arr = &values[0]; + for i in 0..arr.len() { + let v = ScalarValue::try_from_array(arr, i)?; + if !v.is_null() { + *to_remove.entry(v).or_default() += 1; + } + } + + let mut i = 0; + while i < self.all_values.len() { + let k = ScalarValue::new_primitive::( + Some(self.all_values[i]), + &self.data_type, + )?; + if let Some(count) = to_remove.get_mut(&k) + && *count > 0 + { + self.all_values.swap_remove(i); + *count -= 1; + if *count == 0 { + to_remove.remove(&k); + if to_remove.is_empty() { + break; + } + } } + i += 1; } Ok(()) }