Skip to content

Commit

Permalink
Add test hitting the former overflow panic
Browse files Browse the repository at this point in the history
  • Loading branch information
gruuya committed Jan 26, 2024
1 parent 81781ff commit 13def2c
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 90 deletions.
7 changes: 4 additions & 3 deletions datafusion/common/src/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ impl<T: Debug + Clone + PartialEq + Eq + PartialOrd> Precision<T> {

/// Transform the value in this [`Precision`] object, if one exists, using
/// the given function. Preserves the exactness state.
pub fn map<F>(self, f: F) -> Precision<T>
pub fn map<U, F>(self, f: F) -> Precision<U>
where
F: Fn(T) -> T,
F: Fn(T) -> U,
U: Debug + Clone + PartialEq + Eq + PartialOrd,
{
match self {
Precision::Exact(val) => Precision::Exact(f(val)),
Precision::Inexact(val) => Precision::Inexact(f(val)),
_ => self,
_ => Precision::<U>::Absent,
}
}

Expand Down
197 changes: 110 additions & 87 deletions datafusion/physical-plan/src/joins/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,7 @@ mod tests {
use arrow::error::{ArrowError, Result as ArrowResult};
use arrow_schema::SortOptions;

use datafusion_common::stats::Precision::{Absent, Exact, Inexact};
use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue};

fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> {
Expand Down Expand Up @@ -1640,25 +1641,26 @@ mod tests {
}

fn create_column_stats(
min: Option<i64>,
max: Option<i64>,
distinct_count: Option<usize>,
min: Precision<i64>,
max: Precision<i64>,
distinct_count: Precision<usize>,
null_count: Precision<usize>,
) -> ColumnStatistics {
ColumnStatistics {
distinct_count: distinct_count
.map(Precision::Inexact)
.unwrap_or(Precision::Absent),
min_value: min
.map(|size| Precision::Inexact(ScalarValue::from(size)))
.unwrap_or(Precision::Absent),
max_value: max
.map(|size| Precision::Inexact(ScalarValue::from(size)))
.unwrap_or(Precision::Absent),
..Default::default()
distinct_count,
min_value: min.map(ScalarValue::from),
max_value: max.map(ScalarValue::from),
null_count,
}
}

type PartialStats = (usize, Option<i64>, Option<i64>, Option<usize>);
type PartialStats = (
usize,
Precision<i64>,
Precision<i64>,
Precision<usize>,
Precision<usize>,
);

// This is mainly for validating the all edge cases of the estimation, but
// more advanced (and real world test cases) are below where we need some control
Expand All @@ -1675,133 +1677,156 @@ mod tests {
//
// distinct(left) == NaN, distinct(right) == NaN
(
(10, Some(1), Some(10), None),
(10, Some(1), Some(10), None),
Some(Precision::Inexact(10)),
(10, Inexact(1), Inexact(10), Absent, Absent),
(10, Inexact(1), Inexact(10), Absent, Absent),
Some(Inexact(10)),
),
// range(left) > range(right)
(
(10, Some(6), Some(10), None),
(10, Some(8), Some(10), None),
Some(Precision::Inexact(20)),
(10, Inexact(6), Inexact(10), Absent, Absent),
(10, Inexact(8), Inexact(10), Absent, Absent),
Some(Inexact(20)),
),
// range(right) > range(left)
(
(10, Some(8), Some(10), None),
(10, Some(6), Some(10), None),
Some(Precision::Inexact(20)),
(10, Inexact(8), Inexact(10), Absent, Absent),
(10, Inexact(6), Inexact(10), Absent, Absent),
Some(Inexact(20)),
),
// range(left) > len(left), range(right) > len(right)
(
(10, Some(1), Some(15), None),
(20, Some(1), Some(40), None),
Some(Precision::Inexact(10)),
(10, Inexact(1), Inexact(15), Absent, Absent),
(20, Inexact(1), Inexact(40), Absent, Absent),
Some(Inexact(10)),
),
// When we have distinct count.
(
(10, Some(1), Some(10), Some(10)),
(10, Some(1), Some(10), Some(10)),
Some(Precision::Inexact(10)),
(10, Inexact(1), Inexact(10), Inexact(10), Absent),
(10, Inexact(1), Inexact(10), Inexact(10), Absent),
Some(Inexact(10)),
),
// distinct(left) > distinct(right)
(
(10, Some(1), Some(10), Some(5)),
(10, Some(1), Some(10), Some(2)),
Some(Precision::Inexact(20)),
(10, Inexact(1), Inexact(10), Inexact(5), Absent),
(10, Inexact(1), Inexact(10), Inexact(2), Absent),
Some(Inexact(20)),
),
// distinct(right) > distinct(left)
(
(10, Some(1), Some(10), Some(2)),
(10, Some(1), Some(10), Some(5)),
Some(Precision::Inexact(20)),
(10, Inexact(1), Inexact(10), Inexact(2), Absent),
(10, Inexact(1), Inexact(10), Inexact(5), Absent),
Some(Inexact(20)),
),
// min(left) < 0 (range(left) > range(right))
(
(10, Some(-5), Some(5), None),
(10, Some(1), Some(5), None),
Some(Precision::Inexact(10)),
(10, Inexact(-5), Inexact(5), Absent, Absent),
(10, Inexact(1), Inexact(5), Absent, Absent),
Some(Inexact(10)),
),
// min(right) < 0, max(right) < 0 (range(right) > range(left))
(
(10, Some(-25), Some(-20), None),
(10, Some(-25), Some(-15), None),
Some(Precision::Inexact(10)),
(10, Inexact(-25), Inexact(-20), Absent, Absent),
(10, Inexact(-25), Inexact(-15), Absent, Absent),
Some(Inexact(10)),
),
// range(left) < 0, range(right) >= 0
// (there isn't a case where both left and right ranges are negative
// so one of them is always going to work, this just proves negative
// ranges with bigger absolute values are not are not accidentally used).
(
(10, Some(-10), Some(0), None),
(10, Some(0), Some(10), Some(5)),
Some(Precision::Inexact(10)),
(10, Inexact(-10), Inexact(0), Absent, Absent),
(10, Inexact(0), Inexact(10), Inexact(5), Absent),
Some(Inexact(10)),
),
// range(left) = 1, range(right) = 1
(
(10, Some(1), Some(1), None),
(10, Some(1), Some(1), None),
Some(Precision::Inexact(100)),
(10, Inexact(1), Inexact(1), Absent, Absent),
(10, Inexact(1), Inexact(1), Absent, Absent),
Some(Inexact(100)),
),
//
// Edge cases
// ==========
//
// No column level stats.
((10, None, None, None), (10, None, None, None), None),
(
(10, Absent, Absent, Absent, Absent),
(10, Absent, Absent, Absent, Absent),
None,
),
// No min or max (or both).
((10, None, None, Some(3)), (10, None, None, Some(3)), None),
(
(10, Some(2), None, Some(3)),
(10, None, Some(5), Some(3)),
(10, Absent, Absent, Inexact(3), Absent),
(10, Absent, Absent, Inexact(3), Absent),
None,
),
(
(10, Inexact(2), Absent, Inexact(3), Absent),
(10, Absent, Inexact(5), Inexact(3), Absent),
None,
),
(
(10, None, Some(3), Some(3)),
(10, Some(1), None, Some(3)),
(10, Absent, Inexact(3), Inexact(3), Absent),
(10, Inexact(1), Absent, Inexact(3), Absent),
None,
),
(
(10, Absent, Inexact(3), Absent, Absent),
(10, Inexact(1), Absent, Absent, Absent),
None,
),
((10, None, Some(3), None), (10, Some(1), None, None), None),
// Non overlapping min/max (when exact=False).
(
(10, Some(0), Some(10), None),
(10, Some(11), Some(20), None),
Some(Precision::Inexact(0)),
(10, Inexact(0), Inexact(10), Absent, Absent),
(10, Inexact(11), Inexact(20), Absent, Absent),
Some(Inexact(0)),
),
(
(10, Some(11), Some(20), None),
(10, Some(0), Some(10), None),
Some(Precision::Inexact(0)),
(10, Inexact(11), Inexact(20), Absent, Absent),
(10, Inexact(0), Inexact(10), Absent, Absent),
Some(Inexact(0)),
),
// distinct(left) = 0, distinct(right) = 0
(
(10, Some(1), Some(10), Some(0)),
(10, Some(1), Some(10), Some(0)),
(10, Inexact(1), Inexact(10), Inexact(0), Absent),
(10, Inexact(1), Inexact(10), Inexact(0), Absent),
None,
),
// Inexact row count < exact null count with absent distinct count
(
(0, Inexact(1), Inexact(10), Absent, Exact(5)),
(10, Inexact(1), Inexact(10), Absent, Absent),
Some(Inexact(0)),
),
];

for (left_info, right_info, expected_cardinality) in cases {
let left_num_rows = left_info.0;
let left_col_stats =
vec![create_column_stats(left_info.1, left_info.2, left_info.3)];
let left_col_stats = vec![create_column_stats(
left_info.1,
left_info.2,
left_info.3,
left_info.4,
)];

let right_num_rows = right_info.0;
let right_col_stats = vec![create_column_stats(
right_info.1,
right_info.2,
right_info.3,
right_info.4,
)];

assert_eq!(
estimate_inner_join_cardinality(
Statistics {
num_rows: Precision::Inexact(left_num_rows),
total_byte_size: Precision::Absent,
num_rows: Inexact(left_num_rows),
total_byte_size: Absent,
column_statistics: left_col_stats.clone(),
},
Statistics {
num_rows: Precision::Inexact(right_num_rows),
total_byte_size: Precision::Absent,
num_rows: Inexact(right_num_rows),
total_byte_size: Absent,
column_statistics: right_col_stats.clone(),
},
),
Expand All @@ -1819,9 +1844,7 @@ mod tests {
);

assert_eq!(
partial_join_stats
.clone()
.map(|s| Precision::Inexact(s.num_rows)),
partial_join_stats.clone().map(|s| Inexact(s.num_rows)),
expected_cardinality.clone()
);
assert_eq!(
Expand All @@ -1837,13 +1860,13 @@ mod tests {
#[test]
fn test_inner_join_cardinality_multiple_column() -> Result<()> {
let left_col_stats = vec![
create_column_stats(Some(0), Some(100), Some(100)),
create_column_stats(Some(100), Some(500), Some(150)),
create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
create_column_stats(Inexact(100), Inexact(500), Inexact(150), Absent),
];

let right_col_stats = vec![
create_column_stats(Some(0), Some(100), Some(50)),
create_column_stats(Some(100), Some(500), Some(200)),
create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
create_column_stats(Inexact(100), Inexact(500), Inexact(200), Absent),
];

// We have statistics about 4 columns, where the highest distinct
Expand Down Expand Up @@ -1921,15 +1944,15 @@ mod tests {
];

let left_col_stats = vec![
create_column_stats(Some(0), Some(100), Some(100)),
create_column_stats(Some(0), Some(500), Some(500)),
create_column_stats(Some(1000), Some(10000), None),
create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent),
create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent),
];

let right_col_stats = vec![
create_column_stats(Some(0), Some(100), Some(50)),
create_column_stats(Some(0), Some(2000), Some(2500)),
create_column_stats(Some(0), Some(100), None),
create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent),
create_column_stats(Inexact(0), Inexact(100), Absent, Absent),
];

for (join_type, expected_num_rows) in cases {
Expand Down Expand Up @@ -1970,15 +1993,15 @@ mod tests {
// Join on a=c, x=y (ignores b/d) where x and y does not intersect

let left_col_stats = vec![
create_column_stats(Some(0), Some(100), Some(100)),
create_column_stats(Some(0), Some(500), Some(500)),
create_column_stats(Some(1000), Some(10000), None),
create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent),
create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent),
create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent),
];

let right_col_stats = vec![
create_column_stats(Some(0), Some(100), Some(50)),
create_column_stats(Some(0), Some(2000), Some(2500)),
create_column_stats(Some(0), Some(100), None),
create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent),
create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent),
create_column_stats(Inexact(0), Inexact(100), Absent, Absent),
];

let join_on = vec![
Expand Down

0 comments on commit 13def2c

Please sign in to comment.