Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1102,26 +1102,26 @@ async fn window_using_aggregates() -> Result<()> {
| first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 |
+-------------+----------+-----------------+---------------+--------+-----+------+----+------+
| | | | | | | | 1 | -85 |
| -85 | -101 | 14 | -12 | -101 | 83 | -101 | 4 | -54 |
| -85 | -101 | 17 | -25 | -101 | 83 | -101 | 5 | -31 |
| -85 | -12 | 10 | -32 | -12 | 83 | -85 | 3 | 13 |
| -85 | -25 | 3 | -56 | -25 | -25 | -85 | 1 | -5 |
| -85 | -31 | 18 | -29 | -31 | 83 | -101 | 5 | 36 |
| -85 | -38 | 16 | -25 | -38 | 83 | -101 | 4 | 65 |
| -85 | -101 | 14 | -12 | -12 | 83 | -101 | 4 | -54 |
| -85 | -101 | 17 | -25 | -25 | 83 | -101 | 5 | -31 |
Comment on lines +1105 to +1106
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that this test was returning incorrect results due to the bug I explained in another comment, instead of raising an error. The results here were fixed by updating evaluate() to pass a &mut instead of consuming the state with std::mem::take().

| -85 | -12 | 10 | -32 | -34 | 83 | -85 | 3 | 13 |
| -85 | -25 | 3 | -56 | -56 | -25 | -85 | 1 | -5 |
| -85 | -31 | 18 | -29 | -28 | 83 | -101 | 5 | 36 |
| -85 | -38 | 16 | -25 | -25 | 83 | -101 | 4 | 65 |
| -85 | -43 | 7 | -43 | -43 | 83 | -85 | 2 | 45 |
| -85 | -48 | 6 | -35 | -48 | 83 | -85 | 2 | -43 |
| -85 | -5 | 4 | -37 | -5 | -5 | -85 | 1 | 83 |
| -85 | -54 | 15 | -17 | -54 | 83 | -101 | 4 | -38 |
| -85 | -56 | 2 | -70 | -56 | -56 | -85 | 1 | -25 |
| -85 | -72 | 9 | -43 | -72 | 83 | -85 | 3 | -12 |
| -85 | -48 | 6 | -35 | -36 | 83 | -85 | 2 | -43 |
| -85 | -5 | 4 | -37 | -40 | -5 | -85 | 1 | 83 |
| -85 | -54 | 15 | -17 | -18 | 83 | -101 | 4 | -38 |
| -85 | -56 | 2 | -70 | 57 | -56 | -85 | 1 | -25 |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this interesting, how we have -70 for the approx median but 57 for median 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch. I looked into it, and it seems like it's wrapping around due to integer overflow while taking the average of the middle two values (since the count is even).

low: [-85], high: -56, median: 57 datatype: Int8

-85 + -56 = -141 -> wraparound to 115
Then 115 / 2 -> 57.5 -> 57 (truncated due to integer type)

What's our desired behavior in this case? We could promote to a larger datatype to perform the calculation. Also is it intentional to return the value as a truncated integer instead of a float?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding overflow, perhaps we should raise a separate issue to discuss/track this, as it does seem like incorrect behaviour.

We could do similar for the truncated integer behaviour; there was a recent issue asking about this for reference: #18867 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed: #19322

| -85 | -72 | 9 | -43 | -43 | 83 | -85 | 3 | -12 |
| -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 |
| -85 | 13 | 11 | -17 | 13 | 83 | -85 | 3 | 14 |
| -85 | 13 | 11 | -25 | 13 | 83 | -85 | 3 | 13 |
| -85 | 14 | 12 | -12 | 14 | 83 | -85 | 3 | 17 |
| -85 | 17 | 13 | -11 | 17 | 83 | -85 | 4 | -101 |
| -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 |
| -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 |
| -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 |
| -85 | 13 | 11 | -17 | -18 | 83 | -85 | 3 | 14 |
| -85 | 13 | 11 | -25 | -25 | 83 | -85 | 3 | 13 |
| -85 | 14 | 12 | -12 | -12 | 83 | -85 | 3 | 17 |
| -85 | 17 | 13 | -11 | -8 | 83 | -85 | 4 | -101 |
| -85 | 45 | 8 | -34 | -34 | 83 | -85 | 3 | -72 |
| -85 | 65 | 17 | -17 | -18 | 83 | -101 | 5 | -101 |
| -85 | 83 | 5 | -25 | -25 | 83 | -85 | 2 | -48 |
+-------------+----------+-----------------+---------------+--------+-----+------+----+------+
"###
);
Expand Down
54 changes: 45 additions & 9 deletions datafusion/functions-aggregate/src/median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumu
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::filtered_null_mask;
use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer;
use datafusion_macros::user_doc;
use std::collections::HashMap;

make_udaf_expr_and_func!(
Median,
Expand Down Expand Up @@ -289,14 +290,51 @@ impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let d = std::mem::take(&mut self.all_values);
let median = calculate_median::<T>(d);
let median = calculate_median::<T>(&mut self.all_values);
ScalarValue::new_primitive::<T>(median, &self.data_type)
}

fn size(&self) -> usize {
size_of_val(self) + self.all_values.capacity() * size_of::<T::Native>()
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let mut to_remove: HashMap<ScalarValue, usize> = HashMap::new();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a good optimization with minimal added complexity.


let arr = &values[0];
for i in 0..arr.len() {
let v = ScalarValue::try_from_array(arr, i)?;
if !v.is_null() {
*to_remove.entry(v).or_default() += 1;
}
}

let mut i = 0;
while i < self.all_values.len() {
let k = ScalarValue::new_primitive::<T>(
Some(self.all_values[i]),
&self.data_type,
)?;
if let Some(count) = to_remove.get_mut(&k)
&& *count > 0
{
self.all_values.swap_remove(i);
*count -= 1;
if *count == 0 {
to_remove.remove(&k);
if to_remove.is_empty() {
break;
}
}
}
i += 1;
}
Ok(())
}

fn supports_retract_batch(&self) -> bool {
true
}
}

/// The median groups accumulator accumulates the raw input values
Expand Down Expand Up @@ -443,8 +481,8 @@ impl<T: ArrowNumericType + Send> GroupsAccumulator for MedianGroupsAccumulator<T
// Calculate median for each group
let mut evaluate_result_builder =
PrimitiveBuilder::<T>::new().with_data_type(self.data_type.clone());
for values in emit_group_values {
let median = calculate_median::<T>(values);
for mut values in emit_group_values {
let median = calculate_median::<T>(&mut values);
evaluate_result_builder.append_option(median);
}

Expand Down Expand Up @@ -528,11 +566,11 @@ impl<T: ArrowNumericType + Debug> Accumulator for DistinctMedianAccumulator<T> {
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let d = std::mem::take(&mut self.distinct_values.values)
let mut d = std::mem::take(&mut self.distinct_values.values)
.into_iter()
.map(|v| v.0)
.collect::<Vec<_>>();
let median = calculate_median::<T>(d);
let median = calculate_median::<T>(&mut d);
ScalarValue::new_primitive::<T>(median, &self.data_type)
}

Expand All @@ -556,9 +594,7 @@ where
.unwrap()
}

fn calculate_median<T: ArrowNumericType>(
mut values: Vec<T::Native>,
) -> Option<T::Native> {
fn calculate_median<T: ArrowNumericType>(values: &mut [T::Native]) -> Option<T::Native> {
let cmp = |x: &T::Native, y: &T::Native| x.compare(*y);

let len = values.len();
Expand Down
83 changes: 83 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,89 @@ SELECT approx_median(col_f64_nan) FROM median_table
----
NaN

# median_sliding_window
statement ok
CREATE TABLE median_window_test (
timestamp INT,
tags VARCHAR,
value DOUBLE
);

statement ok
INSERT INTO median_window_test (timestamp, tags, value) VALUES
(1, 'tag1', 10.0),
(2, 'tag1', 20.0),
(3, 'tag1', 30.0),
(4, 'tag1', 40.0),
(5, 'tag1', 50.0),
(1, 'tag2', 60.0),
(2, 'tag2', 70.0),
(3, 'tag2', 80.0),
(4, 'tag2', 90.0),
(5, 'tag2', 100.0);

query ITRR
SELECT
timestamp,
tags,
value,
median(value) OVER (
PARTITION BY tags
ORDER BY timestamp
ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend to test different window frames like UNBOUNDED PRECEDING/FOLLOWING

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't familiar with these before, but this was a great idea! It helped me find and understand a bug.

) AS value_median_3
FROM median_window_test
ORDER BY tags, timestamp;
----
1 tag1 10 15
2 tag1 20 20
3 tag1 30 30
4 tag1 40 40
5 tag1 50 45
1 tag2 60 65
2 tag2 70 70
3 tag2 80 80
4 tag2 90 90
5 tag2 100 95

# median_non_sliding_window
query ITRRRR
SELECT
timestamp,
tags,
value,
median(value) OVER (
PARTITION BY tags
ORDER BY timestamp
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
) AS value_median_unbounded_preceding,
median(value) OVER (
PARTITION BY tags
ORDER BY timestamp
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS value_median_unbounded_both,
Comment on lines +1039 to +1054
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For UNBOUNDED FOLLOWING, an error is raised when retract_batch() isn't implemented. I found that queries with UNBOUNDED PRECEDING do not trigger this and instead return incorrect results. I assume this is a bug, right? If so, I can file a ticket.

For example, if you remove the UNBOUNDED FOLLOWING case right below my comment here, and try the query on main, I get this diff instead of an error.

Results Diff ``` [Diff] (-expected|+actual) 1 tag1 10 10 30 - 2 tag1 20 15 30 - 3 tag1 30 20 30 - 4 tag1 40 25 30 - 5 tag1 50 30 30 + 2 tag1 20 20 30 + 3 tag1 30 30 30 + 4 tag1 40 40 30 + 5 tag1 50 50 30 1 tag2 60 60 80 - 2 tag2 70 65 80 - 3 tag2 80 70 80 - 4 tag2 90 75 80 - 5 tag2 100 80 80 + 2 tag2 70 70 80 + 3 tag2 80 80 80 + 4 tag2 90 90 80 + 5 tag2 100 100 80 ```

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is this quite informative comment which seems to explain why this is the case:

// Accumulators that have window frame startings different
// than `UNBOUNDED PRECEDING`, such as `1 PRECEDING`, need to
// implement retract_batch method in order to run correctly
// currently in DataFusion.
//
// If this `retract_batches` is not present, there is no way
// to calculate result correctly. For example, the query
//
// ```sql
// SELECT
// SUM(a) OVER(ORDER BY a ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sum_a
// FROM
// t
// ```
//
// 1. First sum value will be the sum of rows between `[0, 1)`,
//
// 2. Second sum value will be the sum of rows between `[0, 2)`
//
// 3. Third sum value will be the sum of rows between `[1, 3)`, etc.
//
// Since the accumulator keeps the running sum:
//
// 1. First sum we add to the state sum value between `[0, 1)`
//
// 2. Second sum we add to the state sum value between `[1, 2)`
// (`[0, 1)` is already in the state sum, hence running sum will
// cover `[0, 2)` range)
//
// 3. Third sum we add to the state sum value between `[2, 3)`
// (`[0, 2)` is already in the state sum). Also we need to
// retract values between `[0, 1)` by this way we can obtain sum
// between [1, 3) which is indeed the appropriate range.
//
// When we use `UNBOUNDED PRECEDING` in the query starting
// index will always be 0 for the desired range, and hence the
// `retract_batch` method will not be called. In this case
// having retract_batch is not a requirement.
//
// This approach is a a bit different than window function
// approach. In window function (when they use a window frame)
// they get all the desired range during evaluation.
if !accumulator.supports_retract_batch() {
return not_impl_err!(
"Aggregate can not be used as a sliding accumulator because \
`retract_batch` is not implemented: {}",
self.name
);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should file a ticket, the previous impl should be able to handle unbounded preceding as @Jefffrey explained, and the inconsistent results is likely to indicate a bug.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so you're saying unbounded preceding is supposed to work even without retract_batch() implemented. I was originally under the impression that it wasn't, but no that makes total sense now.

In that case, I think this PR is already fixes the bug, so there's no need to submit an issue for that. I mentioned in this comment that passing mut instead of clearing state with take() (81ced74) fixes the results in the mod.rs test. I've verified this by copying that change (81ced74) over to main and testing it, and the results for that test change. It's completely unrelated to the new support for retract_batch(). We just have an integer overflow issue remaining, which I've submitted an issue for.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this makes sense. I realized that the root cause is already known and it's not possible to cause issue else where.

median(value) OVER (
PARTITION BY tags
ORDER BY timestamp
ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING
) AS value_median_unbounded_following
FROM median_window_test
ORDER BY tags, timestamp;
----
1 tag1 10 10 30 30
2 tag1 20 15 30 35
3 tag1 30 20 30 40
4 tag1 40 25 30 45
5 tag1 50 30 30 50
1 tag2 60 60 80 80
2 tag2 70 65 80 85
3 tag2 80 70 80 90
4 tag2 90 75 80 95
5 tag2 100 80 80 100

statement ok
DROP TABLE median_window_test;

query RT
select approx_median(arrow_cast(col_f32, 'Float16')), arrow_typeof(approx_median(arrow_cast(col_f32, 'Float16'))) from median_table;
----
Expand Down