diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index 66d1380e09c3..2c0e73080719 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -25,6 +25,7 @@ use crate::physical_optimizer::test_utils::{ }; use crate::physical_optimizer::test_utils::{parquet_exec_with_sort, trim_plan_display}; +use crate::sql::ExplainNormalizer; use arrow::compute::SortOptions; use datafusion::config::ConfigOptions; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; @@ -32,9 +33,13 @@ use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{CsvSource, FileScanConfig, ParquetSource}; use datafusion::datasource::source::DataSourceExec; +use datafusion::execution::SessionStateBuilder; +use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; +use datafusion::prelude::SessionContext; use datafusion_common::error::Result; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; +use datafusion_execution::config::SessionConfig; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; use datafusion_physical_expr::PhysicalExpr; @@ -42,9 +47,13 @@ use datafusion_physical_expr::{ expressions::binary, expressions::lit, LexOrdering, PhysicalSortExpr, }; use datafusion_physical_expr_common::sort_expr::LexRequirement; +use datafusion_physical_optimizer::coalesce_batches::CoalesceBatches; use datafusion_physical_optimizer::enforce_distribution::*; use datafusion_physical_optimizer::enforce_sorting::EnforceSorting; +use datafusion_physical_optimizer::limit_pushdown::LimitPushdown; use datafusion_physical_optimizer::output_requirements::OutputRequirements; +use datafusion_physical_optimizer::projection_pushdown::ProjectionPushdown; +use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, @@ -62,6 +71,7 @@ use datafusion_physical_plan::union::UnionExec; use datafusion_physical_plan::ExecutionPlanProperties; use datafusion_physical_plan::PlanProperties; use datafusion_physical_plan::{displayable, DisplayAs, DisplayFormatType, Statistics}; +use futures::StreamExt; /// Models operators like BoundedWindowExec that require an input /// ordering but is easy to construct @@ -3154,3 +3164,94 @@ fn optimize_away_unnecessary_repartition2() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn apply_enforce_distribution_multiple_times() -> Result<()> { + // Create a configuration + let config = SessionConfig::new(); + let ctx = SessionContext::new_with_config(config); + let testdata = datafusion::test_util::arrow_test_data(); + let csv_file = format!("{testdata}/csv/aggregate_test_100.csv"); + // Create table schema and data + let sql = format!( + "CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 BIGINT UNSIGNED NOT NULL, + c10 VARCHAR NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL + ) + STORED AS CSV + LOCATION '{csv_file}' + OPTIONS ('format.has_header' 'true')" + ); + + ctx.sql(sql.as_str()).await?; + + let df = ctx.sql("SELECT * FROM(SELECT * FROM aggregate_test_100 UNION ALL SELECT * FROM aggregate_test_100) ORDER BY c13 LIMIT 5").await?; + let logical_plan = df.logical_plan().clone(); + let analyzed_logical_plan = ctx.state().analyzer().execute_and_check( + logical_plan, + ctx.state().config_options(), + |_, _| (), + )?; + let optimized_logical_plan = ctx.state().optimizer().optimize( + analyzed_logical_plan, + &ctx.state(), + |_, _| (), + )?; + + let planner = DefaultPhysicalPlanner::default(); + let session_state = SessionStateBuilder::new() + .with_config(ctx.copied_config()) + .with_default_features() + // The second `EnforceDistribution` should be run with `OutputRequirements` to reproduce the bug. + .with_physical_optimizer_rule(Arc::new(OutputRequirements::new_add_mode())) + .with_physical_optimizer_rule(Arc::new(EnforceDistribution::new())) // -- Add enforce distribution rule again + .with_physical_optimizer_rule(Arc::new(OutputRequirements::new_remove_mode())) + .build(); + let optimized_physical_plan = planner + .create_physical_plan(&optimized_logical_plan, &session_state) + .await?; + + let normalizer = ExplainNormalizer::new(); + let actual = format!( + "{}", + displayable(optimized_physical_plan.as_ref()).indent(true) + ) + .trim() + .lines() + // normalize paths + .map(|s| normalizer.normalize(s)) + .collect::>(); + // Test the optimized plan is correct (after twice `EnforceDistribution`) + // The `fetch` is maintained after the second `EnforceDistribution` + let expected = vec![ + "SortExec: TopK(fetch=5), expr=[c13@12 ASC NULLS LAST], preserve_partitioning=[false]", + " CoalescePartitionsExec", + " UnionExec", + " SortExec: TopK(fetch=5), expr=[c13@12 ASC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], file_type=csv, has_header=true", + " SortExec: TopK(fetch=5), expr=[c13@12 ASC NULLS LAST], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], file_type=csv, has_header=true", + ]; + assert_eq!( + expected, actual, + "expected:\n{expected:#?}\nactual:\n\n{actual:#?}\n" + ); + + let mut results = optimized_physical_plan.execute(0, ctx.task_ctx().clone())?; + + let batch = results.next().await.unwrap()?; + // Without the fix of https://github.com/apache/datafusion/pull/14207, the number of rows will be 10 + assert_eq!(batch.num_rows(), 5); + Ok(()) +} diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 03c4ad7c013e..69412a8be903 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -243,7 +243,7 @@ pub struct ExplainNormalizer { } impl ExplainNormalizer { - fn new() -> Self { + pub(crate) fn new() -> Self { let mut replacements = vec![]; let mut push_path = |path: PathBuf, key: &str| { @@ -266,7 +266,7 @@ impl ExplainNormalizer { Self { replacements } } - fn normalize(&self, s: impl Into) -> String { + pub(crate) fn normalize(&self, s: impl Into) -> String { let mut s = s.into(); for (from, to) in &self.replacements { s = s.replace(from, to); diff --git a/datafusion/physical-optimizer/src/enforce_distribution.rs b/datafusion/physical-optimizer/src/enforce_distribution.rs index 5e76edad1f56..54507e8fa1fc 100644 --- a/datafusion/physical-optimizer/src/enforce_distribution.rs +++ b/datafusion/physical-optimizer/src/enforce_distribution.rs @@ -21,7 +21,8 @@ //! according to the configuration), this rule increases partition counts in //! the physical plan. -use std::fmt::Debug; +use std::fmt; +use std::fmt::{Debug, Display, Formatter}; use std::sync::Arc; use crate::optimizer::PhysicalOptimizerRule; @@ -862,7 +863,11 @@ fn add_roundrobin_on_top( let new_plan = Arc::new(repartition) as _; - Ok(DistributionContext::new(new_plan, true, vec![input])) + Ok(DistributionContext::new( + new_plan, + DistributionData::new(true), + vec![input], + )) } else { // Partition is not helpful, we already have desired number of partitions. Ok(input) @@ -920,7 +925,11 @@ fn add_hash_on_top( .with_preserve_order(); let plan = Arc::new(repartition) as _; - return Ok(DistributionContext::new(plan, true, vec![input])); + return Ok(DistributionContext::new( + plan, + DistributionData::new(true), + vec![input], + )); } Ok(input) @@ -932,12 +941,16 @@ fn add_hash_on_top( /// # Arguments /// /// * `input`: Current node. +/// * `fetch`: Possible fetch value /// /// # Returns /// /// Updated node with an execution plan, where desired single /// distribution is satisfied by adding [`SortPreservingMergeExec`]. -fn add_spm_on_top(input: DistributionContext) -> DistributionContext { +fn add_spm_on_top( + input: DistributionContext, + fetch: &mut Option, +) -> DistributionContext { // Add SortPreservingMerge only when partition count is larger than 1. if input.plan.output_partitioning().partition_count() > 1 { // When there is an existing ordering, we preserve ordering @@ -949,19 +962,22 @@ fn add_spm_on_top(input: DistributionContext) -> DistributionContext { let should_preserve_ordering = input.plan.output_ordering().is_some(); let new_plan = if should_preserve_ordering { - Arc::new(SortPreservingMergeExec::new( - input - .plan - .output_ordering() - .unwrap_or(&LexOrdering::default()) - .clone(), - Arc::clone(&input.plan), - )) as _ + Arc::new( + SortPreservingMergeExec::new( + input + .plan + .output_ordering() + .unwrap_or(&LexOrdering::default()) + .clone(), + Arc::clone(&input.plan), + ) + .with_fetch(fetch.take()), + ) as _ } else { Arc::new(CoalescePartitionsExec::new(Arc::clone(&input.plan))) as _ }; - DistributionContext::new(new_plan, true, vec![input]) + DistributionContext::new(new_plan, DistributionData::new(true), vec![input]) } else { input } @@ -987,16 +1003,22 @@ fn add_spm_on_top(input: DistributionContext) -> DistributionContext { fn remove_dist_changing_operators( mut distribution_context: DistributionContext, ) -> Result { + let mut fetch = None; while is_repartition(&distribution_context.plan) || is_coalesce_partitions(&distribution_context.plan) || is_sort_preserving_merge(&distribution_context.plan) { + if is_sort_preserving_merge(&distribution_context.plan) { + if let Some(child_fetch) = distribution_context.plan.fetch() { + fetch = Some(fetch.map_or(child_fetch, |f: usize| f.min(child_fetch))); + } + } // All of above operators have a single child. First child is only child. // Remove any distribution changing operators at the beginning: distribution_context = distribution_context.children.swap_remove(0); // Note that they will be re-inserted later on if necessary or helpful. } - + distribution_context.data.fetch = fetch; Ok(distribution_context) } @@ -1021,21 +1043,25 @@ fn remove_dist_changing_operators( fn replace_order_preserving_variants( mut context: DistributionContext, ) -> Result { - context.children = context - .children - .into_iter() - .map(|child| { - if child.data { - replace_order_preserving_variants(child) - } else { - Ok(child) - } - }) - .collect::>>()?; + let mut children = vec![]; + let mut fetch = None; + for child in context.children.into_iter() { + if child.data.has_dist_changing { + let mut child = replace_order_preserving_variants(child)?; + fetch = child.data.fetch.take(); + children.push(child); + } else { + children.push(child); + } + } + context.children = children; if is_sort_preserving_merge(&context.plan) { + // Keep the fetch value of the SortPreservingMerge operator, maybe it will be used later. + let fetch = context.plan.fetch(); let child_plan = Arc::clone(&context.children[0].plan); context.plan = Arc::new(CoalescePartitionsExec::new(child_plan)); + context.data.fetch = fetch; return Ok(context); } else if let Some(repartition) = context.plan.as_any().downcast_ref::() @@ -1049,6 +1075,7 @@ fn replace_order_preserving_variants( } } + context.data.fetch = fetch; context.update_plan_from_children() } @@ -1188,9 +1215,10 @@ pub fn ensure_distribution( // Remove unnecessary repartition from the physical plan if any let DistributionContext { mut plan, - data, + mut data, children, } = remove_dist_changing_operators(dist_context)?; + let mut fetch = data.fetch.take(); if let Some(exec) = plan.as_any().downcast_ref::() { if let Some(updated_window) = get_best_fitting_window( @@ -1255,7 +1283,7 @@ pub fn ensure_distribution( // Satisfy the distribution requirement if it is unmet. match &requirement { Distribution::SinglePartition => { - child = add_spm_on_top(child); + child = add_spm_on_top(child, &mut fetch); } Distribution::HashPartitioned(exprs) => { if add_roundrobin { @@ -1288,9 +1316,10 @@ pub fn ensure_distribution( .equivalence_properties() .ordering_satisfy_requirement(&required_input_ordering); if (!ordering_satisfied || !order_preserving_variants_desirable) - && child.data + && child.data.has_dist_changing { child = replace_order_preserving_variants(child)?; + let fetch = child.data.fetch.take(); // If ordering requirements were satisfied before repartitioning, // make sure ordering requirements are still satisfied after. if ordering_satisfied { @@ -1298,12 +1327,12 @@ pub fn ensure_distribution( child = add_sort_above_with_check( child, required_input_ordering.clone(), - None, + fetch, ); } } // Stop tracking distribution changing operators - child.data = false; + child.data.has_dist_changing = false; } else { // no ordering requirement match requirement { @@ -1361,22 +1390,67 @@ pub fn ensure_distribution( } else { plan.with_new_children(children_plans)? }; + let mut optimized_distribution_ctx = + DistributionContext::new(Arc::clone(&plan), data.clone(), children); + + // If `fetch` was not consumed, it means that there was `SortPreservingMergeExec` with fetch before + // It was removed by `remove_dist_changing_operators` + // and we need to add it back. + if fetch.is_some() { + let plan = Arc::new( + SortPreservingMergeExec::new( + plan.output_ordering() + .unwrap_or(&LexOrdering::default()) + .clone(), + plan, + ) + .with_fetch(fetch.take()), + ); + optimized_distribution_ctx = + DistributionContext::new(plan, data, vec![optimized_distribution_ctx]); + } + + Ok(Transformed::yes(optimized_distribution_ctx)) +} - Ok(Transformed::yes(DistributionContext::new( - plan, data, children, - ))) +/// Distribution context that tracks distribution changing operators and fetch limits +#[derive(Debug, Clone, Default)] +pub struct DistributionData { + /// Whether this node contains distribution changing operators + pub has_dist_changing: bool, + /// /// Limit which must be applied to any sort preserving merge that is created + pub fetch: Option, +} + +impl DistributionData { + fn new(has_dist_changing: bool) -> Self { + Self { + has_dist_changing, + fetch: None, + } + } +} + +impl Display for DistributionData { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!( + f, + "(has_dist_changing: {}, fetch: {:?})", + self.has_dist_changing, self.fetch + ) + } } /// Keeps track of distribution changing operators (like `RepartitionExec`, /// `SortPreservingMergeExec`, `CoalescePartitionsExec`) and their ancestors. /// Using this information, we can optimize distribution of the plan if/when /// necessary. -pub type DistributionContext = PlanContext; +pub type DistributionContext = PlanContext; fn update_children(mut dist_context: DistributionContext) -> Result { for child_context in dist_context.children.iter_mut() { let child_plan_any = child_context.plan.as_any(); - child_context.data = + child_context.data.has_dist_changing = if let Some(repartition) = child_plan_any.downcast_ref::() { !matches!( repartition.partitioning(), @@ -1386,14 +1460,14 @@ fn update_children(mut dist_context: DistributionContext) -> Result() || child_plan_any.is::() || child_context.plan.children().is_empty() - || child_context.children[0].data + || child_context.children[0].data.has_dist_changing || child_context .plan .required_input_distribution() .iter() .zip(child_context.children.iter()) .any(|(required_dist, child_context)| { - child_context.data + child_context.data.has_dist_changing && matches!( required_dist, Distribution::UnspecifiedDistribution @@ -1402,7 +1476,7 @@ fn update_children(mut dist_context: DistributionContext) -> Result