From 13def2c1efa6f6eb93fdabcd3c4d4e29c2a7e413 Mon Sep 17 00:00:00 2001 From: Marko Grujic Date: Fri, 26 Jan 2024 14:57:12 +0100 Subject: [PATCH] Add test hitting the former overflow panic --- datafusion/common/src/stats.rs | 7 +- datafusion/physical-plan/src/joins/utils.rs | 197 +++++++++++--------- 2 files changed, 114 insertions(+), 90 deletions(-) diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index 7ad8992ca9ae..a10e05a55c64 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -48,14 +48,15 @@ impl Precision { /// Transform the value in this [`Precision`] object, if one exists, using /// the given function. Preserves the exactness state. - pub fn map(self, f: F) -> Precision + pub fn map(self, f: F) -> Precision where - F: Fn(T) -> T, + F: Fn(T) -> U, + U: Debug + Clone + PartialEq + Eq + PartialOrd, { match self { Precision::Exact(val) => Precision::Exact(f(val)), Precision::Inexact(val) => Precision::Inexact(f(val)), - _ => self, + _ => Precision::::Absent, } } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 9b65496187a5..cd987ab40d45 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -1473,6 +1473,7 @@ mod tests { use arrow::error::{ArrowError, Result as ArrowResult}; use arrow_schema::SortOptions; + use datafusion_common::stats::Precision::{Absent, Exact, Inexact}; use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> { @@ -1640,25 +1641,26 @@ mod tests { } fn create_column_stats( - min: Option, - max: Option, - distinct_count: Option, + min: Precision, + max: Precision, + distinct_count: Precision, + null_count: Precision, ) -> ColumnStatistics { ColumnStatistics { - distinct_count: distinct_count - .map(Precision::Inexact) - .unwrap_or(Precision::Absent), - min_value: min - .map(|size| Precision::Inexact(ScalarValue::from(size))) - .unwrap_or(Precision::Absent), - max_value: max - .map(|size| Precision::Inexact(ScalarValue::from(size))) - .unwrap_or(Precision::Absent), - ..Default::default() + distinct_count, + min_value: min.map(ScalarValue::from), + max_value: max.map(ScalarValue::from), + null_count, } } - type PartialStats = (usize, Option, Option, Option); + type PartialStats = ( + usize, + Precision, + Precision, + Precision, + Precision, + ); // This is mainly for validating the all edge cases of the estimation, but // more advanced (and real world test cases) are below where we need some control @@ -1675,133 +1677,156 @@ mod tests { // // distinct(left) == NaN, distinct(right) == NaN ( - (10, Some(1), Some(10), None), - (10, Some(1), Some(10), None), - Some(Precision::Inexact(10)), + (10, Inexact(1), Inexact(10), Absent, Absent), + (10, Inexact(1), Inexact(10), Absent, Absent), + Some(Inexact(10)), ), // range(left) > range(right) ( - (10, Some(6), Some(10), None), - (10, Some(8), Some(10), None), - Some(Precision::Inexact(20)), + (10, Inexact(6), Inexact(10), Absent, Absent), + (10, Inexact(8), Inexact(10), Absent, Absent), + Some(Inexact(20)), ), // range(right) > range(left) ( - (10, Some(8), Some(10), None), - (10, Some(6), Some(10), None), - Some(Precision::Inexact(20)), + (10, Inexact(8), Inexact(10), Absent, Absent), + (10, Inexact(6), Inexact(10), Absent, Absent), + Some(Inexact(20)), ), // range(left) > len(left), range(right) > len(right) ( - (10, Some(1), Some(15), None), - (20, Some(1), Some(40), None), - Some(Precision::Inexact(10)), + (10, Inexact(1), Inexact(15), Absent, Absent), + (20, Inexact(1), Inexact(40), Absent, Absent), + Some(Inexact(10)), ), // When we have distinct count. ( - (10, Some(1), Some(10), Some(10)), - (10, Some(1), Some(10), Some(10)), - Some(Precision::Inexact(10)), + (10, Inexact(1), Inexact(10), Inexact(10), Absent), + (10, Inexact(1), Inexact(10), Inexact(10), Absent), + Some(Inexact(10)), ), // distinct(left) > distinct(right) ( - (10, Some(1), Some(10), Some(5)), - (10, Some(1), Some(10), Some(2)), - Some(Precision::Inexact(20)), + (10, Inexact(1), Inexact(10), Inexact(5), Absent), + (10, Inexact(1), Inexact(10), Inexact(2), Absent), + Some(Inexact(20)), ), // distinct(right) > distinct(left) ( - (10, Some(1), Some(10), Some(2)), - (10, Some(1), Some(10), Some(5)), - Some(Precision::Inexact(20)), + (10, Inexact(1), Inexact(10), Inexact(2), Absent), + (10, Inexact(1), Inexact(10), Inexact(5), Absent), + Some(Inexact(20)), ), // min(left) < 0 (range(left) > range(right)) ( - (10, Some(-5), Some(5), None), - (10, Some(1), Some(5), None), - Some(Precision::Inexact(10)), + (10, Inexact(-5), Inexact(5), Absent, Absent), + (10, Inexact(1), Inexact(5), Absent, Absent), + Some(Inexact(10)), ), // min(right) < 0, max(right) < 0 (range(right) > range(left)) ( - (10, Some(-25), Some(-20), None), - (10, Some(-25), Some(-15), None), - Some(Precision::Inexact(10)), + (10, Inexact(-25), Inexact(-20), Absent, Absent), + (10, Inexact(-25), Inexact(-15), Absent, Absent), + Some(Inexact(10)), ), // range(left) < 0, range(right) >= 0 // (there isn't a case where both left and right ranges are negative // so one of them is always going to work, this just proves negative // ranges with bigger absolute values are not are not accidentally used). ( - (10, Some(-10), Some(0), None), - (10, Some(0), Some(10), Some(5)), - Some(Precision::Inexact(10)), + (10, Inexact(-10), Inexact(0), Absent, Absent), + (10, Inexact(0), Inexact(10), Inexact(5), Absent), + Some(Inexact(10)), ), // range(left) = 1, range(right) = 1 ( - (10, Some(1), Some(1), None), - (10, Some(1), Some(1), None), - Some(Precision::Inexact(100)), + (10, Inexact(1), Inexact(1), Absent, Absent), + (10, Inexact(1), Inexact(1), Absent, Absent), + Some(Inexact(100)), ), // // Edge cases // ========== // // No column level stats. - ((10, None, None, None), (10, None, None, None), None), + ( + (10, Absent, Absent, Absent, Absent), + (10, Absent, Absent, Absent, Absent), + None, + ), // No min or max (or both). - ((10, None, None, Some(3)), (10, None, None, Some(3)), None), ( - (10, Some(2), None, Some(3)), - (10, None, Some(5), Some(3)), + (10, Absent, Absent, Inexact(3), Absent), + (10, Absent, Absent, Inexact(3), Absent), + None, + ), + ( + (10, Inexact(2), Absent, Inexact(3), Absent), + (10, Absent, Inexact(5), Inexact(3), Absent), None, ), ( - (10, None, Some(3), Some(3)), - (10, Some(1), None, Some(3)), + (10, Absent, Inexact(3), Inexact(3), Absent), + (10, Inexact(1), Absent, Inexact(3), Absent), + None, + ), + ( + (10, Absent, Inexact(3), Absent, Absent), + (10, Inexact(1), Absent, Absent, Absent), None, ), - ((10, None, Some(3), None), (10, Some(1), None, None), None), // Non overlapping min/max (when exact=False). ( - (10, Some(0), Some(10), None), - (10, Some(11), Some(20), None), - Some(Precision::Inexact(0)), + (10, Inexact(0), Inexact(10), Absent, Absent), + (10, Inexact(11), Inexact(20), Absent, Absent), + Some(Inexact(0)), ), ( - (10, Some(11), Some(20), None), - (10, Some(0), Some(10), None), - Some(Precision::Inexact(0)), + (10, Inexact(11), Inexact(20), Absent, Absent), + (10, Inexact(0), Inexact(10), Absent, Absent), + Some(Inexact(0)), ), // distinct(left) = 0, distinct(right) = 0 ( - (10, Some(1), Some(10), Some(0)), - (10, Some(1), Some(10), Some(0)), + (10, Inexact(1), Inexact(10), Inexact(0), Absent), + (10, Inexact(1), Inexact(10), Inexact(0), Absent), None, ), + // Inexact row count < exact null count with absent distinct count + ( + (0, Inexact(1), Inexact(10), Absent, Exact(5)), + (10, Inexact(1), Inexact(10), Absent, Absent), + Some(Inexact(0)), + ), ]; for (left_info, right_info, expected_cardinality) in cases { let left_num_rows = left_info.0; - let left_col_stats = - vec![create_column_stats(left_info.1, left_info.2, left_info.3)]; + let left_col_stats = vec![create_column_stats( + left_info.1, + left_info.2, + left_info.3, + left_info.4, + )]; let right_num_rows = right_info.0; let right_col_stats = vec![create_column_stats( right_info.1, right_info.2, right_info.3, + right_info.4, )]; assert_eq!( estimate_inner_join_cardinality( Statistics { - num_rows: Precision::Inexact(left_num_rows), - total_byte_size: Precision::Absent, + num_rows: Inexact(left_num_rows), + total_byte_size: Absent, column_statistics: left_col_stats.clone(), }, Statistics { - num_rows: Precision::Inexact(right_num_rows), - total_byte_size: Precision::Absent, + num_rows: Inexact(right_num_rows), + total_byte_size: Absent, column_statistics: right_col_stats.clone(), }, ), @@ -1819,9 +1844,7 @@ mod tests { ); assert_eq!( - partial_join_stats - .clone() - .map(|s| Precision::Inexact(s.num_rows)), + partial_join_stats.clone().map(|s| Inexact(s.num_rows)), expected_cardinality.clone() ); assert_eq!( @@ -1837,13 +1860,13 @@ mod tests { #[test] fn test_inner_join_cardinality_multiple_column() -> Result<()> { let left_col_stats = vec![ - create_column_stats(Some(0), Some(100), Some(100)), - create_column_stats(Some(100), Some(500), Some(150)), + create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent), + create_column_stats(Inexact(100), Inexact(500), Inexact(150), Absent), ]; let right_col_stats = vec![ - create_column_stats(Some(0), Some(100), Some(50)), - create_column_stats(Some(100), Some(500), Some(200)), + create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent), + create_column_stats(Inexact(100), Inexact(500), Inexact(200), Absent), ]; // We have statistics about 4 columns, where the highest distinct @@ -1921,15 +1944,15 @@ mod tests { ]; let left_col_stats = vec![ - create_column_stats(Some(0), Some(100), Some(100)), - create_column_stats(Some(0), Some(500), Some(500)), - create_column_stats(Some(1000), Some(10000), None), + create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent), + create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent), + create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent), ]; let right_col_stats = vec![ - create_column_stats(Some(0), Some(100), Some(50)), - create_column_stats(Some(0), Some(2000), Some(2500)), - create_column_stats(Some(0), Some(100), None), + create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent), + create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent), + create_column_stats(Inexact(0), Inexact(100), Absent, Absent), ]; for (join_type, expected_num_rows) in cases { @@ -1970,15 +1993,15 @@ mod tests { // Join on a=c, x=y (ignores b/d) where x and y does not intersect let left_col_stats = vec![ - create_column_stats(Some(0), Some(100), Some(100)), - create_column_stats(Some(0), Some(500), Some(500)), - create_column_stats(Some(1000), Some(10000), None), + create_column_stats(Inexact(0), Inexact(100), Inexact(100), Absent), + create_column_stats(Inexact(0), Inexact(500), Inexact(500), Absent), + create_column_stats(Inexact(1000), Inexact(10000), Absent, Absent), ]; let right_col_stats = vec![ - create_column_stats(Some(0), Some(100), Some(50)), - create_column_stats(Some(0), Some(2000), Some(2500)), - create_column_stats(Some(0), Some(100), None), + create_column_stats(Inexact(0), Inexact(100), Inexact(50), Absent), + create_column_stats(Inexact(0), Inexact(2000), Inexact(2500), Absent), + create_column_stats(Inexact(0), Inexact(100), Absent, Absent), ]; let join_on = vec![