diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 52b30326809e9..e543ae0701828 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -1102,26 +1102,26 @@ async fn window_using_aggregates() -> Result<()> { | first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 | +-------------+----------+-----------------+---------------+--------+-----+------+----+------+ | | | | | | | | 1 | -85 | - | -85 | -101 | 14 | -12 | -101 | 83 | -101 | 4 | -54 | - | -85 | -101 | 17 | -25 | -101 | 83 | -101 | 5 | -31 | - | -85 | -12 | 10 | -32 | -12 | 83 | -85 | 3 | 13 | - | -85 | -25 | 3 | -56 | -25 | -25 | -85 | 1 | -5 | - | -85 | -31 | 18 | -29 | -31 | 83 | -101 | 5 | 36 | - | -85 | -38 | 16 | -25 | -38 | 83 | -101 | 4 | 65 | + | -85 | -101 | 14 | -12 | -12 | 83 | -101 | 4 | -54 | + | -85 | -101 | 17 | -25 | -25 | 83 | -101 | 5 | -31 | + | -85 | -12 | 10 | -32 | -34 | 83 | -85 | 3 | 13 | + | -85 | -25 | 3 | -56 | -56 | -25 | -85 | 1 | -5 | + | -85 | -31 | 18 | -29 | -28 | 83 | -101 | 5 | 36 | + | -85 | -38 | 16 | -25 | -25 | 83 | -101 | 4 | 65 | | -85 | -43 | 7 | -43 | -43 | 83 | -85 | 2 | 45 | - | -85 | -48 | 6 | -35 | -48 | 83 | -85 | 2 | -43 | - | -85 | -5 | 4 | -37 | -5 | -5 | -85 | 1 | 83 | - | -85 | -54 | 15 | -17 | -54 | 83 | -101 | 4 | -38 | - | -85 | -56 | 2 | -70 | -56 | -56 | -85 | 1 | -25 | - | -85 | -72 | 9 | -43 | -72 | 83 | -85 | 3 | -12 | + | -85 | -48 | 6 | -35 | -36 | 83 | -85 | 2 | -43 | + | -85 | -5 | 4 | -37 | -40 | -5 | -85 | 1 | 83 | + | -85 | -54 | 15 | -17 | -18 | 83 | -101 | 4 | -38 | + | -85 | -56 | 2 | -70 | 57 | -56 | -85 | 1 | -25 | + | -85 | -72 | 9 | -43 | -43 | 83 | -85 | 3 | -12 | | -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 | - | -85 | 13 | 11 | -17 | 13 | 83 | -85 | 3 | 14 | - | -85 | 13 | 11 | -25 | 13 | 83 | -85 | 3 | 13 | - | -85 | 14 | 12 | -12 | 14 | 83 | -85 | 3 | 17 | - | -85 | 17 | 13 | -11 | 17 | 83 | -85 | 4 | -101 | - | -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 | - | -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 | - | -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 | + | -85 | 13 | 11 | -17 | -18 | 83 | -85 | 3 | 14 | + | -85 | 13 | 11 | -25 | -25 | 83 | -85 | 3 | 13 | + | -85 | 14 | 12 | -12 | -12 | 83 | -85 | 3 | 17 | + | -85 | 17 | 13 | -11 | -8 | 83 | -85 | 4 | -101 | + | -85 | 45 | 8 | -34 | -34 | 83 | -85 | 3 | -72 | + | -85 | 65 | 17 | -17 | -18 | 83 | -101 | 5 | -101 | + | -85 | 83 | 5 | -25 | -25 | 83 | -85 | 2 | -48 | +-------------+----------+-----------------+---------------+--------+-----+------+----+------+ "### ); diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index f09c544628a6f..29b8857254dd3 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -53,6 +53,7 @@ use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumu use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask; use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer; use datafusion_macros::user_doc; +use std::collections::HashMap; make_udaf_expr_and_func!( Median, @@ -289,14 +290,51 @@ impl Accumulator for MedianAccumulator { } fn evaluate(&mut self) -> Result { - let d = std::mem::take(&mut self.all_values); - let median = calculate_median::(d); + let median = calculate_median::(&mut self.all_values); ScalarValue::new_primitive::(median, &self.data_type) } fn size(&self) -> usize { size_of_val(self) + self.all_values.capacity() * size_of::() } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let mut to_remove: HashMap = HashMap::new(); + + let arr = &values[0]; + for i in 0..arr.len() { + let v = ScalarValue::try_from_array(arr, i)?; + if !v.is_null() { + *to_remove.entry(v).or_default() += 1; + } + } + + let mut i = 0; + while i < self.all_values.len() { + let k = ScalarValue::new_primitive::( + Some(self.all_values[i]), + &self.data_type, + )?; + if let Some(count) = to_remove.get_mut(&k) + && *count > 0 + { + self.all_values.swap_remove(i); + *count -= 1; + if *count == 0 { + to_remove.remove(&k); + if to_remove.is_empty() { + break; + } + } + } + i += 1; + } + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } } /// The median groups accumulator accumulates the raw input values @@ -443,8 +481,8 @@ impl GroupsAccumulator for MedianGroupsAccumulator::new().with_data_type(self.data_type.clone()); - for values in emit_group_values { - let median = calculate_median::(values); + for mut values in emit_group_values { + let median = calculate_median::(&mut values); evaluate_result_builder.append_option(median); } @@ -528,11 +566,11 @@ impl Accumulator for DistinctMedianAccumulator { } fn evaluate(&mut self) -> Result { - let d = std::mem::take(&mut self.distinct_values.values) + let mut d = std::mem::take(&mut self.distinct_values.values) .into_iter() .map(|v| v.0) .collect::>(); - let median = calculate_median::(d); + let median = calculate_median::(&mut d); ScalarValue::new_primitive::(median, &self.data_type) } @@ -556,9 +594,7 @@ where .unwrap() } -fn calculate_median( - mut values: Vec, -) -> Option { +fn calculate_median(values: &mut [T::Native]) -> Option { let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); let len = values.len(); diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 1de075088db72..f5bc705044f25 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -991,6 +991,89 @@ SELECT approx_median(col_f64_nan) FROM median_table ---- NaN +# median_sliding_window +statement ok +CREATE TABLE median_window_test ( + timestamp INT, + tags VARCHAR, + value DOUBLE +); + +statement ok +INSERT INTO median_window_test (timestamp, tags, value) VALUES +(1, 'tag1', 10.0), +(2, 'tag1', 20.0), +(3, 'tag1', 30.0), +(4, 'tag1', 40.0), +(5, 'tag1', 50.0), +(1, 'tag2', 60.0), +(2, 'tag2', 70.0), +(3, 'tag2', 80.0), +(4, 'tag2', 90.0), +(5, 'tag2', 100.0); + +query ITRR +SELECT + timestamp, + tags, + value, + median(value) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING + ) AS value_median_3 +FROM median_window_test +ORDER BY tags, timestamp; +---- +1 tag1 10 15 +2 tag1 20 20 +3 tag1 30 30 +4 tag1 40 40 +5 tag1 50 45 +1 tag2 60 65 +2 tag2 70 70 +3 tag2 80 80 +4 tag2 90 90 +5 tag2 100 95 + +# median_non_sliding_window +query ITRRRR +SELECT + timestamp, + tags, + value, + median(value) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + ) AS value_median_unbounded_preceding, + median(value) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS value_median_unbounded_both, + median(value) OVER ( + PARTITION BY tags + ORDER BY timestamp + ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING + ) AS value_median_unbounded_following +FROM median_window_test +ORDER BY tags, timestamp; +---- +1 tag1 10 10 30 30 +2 tag1 20 15 30 35 +3 tag1 30 20 30 40 +4 tag1 40 25 30 45 +5 tag1 50 30 30 50 +1 tag2 60 60 80 80 +2 tag2 70 65 80 85 +3 tag2 80 70 80 90 +4 tag2 90 75 80 95 +5 tag2 100 80 80 100 + +statement ok +DROP TABLE median_window_test; + query RT select approx_median(arrow_cast(col_f32, 'Float16')), arrow_typeof(approx_median(arrow_cast(col_f32, 'Float16'))) from median_table; ----