diff --git a/datafusion/core/src/physical_optimizer/sort_enforcement.rs b/datafusion/core/src/physical_optimizer/sort_enforcement.rs index 52463b4bdc09..2d785e920a26 100644 --- a/datafusion/core/src/physical_optimizer/sort_enforcement.rs +++ b/datafusion/core/src/physical_optimizer/sort_enforcement.rs @@ -589,71 +589,84 @@ mod tests { Ok(()) } + /// Runs the sort enforcement optimizer and asserts the plan + /// against the original and expected plans + /// + /// `$EXPECTED_PLAN_LINES`: input plan + /// `$EXPECTED_OPTIMIZED_PLAN_LINES`: optimized plan + /// `$PLAN`: the plan to optimized + /// + macro_rules! assert_optimized { + ($EXPECTED_PLAN_LINES: expr, $EXPECTED_OPTIMIZED_PLAN_LINES: expr, $PLAN: expr) => { + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + + let physical_plan = $PLAN; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + + let expected_plan_lines: Vec<&str> = $EXPECTED_PLAN_LINES + .iter().map(|s| *s).collect(); + + assert_eq!( + expected_plan_lines, actual, + "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected_optimized_lines: Vec<&str> = $EXPECTED_OPTIMIZED_PLAN_LINES + .iter().map(|s| *s).collect(); + + // Run the actual optimizer + let optimized_physical_plan = + EnforceSorting::new().optimize(physical_plan, state.config_options())?; + + let formatted = displayable(optimized_physical_plan.as_ref()) + .indent() + .to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected_optimized_lines, actual, + "\n**Optimized Plan Mismatch\n\nexpected:\n\n{expected_optimized_lines:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + }; + } + #[tokio::test] async fn test_remove_unnecessary_sort() -> Result<()> { - let session_ctx = SessionContext::new(); - let state = session_ctx.state(); let schema = create_test_schema()?; - let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) - as Arc; - let sort_exprs = vec![PhysicalSortExpr { - expr: col("non_nullable_col", schema.as_ref()).unwrap(), - options: SortOptions::default(), - }]; - let sort_exec = Arc::new(SortExec::try_new(sort_exprs, source, None)?) - as Arc; - let sort_exprs = vec![PhysicalSortExpr { - expr: col("nullable_col", schema.as_ref()).unwrap(), - options: SortOptions::default(), - }]; - let physical_plan = Arc::new(SortExec::try_new(sort_exprs, sort_exec, None)?) - as Arc; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let expected = { - vec![ - "SortExec: [nullable_col@0 ASC]", - " SortExec: [non_nullable_col@1 ASC]", - ] - }; - let actual: Vec<&str> = formatted.trim().lines().collect(); - let actual_len = actual.len(); - let actual_trim_last = &actual[..actual_len - 1]; - assert_eq!( - expected, actual_trim_last, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let optimized_physical_plan = - EnforceSorting::new().optimize(physical_plan, state.config_options())?; - let formatted = displayable(optimized_physical_plan.as_ref()) - .indent() - .to_string(); - let expected = { vec!["SortExec: [nullable_col@0 ASC]"] }; - let actual: Vec<&str> = formatted.trim().lines().collect(); - let actual_len = actual.len(); - let actual_trim_last = &actual[..actual_len - 1]; - assert_eq!( - expected, actual_trim_last, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); + let source = memory_exec(&schema); + let input = sort_exec(vec![sort_expr("non_nullable_col", &schema)], source); + let physical_plan = sort_exec(vec![sort_expr("nullable_col", &schema)], input); + + let expected_input = vec![ + "SortExec: [nullable_col@0 ASC]", + " SortExec: [non_nullable_col@1 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ]; + let expected_optimized = vec![ + "SortExec: [nullable_col@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) } #[tokio::test] async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { - let session_ctx = SessionContext::new(); - let state = session_ctx.state(); let schema = create_test_schema()?; - let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) - as Arc; - let sort_exprs = vec![PhysicalSortExpr { - expr: col("non_nullable_col", source.schema().as_ref()).unwrap(), - options: SortOptions { + let source = memory_exec(&schema); + + let sort_exprs = vec![sort_expr_options( + "non_nullable_col", + &source.schema(), + SortOptions { descending: true, nulls_first: true, }, - }]; - let sort_exec = Arc::new(SortExec::try_new(sort_exprs.clone(), source, None)?) - as Arc; + )]; + let sort = sort_exec(sort_exprs.clone(), source); + let window_agg_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( &WindowFunction::AggregateFunction(AggregateFunction::Count), @@ -664,32 +677,33 @@ mod tests { Arc::new(WindowFrame::new(true)), schema.as_ref(), )?], - sort_exec.clone(), - sort_exec.schema(), + sort.clone(), + sort.schema(), vec![], Some(sort_exprs), )?) as Arc; - let sort_exprs = vec![PhysicalSortExpr { - expr: col("non_nullable_col", window_agg_exec.schema().as_ref()).unwrap(), - options: SortOptions { + + let sort_exprs = vec![sort_expr_options( + "non_nullable_col", + &window_agg_exec.schema(), + SortOptions { descending: false, nulls_first: false, }, - }]; - let sort_exec = Arc::new(SortExec::try_new( - sort_exprs.clone(), - window_agg_exec, - None, - )?) as Arc; + )]; + + let sort = sort_exec(sort_exprs.clone(), window_agg_exec); + // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before - let filter_exec = Arc::new(FilterExec::try_new( + let filter = filter_exec( Arc::new(NotExpr::new( col("non_nullable_col", schema.as_ref()).unwrap(), )), - sort_exec, - )?) as Arc; + sort, + ); + // let filter_exec = sort_exec; - let window_agg_exec = Arc::new(WindowAggExec::try_new( + let physical_plan = Arc::new(WindowAggExec::try_new( vec![create_window_expr( &WindowFunction::AggregateFunction(AggregateFunction::Count), "count".to_owned(), @@ -699,214 +713,147 @@ mod tests { Arc::new(WindowFrame::new(true)), schema.as_ref(), )?], - filter_exec.clone(), - filter_exec.schema(), + filter.clone(), + filter.schema(), vec![], Some(sort_exprs), )?) as Arc; - let physical_plan = window_agg_exec; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let expected = { - vec![ - "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]", - " FilterExec: NOT non_nullable_col@1", - " SortExec: [non_nullable_col@1 ASC NULLS LAST]", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]", - " SortExec: [non_nullable_col@1 DESC]", - " MemoryExec: partitions=0, partition_sizes=[]", - ] - }; - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let optimized_physical_plan = - EnforceSorting::new().optimize(physical_plan, state.config_options())?; - let formatted = displayable(optimized_physical_plan.as_ref()) - .indent() - .to_string(); - let expected = { - vec![ - "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL) }]", - " FilterExec: NOT non_nullable_col@1", - " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]", - " SortExec: [non_nullable_col@1 DESC]", - " MemoryExec: partitions=0, partition_sizes=[]", - ] - }; - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); + + let expected_input = vec![ + "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]", + " FilterExec: NOT non_nullable_col@1", + " SortExec: [non_nullable_col@1 ASC NULLS LAST]", + " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]", + " SortExec: [non_nullable_col@1 DESC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ]; + + let expected_optimized = vec![ + "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL) }]", + " FilterExec: NOT non_nullable_col@1", + " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]", + " SortExec: [non_nullable_col@1 DESC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) } #[tokio::test] async fn test_add_required_sort() -> Result<()> { - let session_ctx = SessionContext::new(); - let state = session_ctx.state(); let schema = create_test_schema()?; - let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) - as Arc; - let sort_exprs = vec![PhysicalSortExpr { - expr: col("nullable_col", schema.as_ref()).unwrap(), - options: SortOptions::default(), - }]; - let physical_plan = Arc::new(SortPreservingMergeExec::new(sort_exprs, source)) - as Arc; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let expected = { vec!["SortPreservingMergeExec: [nullable_col@0 ASC]"] }; - let actual: Vec<&str> = formatted.trim().lines().collect(); - let actual_len = actual.len(); - let actual_trim_last = &actual[..actual_len - 1]; - assert_eq!( - expected, actual_trim_last, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let optimized_physical_plan = - EnforceSorting::new().optimize(physical_plan, state.config_options())?; - let formatted = displayable(optimized_physical_plan.as_ref()) - .indent() - .to_string(); - let expected = { - vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: [nullable_col@0 ASC]", - ] - }; - let actual: Vec<&str> = formatted.trim().lines().collect(); - let actual_len = actual.len(); - let actual_trim_last = &actual[..actual_len - 1]; - assert_eq!( - expected, actual_trim_last, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); + let source = memory_exec(&schema); + + let sort_exprs = vec![sort_expr("nullable_col", &schema)]; + + let physical_plan = sort_preserving_merge_exec(sort_exprs, source); + + let expected_input = vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ]; + let expected_optimized = vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) } #[tokio::test] async fn test_remove_unnecessary_sort1() -> Result<()> { - let session_ctx = SessionContext::new(); - let state = session_ctx.state(); let schema = create_test_schema()?; - let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) - as Arc; - let sort_exprs = vec![PhysicalSortExpr { - expr: col("nullable_col", schema.as_ref()).unwrap(), - options: SortOptions::default(), - }]; - let sort_exec = Arc::new(SortExec::try_new(sort_exprs.clone(), source, None)?) - as Arc; - let sort_preserving_merge_exec = - Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec)) - as Arc; - let sort_exprs = vec![PhysicalSortExpr { - expr: col("nullable_col", schema.as_ref()).unwrap(), - options: SortOptions::default(), - }]; - let sort_exec = Arc::new(SortExec::try_new( - sort_exprs.clone(), - sort_preserving_merge_exec, - None, - )?) as Arc; - let sort_preserving_merge_exec = - Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec)) - as Arc; - let physical_plan = sort_preserving_merge_exec; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let expected = { - vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: [nullable_col@0 ASC]", - " SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: [nullable_col@0 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", - ] - }; - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let optimized_physical_plan = - EnforceSorting::new().optimize(physical_plan, state.config_options())?; - let formatted = displayable(optimized_physical_plan.as_ref()) - .indent() - .to_string(); - let expected = { - vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: [nullable_col@0 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", - ] - }; - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); + let source = memory_exec(&schema); + let sort_exprs = vec![sort_expr("nullable_col", &schema)]; + let sort = sort_exec(sort_exprs.clone(), source); + let spm = sort_preserving_merge_exec(sort_exprs, sort); + + let sort_exprs = vec![sort_expr("nullable_col", &schema)]; + let sort = sort_exec(sort_exprs.clone(), spm); + let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); + let expected_input = vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + " SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ]; + let expected_optimized = vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) } #[tokio::test] async fn test_change_wrong_sorting() -> Result<()> { - let session_ctx = SessionContext::new(); - let state = session_ctx.state(); let schema = create_test_schema()?; - let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) - as Arc; + let source = memory_exec(&schema); let sort_exprs = vec![ - PhysicalSortExpr { - expr: col("nullable_col", schema.as_ref()).unwrap(), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: col("non_nullable_col", schema.as_ref()).unwrap(), - options: SortOptions::default(), - }, + sort_expr("nullable_col", &schema), + sort_expr("non_nullable_col", &schema), ]; - let sort_exec = Arc::new(SortExec::try_new( - vec![sort_exprs[0].clone()], - source, - None, - )?) as Arc; - let sort_preserving_merge_exec = - Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec)) - as Arc; - let physical_plan = sort_preserving_merge_exec; - let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - let expected = { - vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", - " SortExec: [nullable_col@0 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", - ] - }; - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); - let optimized_physical_plan = - EnforceSorting::new().optimize(physical_plan, state.config_options())?; - let formatted = displayable(optimized_physical_plan.as_ref()) - .indent() - .to_string(); - let expected = { - vec![ - "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", - " SortExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", - " MemoryExec: partitions=0, partition_sizes=[]", - ] - }; - let actual: Vec<&str> = formatted.trim().lines().collect(); - assert_eq!( - expected, actual, - "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" - ); + let sort = sort_exec(vec![sort_exprs[0].clone()], source); + let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); + let expected_input = vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: [nullable_col@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ]; + let expected_optimized = vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan); Ok(()) } + + /// make PhysicalSortExpr with default options + fn sort_expr(name: &str, schema: &Schema) -> PhysicalSortExpr { + sort_expr_options(name, schema, SortOptions::default()) + } + + /// PhysicalSortExpr with specified options + fn sort_expr_options( + name: &str, + schema: &Schema, + options: SortOptions, + ) -> PhysicalSortExpr { + PhysicalSortExpr { + expr: col(name, schema).unwrap(), + options, + } + } + + fn memory_exec(schema: &SchemaRef) -> Arc { + Arc::new(MemoryExec::try_new(&[], schema.clone(), None).unwrap()) + } + + fn sort_exec( + sort_exprs: impl IntoIterator, + input: Arc, + ) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new(SortExec::try_new(sort_exprs, input, None).unwrap()) + } + + fn sort_preserving_merge_exec( + sort_exprs: impl IntoIterator, + input: Arc, + ) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) + } + + fn filter_exec( + predicate: Arc, + input: Arc, + ) -> Arc { + Arc::new(FilterExec::try_new(predicate, input).unwrap()) + } }