diff --git a/datafusion/core/src/physical_optimizer/join_pipeline_selection.rs b/datafusion/core/src/physical_optimizer/join_pipeline_selection.rs index 08e7d9001e80..08d87700ca15 100644 --- a/datafusion/core/src/physical_optimizer/join_pipeline_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_pipeline_selection.rs @@ -3,115 +3,130 @@ use std::sync::Arc; -use crate::physical_optimizer::utils::{is_hash_join, is_nested_loop_join, is_sort}; +use crate::datasource::physical_plan::is_plan_streaming; +use crate::physical_optimizer::join_selection::{ + statistical_join_selection_cross_join, statistical_join_selection_hash_join, + swap_join_according_to_unboundedness, swap_join_type, swap_reverting_projection, +}; +use crate::physical_optimizer::utils::{ + is_aggregate, is_cross_join, is_hash_join, is_nested_loop_join, is_sort, is_window, +}; +use crate::physical_plan::aggregates::{ + get_working_mode, AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; +use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; +use crate::physical_plan::joins::utils::{swap_filter, JoinFilter, JoinOn}; use crate::physical_plan::joins::{ - HashJoinExec, NestedLoopJoinExec, SlidingHashJoinExec, SlidingNestedLoopJoinExec, - SortMergeJoinExec, StreamJoinPartitionMode, + CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionedHashJoinExec, + SlidingHashJoinExec, SlidingNestedLoopJoinExec, SortMergeJoinExec, + StreamJoinPartitionMode, SymmetricHashJoinExec, }; +use crate::physical_plan::projection::ProjectionExec; +use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; +use arrow_schema::SortOptions; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; -use datafusion_common::{JoinType, Result}; +use datafusion_common::{internal_err, DataFusionError, JoinType, Result}; +use datafusion_physical_expr::expressions::{Column, LastValue}; use datafusion_physical_expr::utils::{ - get_indices_of_matching_sort_exprs_with_order_eq, + collect_columns, get_indices_of_matching_sort_exprs_with_order_eq, ordering_satisfy_requirement_concrete, }; -use datafusion_physical_expr::PhysicalSortRequirement; +use datafusion_physical_expr::{ + AggregateExpr, PhysicalSortExpr, PhysicalSortRequirement, +}; use datafusion_physical_plan::joins::prunability::{ is_filter_expr_prunable, separate_columns_of_filter_expression, }; -use datafusion_physical_plan::joins::utils::swap_join_type; +use datafusion_physical_plan::joins::utils::swap_join_on; use datafusion_physical_plan::joins::{ - swap_sliding_hash_join, swap_sliding_nested_loop_join, swap_sort_merge_join, + swap_sliding_hash_join, swap_sliding_nested_loop_join, +}; +use datafusion_physical_plan::windows::{ + get_best_fitting_window, BoundedWindowAggExec, WindowAggExec, }; use itertools::{iproduct, izip, Itertools}; -/// This object is used within the JoinSelection rule to track the closest -/// [`HashJoinExec`] or [`NestedLoopJoinExec`] descendant(s) for every child of a plan. +/// This function swaps the inputs of the given SMJ operator. +fn swap_sort_merge_join( + hash_join: &HashJoinExec, + keys: Vec<(Column, Column)>, + sort_options: Vec, +) -> Result> { + let left = hash_join.left(); + let right = hash_join.right(); + let swapped_join_type = swap_join_type(hash_join.join_type); + if matches!(swapped_join_type, JoinType::RightSemi) { + return Err(DataFusionError::Plan( + "RightSemi is not supported for SortMergeJoin".to_owned(), + )); + } + // Sort option will remain same since each tuple of keys from both side will have exactly same + // SortOptions. + let new_join = SortMergeJoinExec::try_new( + right.clone(), + left.clone(), + swap_join_on(&keys), + swapped_join_type, + sort_options, + hash_join.null_equals_null, + )?; + if matches!( + hash_join.join_type, + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightAnti + ) { + Ok(Arc::new(new_join)) + } else { + // TODO avoid adding ProjectionExec again and again, only adding Final Projection + let proj = ProjectionExec::try_new( + swap_reverting_projection(&left.schema(), &right.schema()), + Arc::new(new_join), + )?; + Ok(Arc::new(proj)) + } +} + +/// Represents the current state of an execution plan in the context of query optimization or analysis. +/// +/// `PlanState` acts as a wrapper around a vector of execution plans (`plans`), offering utility methods +/// to work with these plans and their children, and to apply transformations. This structure is instrumental +/// in manipulating and understanding the flow of execution in the context of database query optimization. +/// +/// It also implements the `TreeNode` trait which provides methods for working with trees of nodes, allowing +/// for recursive operations on the execution plans and their children. #[derive(Debug, Clone)] -pub struct PlanWithCorrespondingHashJoin { - pub(crate) plan: Arc, - // For every child, we keep a subtree of `ExecutionPlan`s starting from the - // child until the `HashJoinExec`(s) that affect the output ordering of the - // child. If the child has no connection to any `HashJoinExec`, simply store - // `None` (and not a subtree). - hash_join_onwards: Vec>, +pub struct PlanState { + pub(crate) plans: Vec>, } -impl PlanWithCorrespondingHashJoin { +impl PlanState { + /// Creates a new `PlanState` instance from a given execution plan. + /// + /// # Parameters + /// - `plan`: The execution plan to be wrapped by this state. pub fn new(plan: Arc) -> Self { - let length = plan.children().len(); - PlanWithCorrespondingHashJoin { - plan, - hash_join_onwards: vec![None; length], - } - } - - pub fn new_from_children_nodes( - children_nodes: Vec, - parent_plan: Arc, - ) -> Result { - let children_plans = children_nodes - .iter() - .map(|item| item.plan.clone()) - .collect(); - let hash_join_onwards = children_nodes - .into_iter() - .enumerate() - .map(|(idx, item)| { - let plan = item.plan; - if plan.children().is_empty() { - // Plan has no children, there is nothing to propagate. - None - } else if (is_hash_join(&plan) || is_nested_loop_join(&plan)) - && item.hash_join_onwards.iter().all(|e| e.is_none()) - { - Some(MultipleExecTree::new(vec![plan], idx, vec![])) - } else { - let required_orderings = plan.required_input_ordering(); - let flags = plan.maintains_input_order(); - let children = - izip!(flags, item.hash_join_onwards, required_orderings) - .filter_map(|(maintains, element, required_ordering)| { - if (required_ordering.is_none() && maintains) - || is_hash_join(&plan) - || is_nested_loop_join(&plan) - || is_sort(&plan) - { - element - } else { - None - } - }) - .collect::>(); - if children.is_empty() { - None - } else { - Some(MultipleExecTree::new(vec![plan], idx, children)) - } - } - }) - .collect(); - let plan = with_new_children_if_necessary(parent_plan, children_plans)?.into(); - Ok(PlanWithCorrespondingHashJoin { - plan, - hash_join_onwards, - }) + PlanState { plans: vec![plan] } } - pub fn children(&self) -> Vec { - self.plan + /// Returns the children of the execution plan as a vector of `PlanState`. + /// + /// Each child represents a subsequent step or dependency in the execution flow. + pub fn children(&self) -> Vec { + self.plans[0] .children() .into_iter() - .map(|child| PlanWithCorrespondingHashJoin::new(child)) + .map(|child| PlanState::new(child)) .collect() } } -impl TreeNode for PlanWithCorrespondingHashJoin { +impl TreeNode for PlanState { fn apply_children(&self, op: &mut F) -> Result where F: FnMut(&Self) -> Result, @@ -127,6 +142,34 @@ impl TreeNode for PlanWithCorrespondingHashJoin { Ok(VisitRecursion::Continue) } + /// Transforms the children of the current execution plan using a provided transformation function. + /// + /// This method first retrieves the children of the current execution plan. If there are no children, + /// it returns the current plan state without any changes. Otherwise, it applies the given transformation + /// function to each child, creating a new set of possible execution plans. + /// + /// The method then constructs a cartesian product of all possible child plans and combines them with + /// the original plan to generate new execution plans that have the transformed children. This is particularly + /// useful in scenarios where multiple transformations or optimizations can be applied, and one wants to + /// explore all possible combinations. + /// + /// # Type Parameters + /// + /// * `F`: A closure type that takes a `PlanState` and returns a `Result`. This closure is used to + /// transform the children of the current execution plan. + /// + /// # Parameters + /// + /// * `transform`: The transformation function to be applied to each child of the execution plan. + /// + /// # Returns + /// + /// * A `Result` containing the transformed `PlanState`. + /// + /// # Errors + /// + /// Returns an error if the transformation fails on any of the children or if there's an issue combining the + /// original plan with the transformed children. fn map_children(self, transform: F) -> Result where F: FnMut(Self) -> Result, @@ -135,522 +178,1240 @@ impl TreeNode for PlanWithCorrespondingHashJoin { if children.is_empty() { Ok(self) } else { + // Transform the children nodes: let children_nodes = children .into_iter() .map(transform) .collect::>>()?; - PlanWithCorrespondingHashJoin::new_from_children_nodes( - children_nodes, - self.plan, - ) - } - } -} + // Get the cartesian product of all possible children: + let possible_children = children_nodes + .into_iter() + .map(|v| v.plans.into_iter()) + .multi_cartesian_product(); -/// This object implements a tree that we use while keeping track of paths -/// leading to [`HashJoinExec`]s. -#[derive(Debug, Clone)] -pub struct MultipleExecTree { - /// The `ExecutionPlan`s associated with this node. - pub plans: Vec>, - /// Child index of the plan in its parent. - pub idx: usize, - /// Children of the plan that would need updating if we remove leaf executors. - pub children: Vec, -} + // Combine the plans with the possible children: + let new_plans = iproduct!(self.plans.into_iter(), possible_children) + .map(|(plan, children)| { + let plan = with_new_children_if_necessary(plan, children)?.into(); + Ok(plan) + }) + .collect::>>()?; -impl MultipleExecTree { - /// Create new Exec tree. - pub fn new( - plans: Vec>, - idx: usize, - children: Vec, - ) -> Self { - MultipleExecTree { - plans, - idx, - children, + Ok(PlanState { plans: new_plans }) } } } -/// Creates a variety of join cases for a given subtree of plans within a local ordering constraint. +/// Examines the provided `PlanState` and selectively optimizes certain types of execution plans, +/// especially joins, to maintain or improve the performance of the query pipeline. +/// The function inspects each plan in the `PlanState` and applies relevant optimizations +/// based on the type of plan and its characteristics. /// -/// This function recursively traverses the provided `MultipleExecTree` (a tree structure of execution plans), -/// and generates new execution plans that are candidate to preserve required order. -/// -/// For each child there will be more than one possible plan that will be propagated to upper. -/// Imagine a parent plan that have 2 child. If obe child has 3 possible plans and other has 2, -/// we propagate 6 possible plans (cartesian product) to the top. -/// -/// The generation of new plans is guided by the `ConfigOptions` provided, which might contain specific rules -/// and preferences for the execution plans. +/// The primary focus of this function is to handle hash joins, cross joins, nested loop joins, +/// aggregates, sorts, windows, and other plans with inherent sorting requirements. /// /// # Arguments /// -/// * `hash_join_onwards`: A mutable reference to a `MultipleExecTree` object, which contains the current subtree -/// of execution plans that need to be processed. -/// -/// * `config_options`: A reference to a `ConfigOptions` object, which provides the configuration settings that guide -/// the creation of new execution plans. +/// * `requirements`: The `PlanState` containing a collection of execution plans to be optimized. +/// * `config_options`: Configuration options that might influence the optimization strategies. /// /// # Returns /// -/// This function returns a `Result` that contains a `Vec` of `Arc` objects if the execution -/// plans are successfully created. Each `ExecutionPlan` within the `Vec` is a unique execution plan that are candidate to preserve required order. -/// If any error occurs during the creation of execution plans, -/// the function returns an `Err`. -fn create_join_cases_local_preserve_order( - hash_join_onwards: &mut MultipleExecTree, +/// * A `Result` containing a `Transformed` variant indicating the outcome of the transformation. +/// - `Transformed::Yes`: Indicates successful transformation, and contains the new `PlanState` with optimized plans. +/// - `Transformed::No`: Would indicate that no transformation took place, but it's not returned by this function. +pub fn select_joins_to_preserve_pipeline( + requirements: PlanState, config_options: &ConfigOptions, -) -> Result>> { - // Clone all the plans in the given subtree: - let plans = hash_join_onwards.plans.clone(); - - // Iterate over the children of the subtree, update their plans recursively: - for item in &mut hash_join_onwards.children { - item.plans = create_join_cases_local_preserve_order(item, config_options)?; - } - - // If there is at least one plan... - if let Some(first_plan) = plans.get(0) { - // Create a vector of child possibilities for each child index. - // If a child with the same index exists in `hash_join_onwards.children`, - // use its plans. Otherwise, use the child from the first plan. - let get_children: Vec<_> = first_plan - .children() - .into_iter() - .enumerate() - .map(|(child_index, child)| { - if let Some(item) = hash_join_onwards - .children - .iter() - .find(|item| item.idx == child_index) - { - item.plans.clone() - } else { - vec![child] - } - }) - .collect(); +) -> Result> { + let PlanState { plans, .. } = requirements; - // Get the cartesian product of all possible children: - let possible_children = get_children - .into_iter() - .map(|v| v.into_iter()) - .multi_cartesian_product(); - - // Compute new plans by combining each plan with each possible set of children. - // Replace the original plans with these new ones: - hash_join_onwards.plans = iproduct!(plans.into_iter(), possible_children) - .map(|(plan, children)| { - let plan = with_new_children_if_necessary(plan, children)?.into(); - Ok(plan) - }) - .collect::>>()?; - } - - // For each plan, check if it is a `HashJoinExec`. If so, check if we can convert - // it. If not, use the original plan. - let plans: Vec<_> = hash_join_onwards - .plans + let new_plans: Result> = plans .iter() - .flat_map(|plan| { - if let Some(hash_join) = plan.as_any().downcast_ref::() { - match check_hash_join_convertable(hash_join, config_options) { - Ok(Some(result)) => result, - _ => vec![plan.clone()], + .map(|plan| { + // Pattern match against the specific types of plans we want to optimize + match plan { + // Handle hash join optimizations + _ if is_hash_join(plan) => handle_hash_join_exec(plan, config_options), + // Handle cross joins + _ if is_cross_join(plan) => handle_cross_join(plan), + // Handle nested loop joins + _ if is_nested_loop_join(plan) => { + handle_nested_loop_join_exec(plan, config_options) } - } else if let Some(nested_loop_join) = - plan.as_any().downcast_ref::() - { - match check_nested_loop_join_convertable(nested_loop_join, config_options) + // If the plan requires a specific input order (like aggregates or sorts), handle them. + // Also handle any plan that inherently has a required input order + _ if is_aggregate(plan) + | is_sort(plan) + | plan.required_input_ordering().iter().any(|e| e.is_some()) => { - Ok(Some(result)) => result, - _ => vec![plan.clone()], - } - } else if let Some(sort_exec) = plan.as_any().downcast_ref::() { - let sort_child = sort_exec.input().clone(); - let child_requirement = sort_exec - .expr() - .iter() - .cloned() - .map(PhysicalSortRequirement::from) - .collect::>(); - if sort_child - .output_ordering() - .map_or(false, |provided_ordering| { - ordering_satisfy_requirement_concrete( - provided_ordering, - &child_requirement, - || sort_child.equivalence_properties(), - || sort_child.ordering_equivalence_properties(), - ) + let handled_plan = match plan { + _ if is_aggregate(plan) => { + handle_aggregate_exec(plan, config_options) + } + _ if is_sort(plan) => handle_sort_exec(plan), + _ if is_window(plan) => handle_window_execs(plan), + _ => handle_sort_requirements(plan), + }; + + // Filter out plans that have errors in execution. For bounded and unbounded cases, + // we shouldn't propagate invalid plans to the further analysis. + handled_plan.map(|plans| { + plans + .into_iter() + .filter(|p| is_plan_streaming(p).is_ok()) + .collect() }) - { - vec![sort_child.clone()] - } else { - vec![] } - } else { - vec![plan.clone()] + // If none of the above conditions match, simply clone the plan + _ => Ok(vec![plan.clone()]), } }) + .flatten_ok() // Flatten the results to remove nesting .collect(); - Ok(plans) + Ok(Transformed::Yes(PlanState { plans: new_plans? })) } -/// Checks if a given `HashJoinExec` can be converted into another form of execution plans, -/// primarily into `SortMergeJoinExec` while still preserving the required conditions. -/// -/// This function first extracts key properties from the `HashJoinExec` and then determines -/// the possibility of conversion based on these properties. The conversion mainly involves -/// replacing `HashJoinExec` with `SortMergeJoinExec` for cases where both left and right children -/// are unbounded, and the output ordering can be satisfied. +/// Ensures that the given `plan` satisfies its required input ordering. /// -/// If conversion is feasible, this function creates the corresponding `SortMergeJoinExec` instances, -/// handles the possibility of swapping join types, and returns these new plans. +/// The function checks if the child plans can satisfy the parent's input +/// ordering requirements. If they can, the plan is deemed valid; otherwise, +/// an empty vector is returned. /// /// # Arguments /// -/// * `hash_join`: A reference to a `HashJoinExec` object, which is the hash join execution plan -/// that we are checking for convertibility. -/// -/// * `_config_options`: A reference to a `ConfigOptions` object, which provides the configuration settings. -/// However, these options are not used in the current function. +/// * `plan` - The execution plan to evaluate. /// /// # Returns /// -/// This function returns a `Result` that contains an `Option`. If the `HashJoinExec` can be converted, -/// the `Option` will contain a `Vec` of `Arc` objects, each representing a new execution plan. -/// If conversion is not possible, the function returns `None`. Any errors that occur during the conversion process -/// will result in an `Err` being returned. -fn check_hash_join_convertable( - hash_join: &HashJoinExec, - config_options: &ConfigOptions, -) -> Result>>> { - // To perform the prunability analysis correctly, the columns from the left table - // and the columns from the right table must be on the different sides of the join. - let filter = hash_join - .filter() - .map(|filter| separate_columns_of_filter_expression(filter.clone())); - let (on_left, on_right): (Vec<_>, Vec<_>) = hash_join.on.iter().cloned().unzip(); - let left_order = hash_join.left().output_ordering(); - let right_order = hash_join.right().output_ordering(); - let is_left_streaming = is_plan_streaming(hash_join.left()); - let is_right_streaming = is_plan_streaming(hash_join.right()); - match ( - is_left_streaming, - is_right_streaming, - filter, - left_order, - right_order, - ) { - (true, true, Some(filter), Some(left_order), Some(right_order)) => { - let (left_prunable, right_prunable) = is_filter_expr_prunable( - &filter, - Some(left_order[0].clone()), - Some(right_order[0].clone()), - || hash_join.left().equivalence_properties(), - || hash_join.left().ordering_equivalence_properties(), - || hash_join.right().equivalence_properties(), - || hash_join.right().ordering_equivalence_properties(), - )?; +/// Returns a vector containing the original plan if it is valid, or an empty vector otherwise. +fn handle_sort_requirements( + plan: &Arc, +) -> Result>> { + // Extract children and requirements from the plan for clarity + let children = plan.children(); + let requirements = plan.required_input_ordering(); - if left_prunable && right_prunable { - let mode = if config_options.optimizer.repartition_joins { - StreamJoinPartitionMode::Partitioned + // Check the streaming status and validity of the plan with its sources + let maybe_streaming = is_plan_streaming(plan); + + let plan_is_valid = match maybe_streaming { + // If plan is streaming, check if all children satisfy their requirements + Ok(true) => izip!(children, requirements).all(|(child, maybe_requirement)| { + // If there's no requirement, it's automatically satisfied + if let Some(requirement) = maybe_requirement { + if let Some(child_output_ordering) = child.output_ordering() { + ordering_satisfy_requirement_concrete( + child_output_ordering, + &requirement, + || child.equivalence_properties(), + || child.ordering_equivalence_properties(), + ) } else { - StreamJoinPartitionMode::SinglePartition - }; - let sliding_hash_join = Arc::new(SlidingHashJoinExec::try_new( - hash_join.left.clone(), - hash_join.right.clone(), - hash_join.on.clone(), - filter.clone(), - &hash_join.join_type, - hash_join.null_equals_null, - left_order.to_vec(), - right_order.to_vec(), - mode, - )?); - let reversed_sliding_hash_join = - swap_sliding_hash_join(&sliding_hash_join)?; - Ok(Some(vec![sliding_hash_join, reversed_sliding_hash_join])) - } else { - Ok(None) - } - } - (true, true, None, Some(left_order), Some(right_order)) => { - // Get left key(s)' sort options: - let left_satisfied = get_indices_of_matching_sort_exprs_with_order_eq( - left_order, - &on_left, - &hash_join.left().equivalence_properties(), - &hash_join.left().ordering_equivalence_properties(), - ); - // Get right key(s)' sort options: - let right_satisfied = get_indices_of_matching_sort_exprs_with_order_eq( - right_order, - &on_right, - &hash_join.right().equivalence_properties(), - &hash_join.right().ordering_equivalence_properties(), - ); - - if let ( - Some((left_satisfied, left_indices)), - Some((right_satisfied, right_indices)), - ) = (left_satisfied, right_satisfied) - { - // Check if the indices are equal and the sort options are aligned: - if left_indices == right_indices - && left_satisfied - .iter() - .zip(right_satisfied.iter()) - .all(|(l, r)| l == r) - { - let adjusted_keys = left_indices - .iter() - .map(|index| hash_join.on[*index].clone()) - .collect::>(); - let mut plans: Vec> = vec![]; - // SortMergeJoin does not support RightSemi - if !matches!(hash_join.join_type, JoinType::RightSemi) { - plans.push(Arc::new(SortMergeJoinExec::try_new( - hash_join.left.clone(), - hash_join.right.clone(), - adjusted_keys.clone(), - hash_join.join_type, - left_satisfied, - hash_join.null_equals_null, - )?)) - } - if !matches!(swap_join_type(hash_join.join_type), JoinType::RightSemi) - { - plans.push(swap_sort_merge_join( - hash_join, - adjusted_keys, - right_satisfied, - )?); - } - return Ok(Some(plans)); + // If child doesn't provide output ordering, requirement isn't satisfied + false } - } - Ok(None) - } - _ => Ok(None), - } -} - -fn check_nested_loop_join_convertable( - nested_loop_join: &NestedLoopJoinExec, - _config_options: &ConfigOptions, -) -> Result>>> { - // To perform the prunability analysis correctly, the columns from the left table - // and the columns from the right table must be on the different sides of the join. - let filter = nested_loop_join - .filter() - .map(|filter| separate_columns_of_filter_expression(filter.clone())); - let left_order = nested_loop_join.left().output_ordering(); - let right_order = nested_loop_join.right().output_ordering(); - let is_left_streaming = is_plan_streaming(nested_loop_join.left()); - let is_right_streaming = is_plan_streaming(nested_loop_join.right()); - match ( - is_left_streaming, - is_right_streaming, - filter, - left_order, - right_order, - ) { - (true, true, Some(filter), Some(left_order), Some(right_order)) => { - let (left_prunable, right_prunable) = is_filter_expr_prunable( - &filter, - Some(left_order[0].clone()), - Some(right_order[0].clone()), - || nested_loop_join.left().equivalence_properties(), - || nested_loop_join.left().ordering_equivalence_properties(), - || nested_loop_join.right().equivalence_properties(), - || nested_loop_join.right().ordering_equivalence_properties(), - )?; - if left_prunable && right_prunable { - let sliding_nested_loop_join = - Arc::new(SlidingNestedLoopJoinExec::try_new( - nested_loop_join.left().clone(), - nested_loop_join.right().clone(), - filter.clone(), - nested_loop_join.join_type(), - left_order.to_vec(), - right_order.to_vec(), - )?); - let reversed_sliding_nested_loop_join = - swap_sliding_nested_loop_join(&sliding_nested_loop_join)?; - Ok(Some(vec![ - sliding_nested_loop_join, - reversed_sliding_nested_loop_join, - ])) } else { - Ok(None) + true } - } - _ => Ok(None), - } + }), + // If plan is not streaming but valid with its sources, it's considered valid + Ok(false) => true, + // If there was an error determining the streaming status, plan is invalid + Err(_) => false, + }; + + // Return the plan if it's valid, or an empty vector otherwise + Ok(if plan_is_valid { + vec![plan.clone()] + } else { + vec![] + }) } -/// Generates and filters a set of execution plans based on a specified ordering requirement. +/// Handles a potential hash join conversion based on the given execution plan and configuration. /// -/// This function leverages the `create_join_cases_local_preserve_order` function to generate -/// a variety of execution plans from a given `MultipleExecTree`. It then filters these plans -/// using the `find_suitable_plans` function to only keep those that meet the required ordering. +/// This function checks if the provided execution `plan` can be converted into a hash join +/// based on certain criteria derived from the `config_options`. If it's convertible, +/// the converted plans are returned. If not, the original plan is returned. /// /// # Arguments /// -/// * `hash_join_onward`: A mutable reference to a `MultipleExecTree` object, which contains -/// the current subtree of execution plans to be processed. -/// -/// * `required_ordering`: A reference to a slice of `PhysicalSortRequirement` objects, -/// which define the desired ordering for the resulting execution plans. -/// -/// * `config_options`: A reference to a `ConfigOptions` object, which provides the configuration -/// settings that guide the creation of new execution plans. +/// * `plan` - The execution plan to evaluate for potential hash join conversion. +/// * `config_options` - The configuration options that influence the decision for conversion. /// /// # Returns /// -/// This function returns a `Result` that contains a `Vec` of `Arc` objects -/// if the execution plans meeting the ordering requirements are successfully created. -/// If any error occurs during this process, the function returns an `Err`. - -fn get_meeting_the_plan_with_required_order( - hash_join_onward: &mut MultipleExecTree, - required_ordering: &[PhysicalSortRequirement], +/// Returns a vector containing the converted plan(s) if the given `plan` can be converted +/// into a hash join. If conversion isn't feasible, the original plan is returned in the vector. +fn handle_hash_join_exec( + plan: &Arc, config_options: &ConfigOptions, ) -> Result>> { - let possible_plans = - create_join_cases_local_preserve_order(hash_join_onward, config_options)?; - Ok(find_suitable_plans(possible_plans, required_ordering)) + if let Some(result) = check_hash_join_convertable(plan, config_options)? { + return Ok(result); + } + Ok(vec![plan.clone()]) } -/// Filters a list of execution plans to keep only those that meet a specified ordering requirement. +/// Handles a potential cross join conversion based on the given execution plan. /// -/// This function iterates through the provided list of `ExecutionPlan` objects and filters out those -/// that do not satisfy the required ordering. The function leverages the `ordering_satisfy_requirement_concrete` -/// function to check if the provided ordering of an execution plan meets the required ordering. +/// This function checks if the provided execution `plan` can be converted from a cross join +/// based on certain criteria. If it's convertible, the converted plans are returned. +/// If not, the original plan is returned. /// /// # Arguments /// -/// * `plans`: A vector of `Arc` objects. These are the initial execution plans that -/// need to be filtered. +/// * `plan` - The execution plan to evaluate for potential cross join conversion. +/// +/// # Returns +/// +/// Returns a vector containing the converted plan(s) if the given `plan` can be converted +/// from a cross join. If conversion isn't feasible, the original plan is returned in the vector. +fn handle_cross_join( + plan: &Arc, +) -> Result>> { + if let Some(result) = check_cross_join_convertable(plan)? { + return Ok(result); + } + Ok(vec![plan.clone()]) +} + +// TODO: This will be improved with new mechanisms. +pub fn cost_of_the_plan(plan: &Arc) -> usize { + let children_cost: usize = plan.children().iter().map(cost_of_the_plan).sum(); + let plan_cost = if plan.as_any().is::() { + 1 + } else { + 0 + }; + plan_cost + children_cost +} + +/// Handles potential modifications to an aggregate execution plan. +/// +/// This function evaluates the provided execution `plan` for potential modifications +/// related to the aggregate execution. If modifications can be applied, the modified plans +/// are returned; otherwise, the original plan is returned. +/// +/// The purpose of this function is to optimize and adapt the aggregate execution plan +/// based on the given configuration options and inherent properties of the plan itself. +/// +/// # Arguments /// -/// * `required_ordering`: A reference to a slice of `PhysicalSortRequirement` objects, -/// which define the desired ordering that the execution plans should satisfy. +/// * `plan` - The execution plan to evaluate for potential aggregate modifications. +/// * `config_options` - The configuration options that may affect the plan modifications. /// /// # Returns /// -/// The function returns a vector of `Arc` objects. Each `ExecutionPlan` in the vector -/// is a plan from the input vector that satisfies the required ordering. -fn find_suitable_plans( - plans: Vec>, - required_ordering: &[PhysicalSortRequirement], -) -> Vec> { - plans - .into_iter() - .filter(|plan| { - plan.output_ordering().map_or(false, |provided_ordering| { - ordering_satisfy_requirement_concrete( - provided_ordering, - required_ordering, - || plan.equivalence_properties(), - || plan.ordering_equivalence_properties(), - ) - }) - }) - .collect() +/// Returns a vector containing the modified plan(s) if the given `plan` can be optimized +/// as an aggregate execution. If no modifications are feasible, the original plan is returned in the vector. +fn handle_aggregate_exec( + plan: &Arc, + config_options: &ConfigOptions, +) -> Result>> { + // Attempt to downcast the execution plan to an AggregateExec + if let Some(aggregation_exec) = plan.as_any().downcast_ref::() { + // Extract the input plan for the aggregation + let input_plan = aggregation_exec.input(); + + // Find the closest join that can be changed based on the aggregate and input plans + let possible_children_plans = + find_closest_join_and_change(aggregation_exec, input_plan, config_options)?; + + // Extract the group by and aggregate expressions from the aggregate execution + let group_by = aggregation_exec.group_by(); + let aggr_expr = aggregation_exec.aggr_expr(); + + // Select the best aggregate streaming plan based on possible children plans and expressions + let children_plans_plans = select_best_aggregate_streaming_plan( + possible_children_plans, + group_by, + aggr_expr, + )?; + + // If there are no optimized children plans, return the original plan + // Otherwise, modify the plan with the new children plans and return + return if children_plans_plans.is_empty() { + Ok(vec![plan.clone()]) + } else { + children_plans_plans + .into_iter() + .map(|child| plan.clone().with_new_children(vec![child])) + .collect() + }; + } + + // If the provided execution plan isn't an AggregateExec, return the original plan + Ok(vec![plan.clone()]) } -/// This function chooses the "best" (i.e. cost optimal) plan among all the -/// feasible plans given in `possible_plans`. +/// Selects the best aggregate streaming plan from a list of possible plans. /// -/// TODO: Until the Datafusion can identify execution costs of the plans, we -/// are selecting the first feasible plan. -fn select_best_streaming_plan( +/// Evaluates a list of potential execution plans to determine which ones +/// are suitable for streamable aggregation based on several criteria: +/// - Whether a plan has seen a `PartitionedHashJoinExec` +/// - Whether the aggregation result of the plan is valid +/// - The validity of all group by and aggregate expressions +/// - Whether the plan supports streamable aggregates +/// - Whether the child plan can be executed without errors +/// +/// # Parameters +/// - `possible_plans`: A list of potential execution plans to evaluate. +/// - `group_by`: The physical group-by expression to consider. +/// - `aggr_expr`: A list of aggregate expressions to evaluate. +/// +/// # Returns +/// Returns a `Result` containing a list of suitable execution plans, +/// or an error if a plan fails the validation. +fn select_best_aggregate_streaming_plan( possible_plans: Vec>, -) -> Option> { + group_by: &PhysicalGroupBy, + aggr_expr: &[Arc], +) -> Result>> { possible_plans .into_iter() - .find(|plan| is_plan_streaming(plan)) + .filter_map(|plan| { + // Flag to track if PartitionedHashJoinExec is encountered + let mut has_seen_phj = false; + + // Ensure the aggregation result of the plan is valid + let phj_is_valid = check_the_aggregation_result_is_valid(&plan, &mut has_seen_phj); + if !phj_is_valid { + return Some(internal_err!("PartitionHashJoin cannot be interrupt by another unallowed executor.")); + } + + // If PartitionedHashJoinExec is encountered, check expressions validity + let agg_valid_results = if has_seen_phj { + let schema = plan.schema(); + + // All group by expressions should probe the right side of a partitioned hash join + let group_valid = group_by.expr().iter().all(|(expr, _)| { + collect_columns(expr).iter().all(|col| { + schema.field(col.index()) + .metadata() + .get("PartitionedHashJoinExec") + .map_or(false, |v| v.eq("JoinSide::Right")) + }) + }); + + // All aggregate expressions should belong to the left side of a partitioned hash join + let aggr_valid = aggr_expr.iter().all(|expr| { + expr.expressions().iter().all(|expr| { + collect_columns(expr).iter().all(|col| { + schema.field(col.index()) + .metadata() + .get("PartitionedHashJoinExec") + .map_or(false, |v| v.eq("JoinSide::Left")) + }) + }) + }); + + group_valid && aggr_valid + } else { + true + }; + + // Plan should support streamable aggregates and be streamable itself + if get_working_mode(&plan, group_by).is_some() + && is_plan_streaming(&plan).is_ok() + && agg_valid_results { + Some(Ok(plan)) + } else { + None + } + }) + .collect() } -// Check if the given `plan` processes infinite data in a streaming fashion. -fn is_plan_streaming(plan: &Arc) -> bool { - let children_unbounded_output = plan +/// Attempts to locate the nearest `HashJoinExec` within the provided execution `plan` +/// and then modifies it according to the given aggregate plan (`agg_plan`). +/// +/// # Arguments +/// +/// * `agg_plan`: Reference to an `AggregateExec` plan, which provides context for the optimization. +/// * `plan`: Reference to the current node in the execution plan tree being analyzed. +/// * `config_options`: Configuration options that may influence the optimization decisions. +/// +/// # Returns +/// +/// Returns a `Result` containing a vector of possible modified execution plans (`Arc`). +/// The function may generate multiple alternative plans based on the possible children configurations. +/// +/// The function will return the original plan encapsulated in a vector if: +/// 1. The plan node isn't a `HashJoinExec`. +/// 2. Modifying the `HashJoinExec` according to the `agg_plan` isn't feasible or allowed. +/// +/// In case of any error during the optimization process, an error variant of `Result` will be returned. +/// +/// # Example +/// +/// If the `plan` contains a `HashJoinExec` that can be influenced by `agg_plan`, +/// this function will generate optimized versions of the plan and return them. +/// Otherwise, it will go deeper into the tree, recursively trying to find a suitable `HashJoinExec` +/// and do the necessary modifications. +fn find_closest_join_and_change( + agg_plan: &AggregateExec, + plan: &Arc, + config_options: &ConfigOptions, +) -> Result>> { + // Attempt to cast the current node to a HashJoinExec + if let Some(hash_join) = plan.as_any().downcast_ref::() { + // If the current node is a HashJoinExec, then try to modify it based on the agg_plan. + return check_hash_join_aggregate_partial_hash_join( + agg_plan, + hash_join, + config_options, + ) + .and_then(|opt| opt.map_or_else(|| Ok(vec![plan.clone()]), Ok)); + } + + // Attempt to cast the current node to a SymmetricHashJoin + if let Some(sym_hash_join) = plan.as_any().downcast_ref::() { + // If the current node is a HashJoinExec, then try to modify it based on the agg_plan. + return check_symmetric_hash_join_aggregate_partial_hash_join( + agg_plan, + sym_hash_join, + config_options, + ) + .and_then(|opt| opt.map_or_else(|| Ok(vec![plan.clone()]), Ok)); + } + + // If the current plan node is not allowed for modification, return the original plan. + if !is_allowed(plan) { + return Ok(vec![plan.clone()]); + } + + // For each child of the current plan, recursively attempt to locate and modify a HashJoinExec. + let calculated_children_possibilities = plan .children() .iter() - .map(is_plan_streaming) - .collect::>(); - plan.unbounded_output(&children_unbounded_output) - .unwrap_or(false) + .map(|child| find_closest_join_and_change(agg_plan, child, config_options)) + .collect::>>>()?; + + // Generate all possible combinations of children based on the modified child plans. + // This is used to create alternative versions of the current plan with different child configurations. + calculated_children_possibilities + .into_iter() + .map(|v| v.into_iter()) + .multi_cartesian_product() + .map(|children| { + // For each combination of children, try to create a new version of the current plan. + with_new_children_if_necessary(plan.clone(), children).map(|t| t.into()) + }) + .collect() } -/// This subrule tries to modify joins in order to preserve output ordering(s). -/// This will enable downstream rules, such as `EnforceSorting`, to optimize -/// away costly pipeline-breaking sort operations. -pub fn select_joins_to_preserve_order_subrule( - requirements: PlanWithCorrespondingHashJoin, +/// Handles the optimization for the nested loop join execution plan. +/// +/// This function checks if the given plan can be converted as a nested loop join based on the +/// provided configuration options. +/// +/// # Arguments +/// * `plan` - The execution plan to be checked and possibly converted. +/// * `config_options` - The configuration options that may affect the nested loop join conversion. +/// +/// # Returns +/// * A `Result` containing a vector of execution plans after processing. +fn handle_nested_loop_join_exec( + plan: &Arc, config_options: &ConfigOptions, -) -> Result> { - // If there are no child nodes, return as is: - if requirements.plan.children().is_empty() { - return Ok(Transformed::No(requirements)); - } - let PlanWithCorrespondingHashJoin { - plan, - mut hash_join_onwards, - } = requirements; - - // If the plan has a required ordering: - if plan.required_input_ordering().iter().any(|e| e.is_some()) { - let mut children = plan.children(); - let mut is_transformed = false; - - // Jointly iterate over child nodes and required orderings: - for (child, required_ordering, hash_join_onward) in izip!( - children.iter_mut(), - plan.required_input_ordering().iter(), - hash_join_onwards.iter_mut() - ) { - let required_ordering = match required_ordering { - Some(req) => req, - None => continue, - }; - let hash_join_onward = match hash_join_onward { - Some(hj) => hj, - None => continue, - }; - // Get possible plans meeting the ordering requirements: - let possible_plans = get_meeting_the_plan_with_required_order( - hash_join_onward, - required_ordering, - config_options, - )?; - // If there is a plan that is more optimal, choose it: - if let Some(plan) = select_best_streaming_plan(possible_plans) { - *child = plan; - is_transformed = true; - } - } - // If a transformation has occurred, return the new plan: - if is_transformed { - return Ok(Transformed::Yes(PlanWithCorrespondingHashJoin { - plan: plan.with_new_children(children)?, - hash_join_onwards, - })); +) -> Result>> { + if let Some(nested_loop_join) = plan.as_any().downcast_ref::() { + if let Some(result) = + check_nested_loop_join_convertable(nested_loop_join, config_options)? + { + return Ok(result); } - }; - // If no transformation is possible or required, return as is: - Ok(Transformed::No(PlanWithCorrespondingHashJoin { - plan, - hash_join_onwards, - })) + } + Ok(vec![plan.clone()]) +} + +/// Examines the provided execution plan to determine if it is a window aggregation +/// (`BoundedWindowAggExec` or `WindowAggExec`). If it is, the function extracts the window +/// corresponding to the input plan. If not, it returns the original plan. +/// +/// This function facilitates the extraction and processing of window functions in the context +/// of an execution plan. +/// +/// # Arguments +/// +/// * `plan`: The execution plan to inspect for window aggregations. +/// +/// # Returns +/// +/// * A `Result` containing a `Vec` with either: +/// - The extracted window if the provided plan is a window aggregation. +/// - The original execution plan otherwise. +fn handle_window_execs( + plan: &Arc, +) -> Result>> { + let new_window = if let Some(exec) = + plan.as_any().downcast_ref::() + { + get_best_fitting_window(exec.window_expr(), exec.input(), &exec.partition_keys)? + } else if let Some(exec) = plan.as_any().downcast_ref::() { + get_best_fitting_window(exec.window_expr(), exec.input(), &exec.partition_keys)? + } else { + None + }; + if let Some(window) = new_window { + return Ok(vec![window]); + } + Ok(vec![plan.clone()]) +} + +/// Handles the optimization for the sort execution plan. +/// +/// This function checks if the given plan can satisfy the required sorting order. If the sorting +/// requirement is already met, it returns the child of the sort plan, else an empty vector. +/// +/// # Arguments +/// * `plan` - The execution plan to be checked and possibly converted. +/// +/// # Returns +/// * A `Result` containing a vector of execution plans after processing. +fn handle_sort_exec( + plan: &Arc, +) -> Result>> { + if let Some(sort_exec) = plan.as_any().downcast_ref::() { + let sort_child = sort_exec.input().clone(); + let child_requirement = + PhysicalSortRequirement::from_sort_exprs(sort_exec.expr()); + return if sort_exec.fetch().is_none() + && sort_child + .output_ordering() + .map_or(false, |provided_ordering| { + ordering_satisfy_requirement_concrete( + provided_ordering, + &child_requirement, + || sort_child.equivalence_properties(), + || sort_child.ordering_equivalence_properties(), + ) + }) + { + Ok(vec![sort_child.clone()]) + // If the plan is OK with bounded data, we can continue without deleting the possible plan. + } else if let Ok(false) = is_plan_streaming(plan) { + Ok(vec![plan.clone()]) + } else { + Ok(vec![]) + }; + } + Ok(vec![plan.clone()]) +} + +/// Processes an aggregate plan with a partial hash join for optimization. +/// +/// This function examines the aggregate and hash join plans, and applies specific transformations +/// to optimize the execution depending on the configurations and the state of the plans. +/// +/// # Arguments +/// * `parent_plan` - The aggregate execution plan that acts as a parent to the hash join plan. +/// * `hash_join` - The hash join execution plan to be processed. +/// * `config_options` - The configuration options that may affect the optimization process. +/// +/// # Returns +/// * A `Result` containing an `Option` of a vector of execution plans after processing. +/// Returns `None` if no optimization could be applied. +fn check_hash_join_aggregate_partial_hash_join( + parent_plan: &AggregateExec, + hash_join: &HashJoinExec, + _config_options: &ConfigOptions, +) -> Result>>> { + // Based on various properties of the join and the data streams, determine the + // best way to process the hash join. + replace_with_partial_hash_join( + parent_plan, + hash_join.left(), + hash_join.right(), + hash_join.filter(), + hash_join.on.clone(), + hash_join.join_type, + hash_join.null_equals_null, + ) +} + +fn check_symmetric_hash_join_aggregate_partial_hash_join( + parent_plan: &AggregateExec, + sym_hash_join: &SymmetricHashJoinExec, + _config_options: &ConfigOptions, +) -> Result>>> { + // Based on various properties of the join and the data streams, determine the + // best way to process the hash join. + replace_with_partial_hash_join( + parent_plan, + sym_hash_join.left(), + sym_hash_join.right(), + sym_hash_join.filter(), + sym_hash_join.on().to_vec(), + *sym_hash_join.join_type(), + sym_hash_join.null_equals_null(), + ) +} + +/// Attempts to replace the current join execution plan with a more optimized partitioned hash join +/// based on specific conditions. This optimization specifically targets scenarios where both input +/// streams are unbounded and have filters and orders provided. If certain criteria are met, the function +/// returns a `PartitionedHashJoinExec` or a swapped variant with an added projection. +/// +/// # Arguments +/// +/// * `parent_plan`: The aggregate execution plan which is the parent node of the current join. +/// * `left_child`: The left child node of the current join execution plan. +/// * `right_child`: The right child node of the current join execution plan. +/// * `filter`: The join filter applied, if any. +/// * `on`: Specifies which columns the join condition should be based on. +/// * `join_type`: Type of join (e.g., Inner, Left, Right). +/// * `null_equals_null`: A boolean flag indicating if null values should be considered equal during the join. +/// +/// # Returns +/// +/// * A `Result` containing an `Option` with a `Vec` of potential optimized execution plans (`Arc`). +/// - Returns an optimized plan if conditions are met. +/// - Returns `None` if no optimization is applied. +#[allow(clippy::too_many_arguments)] +fn replace_with_partial_hash_join( + parent_plan: &AggregateExec, + left_child: &Arc, + right_child: &Arc, + filter: Option<&JoinFilter>, + on: JoinOn, + join_type: JoinType, + null_equals_null: bool, +) -> Result>>> { + let left_order = left_child.output_ordering(); + let right_order = right_child.output_ordering(); + let is_left_unbounded = is_plan_streaming(left_child).unwrap_or(false); + let is_right_unbounded = is_plan_streaming(right_child).unwrap_or(false); + // To perform the prunability analysis correctly, the columns from the left table + // and the columns from the right table must be on the different sides of the join. + let filter = + filter.map(|filter| separate_columns_of_filter_expression(filter.clone())); + match ( + is_left_unbounded, + is_right_unbounded, + filter.as_ref(), + left_order, + right_order, + ) { + // Both streams are unbounded, and filter with orders are present + (true, true, Some(filter), Some(left_order), Some(right_order)) => { + // Check if filter expressions can be pruned based on the data orders + let (build_prunable, probe_prunable) = is_filter_expr_prunable( + filter, + Some(left_order[0].clone()), + Some(right_order[0].clone()), + || left_child.equivalence_properties(), + || left_child.ordering_equivalence_properties(), + || right_child.equivalence_properties(), + || right_child.ordering_equivalence_properties(), + )?; + + let group_by = parent_plan.group_by(); + let mode = parent_plan.mode(); + let aggr_expr = parent_plan.aggr_expr(); + // TODO: Implement FIRST_VALUE convert into LAST_VALUE. + let fetch_per_key = + if aggr_expr.iter().all(|expr| expr.as_any().is::()) { + 1 + } else { + return Ok(None); + }; + + // If probe side can be pruned, apply specific optimization for that case + if probe_prunable + && matches!(mode, AggregateMode::Partial) + && matches!(join_type, JoinType::Inner | JoinType::Right) + && group_by.null_expr().is_empty() + { + // Create a new partitioned hash join plan + let partitioned_hash_join = Arc::new(PartitionedHashJoinExec::try_new( + left_child.clone(), + right_child.clone(), + on.clone(), + filter.clone(), + &join_type, + null_equals_null, + left_order.to_vec(), + right_order.to_vec(), + fetch_per_key, + )?); + return Ok(Some(vec![partitioned_hash_join])); + // If build side can be pruned, apply specific optimization for that case + } else if build_prunable + && matches!(mode, AggregateMode::Partial) + && matches!(join_type, JoinType::Inner | JoinType::Left) + && group_by.null_expr().is_empty() + { + // Create a new join plan with swapped sides + let new_join = PartitionedHashJoinExec::try_new( + right_child.clone(), + left_child.clone(), + swap_join_on(&on), + swap_filter(filter), + &swap_join_type(join_type), + null_equals_null, + right_order.to_vec(), + left_order.to_vec(), + fetch_per_key, + )?; + // Create a new projection plan + let proj = ProjectionExec::try_new( + swap_reverting_projection( + &left_child.schema(), + &right_child.schema(), + ), + Arc::new(new_join), + )?; + return Ok(Some(vec![Arc::new(proj)])); + } + Ok(None) + } + // In all other cases, no specific optimization is applied + _ => Ok(None), + } +} + +/// Checks if a given `HashJoinExec` can be converted into another form of execution plans, +/// primarily into `SortMergeJoinExec` while still preserving the required conditions. +/// +/// This function first extracts key properties from the `HashJoinExec` and then determines +/// the possibility of conversion based on these properties. The conversion mainly involves +/// replacing `HashJoinExec` with `SortMergeJoinExec` for cases where both left and right children +/// are unbounded, and the output ordering can be satisfied. +/// +/// If conversion is feasible, this function creates the corresponding `SortMergeJoinExec` instances, +/// handles the possibility of swapping join types, and returns these new plans. +/// +/// # Arguments +/// +/// * `hash_join`: A reference to a `HashJoinExec` object, which is the hash join execution plan +/// that we are checking for convertibility. +/// +/// * `_config_options`: A reference to a `ConfigOptions` object, which provides the configuration settings. +/// However, these options are not used in the current function. +/// +/// # Returns +/// +/// This function returns a `Result` that contains an `Option`. If the `HashJoinExec` can be converted, +/// the `Option` will contain a `Vec` of `Arc` objects, each representing a new execution plan. +/// If conversion is not possible, the function returns `None`. Any errors that occur during the conversion process +/// will result in an `Err` being returned. +fn check_hash_join_convertable( + plan: &Arc, + config_options: &ConfigOptions, +) -> Result>>> { + if let Some(hash_join) = plan.as_any().downcast_ref::() { + // To perform the prunability analysis correctly, the columns from the left table + // and the columns from the right table must be on the different sides of the join. + let filter = hash_join + .filter() + .map(|filter| separate_columns_of_filter_expression(filter.clone())); + let (on_left, on_right): (Vec<_>, Vec<_>) = hash_join.on.iter().cloned().unzip(); + let left = hash_join.left(); + let right = hash_join.right(); + let mode = if config_options.optimizer.repartition_joins { + StreamJoinPartitionMode::Partitioned + } else { + StreamJoinPartitionMode::SinglePartition + }; + match ( + is_plan_streaming(left).unwrap_or(false), + is_plan_streaming(right).unwrap_or(false), + filter.as_ref(), + left.output_ordering(), + right.output_ordering(), + ) { + (true, true, Some(filter), Some(left_order), Some(right_order)) => { + handle_sliding_hash_conversion( + hash_join, + filter, + left_order, + right_order, + mode, + ) + } + (true, true, None, Some(left_order), Some(right_order)) => { + handle_sort_merge_join_creation( + hash_join, + mode, + &on_left, + &on_right, + left_order, + right_order, + ) + } + (true, true, maybe_filter, _, _) => { + Ok(Some(vec![create_symmetric_hash_join( + hash_join, + maybe_filter, + mode, + )?])) + } + (true, false, _, _, _) => { + let optimized_hash_join = if matches!( + *hash_join.join_type(), + JoinType::Inner + | JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + ) { + swap_join_according_to_unboundedness(hash_join)? + } else { + plan.clone() + }; + Ok(Some(vec![optimized_hash_join])) + } + (false, false, _, _, _) => { + let optimized_plan = if let Some(opt_plan) = + statistical_join_selection_hash_join( + hash_join, + config_options + .optimizer + .hash_join_single_partition_threshold, + )? { + opt_plan + } else { + plan.clone() + }; + Ok(Some(vec![optimized_plan])) + } + _ => Ok(None), + } + } else { + Ok(None) + } +} +/// Handles the conversion of a `HashJoinExec` into a sliding hash join execution plan or symmetric hash join +/// depending on whether the filter expression is prunable for both left and right orders. +/// +/// # Arguments +/// +/// * `hash_join`: Reference to the `HashJoinExec` being converted. +/// * `filter`: Reference to the join filter applied, after the filter rewrite. +/// * `left_order`: The order for the left side of the join. +/// * `right_order`: The order for the right side of the join. +/// * `mode`: The stream join partition mode. +/// +/// # Returns +/// +/// * A `Result` containing an `Option` with a `Vec` of execution plans (`Arc`). +fn handle_sliding_hash_conversion( + hash_join: &HashJoinExec, + filter: &JoinFilter, + left_order: &[PhysicalSortExpr], + right_order: &[PhysicalSortExpr], + mode: StreamJoinPartitionMode, +) -> Result>>> { + let (left_prunable, right_prunable) = is_filter_expr_prunable( + filter, + Some(left_order[0].clone()), + Some(right_order[0].clone()), + || hash_join.left().equivalence_properties(), + || hash_join.left().ordering_equivalence_properties(), + || hash_join.right().equivalence_properties(), + || hash_join.right().ordering_equivalence_properties(), + )?; + + if left_prunable && right_prunable { + let sliding_hash_join = Arc::new(SlidingHashJoinExec::try_new( + hash_join.left.clone(), + hash_join.right.clone(), + hash_join.on.clone(), + filter.clone(), + &hash_join.join_type, + hash_join.null_equals_null, + left_order.to_vec(), + right_order.to_vec(), + mode, + )?); + let reversed_sliding_hash_join = swap_sliding_hash_join(&sliding_hash_join)?; + Ok(Some(vec![sliding_hash_join, reversed_sliding_hash_join])) + } else { + // There is an configuration for allowing not prunable symmetric hash join. + Ok(Some(vec![create_symmetric_hash_join( + hash_join, + Some(filter), + mode, + )?])) + } +} + +/// Handles the creation of a `SortMergeJoinExec` from a `HashJoinExec`. +/// +/// # Arguments +/// +/// * `hash_join`: Reference to the `HashJoinExec` being converted. +/// * `mode`: The stream join partition mode. +/// * `on_left`: The columns on the left side of the join. +/// * `on_right`: The columns on the right side of the join. +/// * `left_order`: The order for the left side of the join. +/// * `right_order`: The order for the right side of the join. +/// +/// # Returns +/// +/// * A `Result` containing an `Option` with a `Vec` of execution plans (`Arc`). +fn handle_sort_merge_join_creation( + hash_join: &HashJoinExec, + mode: StreamJoinPartitionMode, + on_left: &[Column], + on_right: &[Column], + left_order: &[PhysicalSortExpr], + right_order: &[PhysicalSortExpr], +) -> Result>>> { + // Get left key(s)' sort options: + let left_satisfied = get_indices_of_matching_sort_exprs_with_order_eq( + left_order, + on_left, + &hash_join.left().equivalence_properties(), + &hash_join.left().ordering_equivalence_properties(), + ); + // Get right key(s)' sort options: + let right_satisfied = get_indices_of_matching_sort_exprs_with_order_eq( + right_order, + on_right, + &hash_join.right().equivalence_properties(), + &hash_join.right().ordering_equivalence_properties(), + ); + let mut plans: Vec> = vec![]; + if let ( + Some((left_satisfied, left_indices)), + Some((right_satisfied, right_indices)), + ) = (left_satisfied, right_satisfied) + { + // Check if the indices are equal and the sort options are aligned: + if left_indices == right_indices + && left_satisfied + .iter() + .zip(right_satisfied.iter()) + .all(|(l, r)| l == r) + { + let adjusted_keys = left_indices + .iter() + .map(|index| hash_join.on[*index].clone()) + .collect::>(); + + // SortMergeJoin does not support RightSemi + if !matches!(hash_join.join_type, JoinType::RightSemi) { + plans.push(Arc::new(SortMergeJoinExec::try_new( + hash_join.left.clone(), + hash_join.right.clone(), + adjusted_keys.clone(), + hash_join.join_type, + left_satisfied, + hash_join.null_equals_null, + )?)) + } + if !matches!(swap_join_type(hash_join.join_type), JoinType::RightSemi) { + plans.push(swap_sort_merge_join( + hash_join, + adjusted_keys, + right_satisfied, + )?); + } + } + } + plans.push(create_symmetric_hash_join( + hash_join, + hash_join.filter(), + mode, + )?); + Ok(Some(plans)) +} + +/// Creates a symmetric hash join execution plan from a `HashJoinExec`. +/// +/// # Arguments +/// +/// * `hash_join`: Reference to the `HashJoinExec` being converted. +/// * `mode`: The stream join partition mode. +/// +/// # Returns +/// +/// * A `Result` containing the execution plan (`Arc`). +fn create_symmetric_hash_join( + hash_join: &HashJoinExec, + filter: Option<&JoinFilter>, + mode: StreamJoinPartitionMode, +) -> Result> { + let plan = Arc::new(SymmetricHashJoinExec::try_new( + hash_join.left().clone(), + hash_join.right().clone(), + hash_join.on().to_vec(), + filter.cloned(), + hash_join.join_type(), + hash_join.null_equals_null(), + mode, + )?) as _; + Ok(plan) +} + +/// Checks if a given execution plan is convertible to a cross join. +/// +/// # Arguments +/// +/// * `plan`: Reference to the `ExecutionPlan` being checked. +/// +/// # Returns +/// +/// * A `Result` containing an `Option` with a `Vec` of execution plans (`Arc`). +fn check_cross_join_convertable( + plan: &Arc, +) -> Result>>> { + if let Some(cross_join) = plan.as_any().downcast_ref::() { + let optimized_plan = + if let Some(opt_plan) = statistical_join_selection_cross_join(cross_join)? { + opt_plan + } else { + plan.clone() + }; + Ok(Some(vec![optimized_plan])) + } else { + Ok(None) + } +} + +/// Checks if a nested loop join is convertible, and if so, converts it. +/// +/// # Arguments +/// +/// * `nested_loop_join`: Reference to the `NestedLoopJoinExec` being checked and potentially converted. +/// * `_config_options`: Configuration options. +/// +/// # Returns +/// +/// * A `Result` containing an `Option` with a `Vec` of execution plans (`Arc`). +fn check_nested_loop_join_convertable( + nested_loop_join: &NestedLoopJoinExec, + _config_options: &ConfigOptions, +) -> Result>>> { + // To perform the prunability analysis correctly, the columns from the left table + // and the columns from the right table must be on the different sides of the join. + let filter = nested_loop_join + .filter() + .map(|filter| separate_columns_of_filter_expression(filter.clone())); + let left_order = nested_loop_join.left().output_ordering(); + let right_order = nested_loop_join.right().output_ordering(); + let is_left_streaming = is_plan_streaming(nested_loop_join.left())?; + let is_right_streaming = is_plan_streaming(nested_loop_join.right())?; + match ( + is_left_streaming, + is_right_streaming, + filter, + left_order, + right_order, + ) { + (true, true, Some(filter), Some(left_order), Some(right_order)) => { + let (left_prunable, right_prunable) = is_filter_expr_prunable( + &filter, + Some(left_order[0].clone()), + Some(right_order[0].clone()), + || nested_loop_join.left().equivalence_properties(), + || nested_loop_join.left().ordering_equivalence_properties(), + || nested_loop_join.right().equivalence_properties(), + || nested_loop_join.right().ordering_equivalence_properties(), + )?; + if left_prunable && right_prunable { + let sliding_nested_loop_join = + Arc::new(SlidingNestedLoopJoinExec::try_new( + nested_loop_join.left().clone(), + nested_loop_join.right().clone(), + filter.clone(), + nested_loop_join.join_type(), + left_order.to_vec(), + right_order.to_vec(), + )?); + let reversed_sliding_nested_loop_join = + swap_sliding_nested_loop_join(&sliding_nested_loop_join)?; + Ok(Some(vec![ + sliding_nested_loop_join, + reversed_sliding_nested_loop_join, + ])) + } else { + Ok(None) + } + } + _ => Ok(None), + } +} + +/// Determines if the given execution plan node type is allowed before encountering +/// a `PartitionedHashJoinExec` node in the plan tree. +/// +/// # Parameters +/// * `plan`: The execution plan node to check. +/// +/// # Returns +/// * Returns `true` if the node is one of the following types: +/// * `CoalesceBatchesExec` +/// * `SortExec` +/// * `CoalescePartitionsExec` +/// * `SortPreservingMergeExec` +/// * `ProjectionExec` +/// * `RepartitionExec` +/// * `PartitionedHashJoinExec` +/// * Otherwise, returns `false`. +/// +fn is_allowed(plan: &Arc) -> bool { + let as_any = plan.as_any(); + as_any.is::() + || as_any.is::() + || as_any.is::() + || as_any.is::() + || as_any.is::() + || as_any.is::() + || as_any.is::() +} + +/// Recursively checks the execution plan from the given node downward to determine +/// if any unallowed executors exist before a `PartitionedHashJoinExec` node. +/// +/// The function traverses the tree in a depth-first manner. If it encounters an unallowed +/// executor before it reaches a `PartitionedHashJoinExec`, it immediately stops and returns +/// `false`. If a `PartitionedHashJoinExec` node is found before encountering any unallowed +/// executors, the function returns `true`. If the function completes traversal without +/// finding a `PartitionedHashJoinExec`, it still considers it as a valid path and returns +/// `true`. +/// +/// # Examples +/// +/// Considering the tree structure: +/// +/// ```plaintext +/// [ +/// "ProjectionExec: expr=[c_custkey@0 as c_custkey, n_nationkey@2 as n_nationkey]", +/// ... +/// " PartitionedHashJoinExec: join_type=Inner, on=[(c_nationkey@1, n_regionkey@1)], filter=n_nationkey@1 > c_custkey@0", +/// ... +/// ] +/// ``` +/// This will return `true` since there's a valid path with only allowed executors before +/// `PartitionedHashJoinExec`. +/// +/// For the tree structure: +/// +/// ```plaintext +///[ +/// "ProjectionExec: expr=[c_custkey@0 as c_custkey, n_nationkey@2 as n_nationkey]", +/// " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(c_nationkey@1, n_regionkey@1)], filter=n_nationkey@1 > c_custkey@0", +/// " ProjectionExec: expr=[c_custkey@1 as c_custkey, c_nationkey@2 as c_nationkey]", +/// " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(o_orderkey@0, c_custkey@0)]", +/// " ProjectionExec: expr=[o_orderkey@0 as o_orderkey]", +/// " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(o_orderdate@1, l_shipdate@1)], filter=l_orderkey@1 < o_orderkey@0 - 10 AND l_orderkey@1 > o_orderkey@0 + 10", +/// " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/orders.csv]]}, projection=[o_orderkey, o_orderdate], output_ordering=[o_orderkey@0 ASC NULLS LAST], has_header=true", +/// " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/lineitem.csv]]}, projection=[l_orderkey, l_shipdate], output_ordering=[l_orderkey@0 ASC NULLS LAST], has_header=true", +/// " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/customer.csv]]}, projection=[c_custkey, c_nationkey], output_ordering=[c_custkey@0 ASC NULLS LAST], has_header=true", +/// " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/nation.csv]]}, projection=[n_nationkey, n_regionkey], output_ordering=[n_nationkey@0 ASC NULLS LAST], has_header=true", +///] +/// ``` +/// This will return `true` because no `PartitionedHashJoinExec` is present. +/// +/// For another tree structure: +/// +/// ```plaintext +/// [ +/// "ProjectionExec: expr=[c_custkey@0 as c_custkey, n_nationkey@1 as n_nationkey]", +/// ... +/// " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(o_orderdate@3, l_shipdate@1)], filter=l_orderkey@1 < o_orderkey@0 - 10 AND l_orderkey@1 > o_orderkey@0 + 10", +/// ... +/// " PartitionedHashJoinExec: join_type=Inner, on=[(c_nationkey@1, n_regionkey@1)], filter=n_nationkey@1 > c_custkey@0", +/// ... +/// ] +/// ``` +/// This will return `false` since there's an unallowed executor before +/// `PartitionedHashJoinExec`. +/// +/// +/// ```plaintext +/// [ +/// "PartitionedHashJoinExec: join_type=Inner, on=[(z@2, c@2)], filter=0@0 > 1@1", +/// " StreamingTableExec: partition_sizes=0, projection=[x, y, z], infinite_source=true, output_ordering=[x@0 ASC]", +/// " PartitionedHashJoinExec: join_type=Inner, on=[(a@0, d@0)], filter=0@0 > 1@1", +/// " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", +/// " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", +/// ] +/// +/// ``` +/// This will return `false` since there is more than one `PartitionedHashJoinExec`. +/// +/// # Parameters +/// +/// * `plan`: The node to start checking from. +/// +/// # Returns +/// +/// * `true` if the tree has only allowed executors before encountering a +/// `PartitionedHashJoinExec` or if `PartitionedHashJoinExec` is not present. +/// * `false` if there's an unallowed executor before a `PartitionedHashJoinExec`. +/// +/// # Note +/// +/// This function uses the helper function `is_allowed` to check if a node type is allowed or not. +/// +fn check_nodes( + plan: Arc, + path: &mut Vec>, + has_seen_phj: &mut bool, +) -> bool { + path.push(plan.clone()); + // If we've encountered a PartitionedHashJoinExec + if plan.as_any().is::() { + if *has_seen_phj { + // We've already seen a PHJ and encountered another one before an aggregation. + return false; + } + *has_seen_phj = true; + + // Check if all nodes in the path are allowed before a PHJ. + for node in path.iter() { + if !is_allowed(node) { + return false; + } + } + } + + // If we encounter an AggregateExec + if plan.as_any().is::() { + // We can exit early if we see an AggregateExec after a PHJ or even without a prior PHJ. + return true; + } + + // If we have not returned by now, we recursively check children. + let mut is_valid = true; + for child in plan.children() { + is_valid &= check_nodes(child.clone(), path, has_seen_phj); + if !is_valid { + break; + } + } + + path.pop(); + is_valid +} + +/// Validates if the given execution plan results in a valid aggregation by traversing the plan's nodes. +/// +/// The function determines the validity based on certain conditions related to `PartitionedHashJoinExec` +/// and other allowed nodes leading up to it. If an `AggregateExec` is found without a prior `PartitionedHashJoinExec`, +/// the result is considered also valid. +/// +/// # Parameters +/// * `plan`: The root execution plan node to start the validation from. +/// * `has_seen_phj`: A mutable reference to a boolean flag that indicates whether a `PartitionedHashJoinExec` +/// node has been encountered during the traversal. This flag is updated during the process. +/// +/// # Returns +/// * Returns `true` if the aggregation result from the given execution plan is considered valid. +/// * Returns `false` otherwise. +/// +fn check_the_aggregation_result_is_valid( + plan: &Arc, + has_seen_phj: &mut bool, +) -> bool { + let mut path = Vec::new(); + check_nodes(plan.clone(), &mut path, has_seen_phj) } #[cfg(test)] @@ -662,9 +1423,12 @@ mod order_preserving_join_swap_tests { use crate::physical_optimizer::output_requirements::OutputRequirements; use crate::physical_optimizer::test_utils::{ memory_exec_with_sort, nested_loop_join_exec, not_prunable_filter, - sort_expr_options, + partial_prunable_filter, sort_expr_options, }; use crate::physical_optimizer::PhysicalOptimizerRule; + use crate::physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, + }; use crate::physical_plan::joins::utils::{ColumnIndex, JoinSide}; use crate::physical_plan::windows::create_window_expr; use crate::physical_plan::{displayable, ExecutionPlan}; @@ -681,8 +1445,10 @@ mod order_preserving_join_swap_tests { use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; use datafusion_common::Result; use datafusion_expr::{BuiltInWindowFunction, JoinType, WindowFrame, WindowFunction}; - use datafusion_physical_expr::expressions::{col, Column, NotExpr}; - use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_expr::expressions::{ + col, Column, FirstValue, LastValue, NotExpr, + }; + use datafusion_physical_expr::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; // Util function to get string representation of a physical plan fn get_plan_string(plan: &Arc) -> Vec { @@ -708,6 +1474,14 @@ mod order_preserving_join_swap_tests { Ok(schema) } + fn create_test_schema3() -> Result { + let x = Field::new("x", DataType::Int32, false); + let y = Field::new("y", DataType::Int32, false); + let z = Field::new("z", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![x, y, z])); + Ok(schema) + } + fn col_indices(name: &str, schema: &Schema, side: JoinSide) -> ColumnIndex { ColumnIndex { index: schema.index_of(name).unwrap(), @@ -1484,17 +2258,18 @@ mod order_preserving_join_swap_tests { " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", ]; + + // Since the JoinSelection rule cannot remove the SortExec with any executor, the plan is not executable. If the plan + // is not executable, we are choosing not to change it. let expected_optimized = [ - "SortExec: expr=[d@3 ASC]", - " SymmetricHashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 10", - " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + "HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 10", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", ]; assert_original_plan!(expected_input, physical_plan.clone()); assert_join_selection_enforce_sorting!(expected_optimized, physical_plan.clone()); - // TODO: Make the SymmetricHashJoin swap order preserving to use EnforceSorting before the JoinSelection rule - // assert_enforce_sorting_join_selection!(expected_optimized, physical_plan); + // assert_enforce_sorting_join_selection!(expected_optimized, physical_plan); Ok(()) } @@ -1831,7 +2606,174 @@ mod order_preserving_join_swap_tests { } #[tokio::test] - async fn test_not_add_sort_bounded_window_by_projection() -> Result<()> { + async fn test_not_add_sort_bounded_window_by_projection() -> Result<()> { + let left_schema = create_test_schema()?; + let right_schema = create_test_schema2()?; + let left_input = + streaming_table_exec(&left_schema, Some(vec![sort_expr("a", &left_schema)])); + let right_input = streaming_table_exec( + &right_schema, + Some(vec![sort_expr("d", &right_schema)]), + ); + let prunable_filter = prunable_filter( + col_indices("a", &left_schema, JoinSide::Left), + col_indices("d", &right_schema, JoinSide::Right), + ); + let on = vec![( + Column::new_with_schema("c", &left_schema)?, + Column::new_with_schema("c", &right_schema)?, + )]; + let join = hash_join_exec( + left_input, + right_input, + on, + Some(prunable_filter), + &JoinType::Inner, + )?; + let join_schema = join.schema(); + let window_sort_expr = vec![sort_expr("d", &join_schema)]; + let physical_plan = bounded_window_exec("b", window_sort_expr, join); + + let expected_input = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + ]; + let expected_optimized = [ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " SlidingHashJoinExec: join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + ]; + assert_original_plan!(expected_input, physical_plan.clone()); + assert_join_selection_enforce_sorting!(expected_optimized, physical_plan.clone()); + assert_enforce_sorting_join_selection!(expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_remove_unnecessary_sort_nested() -> Result<()> { + let left_schema = create_test_schema()?; + let right_schema = create_test_schema2()?; + let left_input = + streaming_table_exec(&left_schema, Some(vec![sort_expr("a", &left_schema)])); + let right_input = streaming_table_exec( + &right_schema, + Some(vec![sort_expr("d", &right_schema)]), + ); + let prunable_filter = prunable_filter( + col_indices("a", &left_schema, JoinSide::Left), + col_indices("d", &right_schema, JoinSide::Right), + ); + let join = nested_loop_join_exec( + left_input, + right_input, + Some(prunable_filter), + &JoinType::Inner, + )?; + let physical_plan = sort_exec(vec![sort_expr("d", &join.schema())], join); + + let expected_input = [ + "SortExec: expr=[d@3 ASC]", + " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + ]; + let expected_optimized = [ + "SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + ]; + assert_original_plan!(expected_input, physical_plan.clone()); + assert_join_selection_enforce_sorting!(expected_optimized, physical_plan.clone()); + assert_enforce_sorting_join_selection!(expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_can_not_remove_unnecessary_sort_nested_loop() -> Result<()> { + let left_schema = create_test_schema()?; + let right_schema = create_test_schema2()?; + let left_input = + streaming_table_exec(&left_schema, Some(vec![sort_expr("a", &left_schema)])); + let right_input = streaming_table_exec( + &right_schema, + Some(vec![sort_expr("d", &right_schema)]), + ); + let not_prunable_filter = not_prunable_filter( + col_indices("a", &left_schema, JoinSide::Left), + col_indices("d", &right_schema, JoinSide::Right), + ); + let join = nested_loop_join_exec( + left_input, + right_input, + Some(not_prunable_filter), + &JoinType::Inner, + )?; + let physical_plan = sort_exec(vec![sort_expr("d", &join.schema())], join); + + let expected_input = [ + "SortExec: expr=[d@3 ASC]", + " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 10", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + ]; + let expected_optimized = [ + "SortExec: expr=[d@3 ASC]", + " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 10", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + ]; + assert_original_plan!(expected_input, physical_plan.clone()); + assert_join_selection_enforce_sorting!(expected_optimized, physical_plan.clone()); + assert_enforce_sorting_join_selection!(expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_remove_unnecessary_sort_by_projection_nested_loop() -> Result<()> { + let left_schema = create_test_schema()?; + let right_schema = create_test_schema2()?; + let left_input = + streaming_table_exec(&left_schema, Some(vec![sort_expr("a", &left_schema)])); + let right_input = streaming_table_exec( + &right_schema, + Some(vec![sort_expr("d", &right_schema)]), + ); + let prunable_filter = prunable_filter( + col_indices("a", &left_schema, JoinSide::Left), + col_indices("d", &right_schema, JoinSide::Right), + ); + let join = nested_loop_join_exec( + left_input, + right_input, + Some(prunable_filter), + &JoinType::Inner, + )?; + let physical_plan = sort_exec(vec![sort_expr("a", &join.schema())], join); + + let expected_input = [ + "SortExec: expr=[a@0 ASC]", + " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + ]; + let expected_optimized = [ + "ProjectionExec: expr=[a@3 as a, b@4 as b, c@5 as c, d@0 as d, e@1 as e, c@2 as c]", + " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + ]; + assert_original_plan!(expected_input, physical_plan.clone()); + assert_join_selection_enforce_sorting!(expected_optimized, physical_plan.clone()); + assert_enforce_sorting_join_selection!(expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_remove_unnecessary_sort_bounded_window_by_projection_nested_loop( + ) -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; let left_input = @@ -1844,30 +2786,27 @@ mod order_preserving_join_swap_tests { col_indices("a", &left_schema, JoinSide::Left), col_indices("d", &right_schema, JoinSide::Right), ); - let on = vec![( - Column::new_with_schema("c", &left_schema)?, - Column::new_with_schema("c", &right_schema)?, - )]; - let join = hash_join_exec( + let join = nested_loop_join_exec( left_input, right_input, - on, Some(prunable_filter), &JoinType::Inner, )?; let join_schema = join.schema(); let window_sort_expr = vec![sort_expr("d", &join_schema)]; - let physical_plan = bounded_window_exec("b", window_sort_expr, join); + let sort = sort_exec(window_sort_expr.clone(), join); + let physical_plan = bounded_window_exec("b", window_sort_expr, sort); let expected_input = [ "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", - " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + " SortExec: expr=[d@3 ASC]", + " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", ]; let expected_optimized = [ "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", - " SlidingHashJoinExec: join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", ]; @@ -1878,7 +2817,7 @@ mod order_preserving_join_swap_tests { } #[tokio::test] - async fn test_remove_unnecessary_sort_nested() -> Result<()> { + async fn test_multilayer_joins_nested_loop() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; let left_input = @@ -1887,28 +2826,56 @@ mod order_preserving_join_swap_tests { &right_schema, Some(vec![sort_expr("d", &right_schema)]), ); - let prunable_filter = prunable_filter( + let filter = prunable_filter( col_indices("a", &left_schema, JoinSide::Left), col_indices("d", &right_schema, JoinSide::Right), ); let join = nested_loop_join_exec( left_input, right_input, - Some(prunable_filter), + Some(filter), &JoinType::Inner, )?; - let physical_plan = sort_exec(vec![sort_expr("d", &join.schema())], join); + let join_schema = join.schema(); + let window_sort_expr = vec![sort_expr("d", &join_schema)]; + let sort = sort_exec(window_sort_expr.clone(), join); + // Second layer + let left_input = + streaming_table_exec(&left_schema, Some(vec![sort_expr("a", &left_schema)])); + let right_input = bounded_window_exec("b", window_sort_expr, sort); + let right_schema = right_input.schema(); + + let filter = prunable_filter( + col_indices("a", &left_schema, JoinSide::Left), + col_indices("d", &right_schema, JoinSide::Right), + ); + let join = nested_loop_join_exec( + left_input, + right_input, + Some(filter), + &JoinType::Inner, + )?; + let join_schema = join.schema(); + let window_sort_expr = vec![sort_expr("d", &join_schema)]; + let physical_plan = sort_exec(window_sort_expr, join); let expected_input = [ - "SortExec: expr=[d@3 ASC]", + "SortExec: expr=[d@6 ASC]", " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " SortExec: expr=[d@3 ASC]", + " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", ]; let expected_optimized = [ "SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", ]; assert_original_plan!(expected_input, physical_plan.clone()); assert_join_selection_enforce_sorting!(expected_optimized, physical_plan.clone()); @@ -1917,7 +2884,7 @@ mod order_preserving_join_swap_tests { } #[tokio::test] - async fn test_can_not_remove_unnecessary_sort_nested_loop() -> Result<()> { + async fn test_multilayer_joins_with_sort_preserve_nested_loop() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; let left_input = @@ -1926,29 +2893,62 @@ mod order_preserving_join_swap_tests { &right_schema, Some(vec![sort_expr("d", &right_schema)]), ); - let not_prunable_filter = not_prunable_filter( + let filter = prunable_filter( col_indices("a", &left_schema, JoinSide::Left), col_indices("d", &right_schema, JoinSide::Right), ); let join = nested_loop_join_exec( left_input, right_input, - Some(not_prunable_filter), + Some(filter), &JoinType::Inner, )?; - let physical_plan = sort_exec(vec![sort_expr("d", &join.schema())], join); + let join_schema = join.schema(); + let window_sort_expr = vec![sort_expr("d", &join_schema)]; + let sort = sort_exec(window_sort_expr.clone(), join); + // Second layer + let left_input = + streaming_table_exec(&left_schema, Some(vec![sort_expr("a", &left_schema)])); + let right_input = bounded_window_exec("b", window_sort_expr, sort); + let right_schema = right_input.schema(); + + let filter = prunable_filter( + col_indices("a", &left_schema, JoinSide::Left), + col_indices("d", &right_schema, JoinSide::Right), + ); + let join = nested_loop_join_exec( + left_input, + right_input, + Some(filter), + &JoinType::Inner, + )?; + let join_schema = join.schema(); + let window_sort_expr = vec![sort_expr("d", &join_schema)]; + let sort = sort_exec(window_sort_expr, join); + let physical_plan = filter_exec( + Arc::new(NotExpr::new(col("d", join_schema.as_ref()).unwrap())), + sort, + ); let expected_input = [ - "SortExec: expr=[d@3 ASC]", - " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 10", - " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + "FilterExec: NOT d@6", + " SortExec: expr=[d@6 ASC]", + " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " SortExec: expr=[d@3 ASC]", + " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", ]; let expected_optimized = [ - "SortExec: expr=[d@3 ASC]", - " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 10", + "FilterExec: NOT d@6", + " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", ]; assert_original_plan!(expected_input, physical_plan.clone()); assert_join_selection_enforce_sorting!(expected_optimized, physical_plan.clone()); @@ -1957,38 +2957,92 @@ mod order_preserving_join_swap_tests { } #[tokio::test] - async fn test_remove_unnecessary_sort_by_projection_nested_loop() -> Result<()> { + async fn test_multiple_options_for_joins_nested_loop() -> Result<()> { let left_schema = create_test_schema()?; - let right_schema = create_test_schema2()?; + let right_table_schema = create_test_schema2()?; let left_input = streaming_table_exec(&left_schema, Some(vec![sort_expr("a", &left_schema)])); let right_input = streaming_table_exec( - &right_schema, - Some(vec![sort_expr("d", &right_schema)]), + &right_table_schema, + Some(vec![sort_expr("d", &right_table_schema)]), ); - let prunable_filter = prunable_filter( + let filter = prunable_filter( + col_indices("a", &left_schema, JoinSide::Left), + col_indices("d", &right_table_schema, JoinSide::Right), + ); + let join = nested_loop_join_exec( + left_input, + right_input, + Some(filter), + &JoinType::Inner, + )?; + let join_schema = join.schema(); + let window_sort_expr = vec![sort_expr("d", &join_schema)]; + let sort = sort_exec(window_sort_expr.clone(), join); + // Second layer + let left_input = + streaming_table_exec(&left_schema, Some(vec![sort_expr("a", &left_schema)])); + let right_input = bounded_window_exec("b", window_sort_expr, sort); + let right_schema = right_input.schema(); + let filter = prunable_filter( col_indices("a", &left_schema, JoinSide::Left), col_indices("d", &right_schema, JoinSide::Right), ); let join = nested_loop_join_exec( left_input, right_input, - Some(prunable_filter), + Some(filter), &JoinType::Inner, )?; - let physical_plan = sort_exec(vec![sort_expr("a", &join.schema())], join); + + // Third layer + let left_input = join.clone(); + let left_schema = join.schema(); + let right_input = streaming_table_exec( + &right_table_schema, + Some(vec![sort_expr("e", &right_table_schema)]), + ); + let filter = prunable_filter( + col_indices("a", &left_schema, JoinSide::Left), + col_indices("e", &right_table_schema, JoinSide::Right), + ); + let join = nested_loop_join_exec( + left_input, + right_input, + Some(filter), + &JoinType::Inner, + )?; + let join_schema = join.schema(); + // Third join + let window_sort_expr = vec![sort_expr("a", &join_schema)]; + let sort = sort_exec(window_sort_expr.clone(), join); + let physical_plan = bounded_window_exec("b", window_sort_expr, sort); let expected_input = [ - "SortExec: expr=[a@0 ASC]", - " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", - " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " SortExec: expr=[a@0 ASC]", + " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " SortExec: expr=[d@3 ASC]", + " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[e@1 ASC]", ]; let expected_optimized = [ - "ProjectionExec: expr=[a@3 as a, b@4 as b, c@5 as c, d@0 as d, e@1 as e, c@2 as c]", - " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", - " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", - " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " ProjectionExec: expr=[a@3 as a, b@4 as b, c@5 as c, a@6 as a, b@7 as b, c@8 as c, d@9 as d, e@10 as e, c@11 as c, count@12 as count, d@0 as d, e@1 as e, c@2 as c]", + " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[e@1 ASC]", + " ProjectionExec: expr=[a@7 as a, b@8 as b, c@9 as c, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e, c@5 as c, count@6 as count]", + " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", + " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", ]; assert_original_plan!(expected_input, physical_plan.clone()); assert_join_selection_enforce_sorting!(expected_optimized, physical_plan.clone()); @@ -1997,8 +3051,7 @@ mod order_preserving_join_swap_tests { } #[tokio::test] - async fn test_remove_unnecessary_sort_bounded_window_by_projection_nested_loop( - ) -> Result<()> { + async fn test_not_add_sort_bounded_window_by_projection_nested_loop() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; let left_input = @@ -2019,15 +3072,13 @@ mod order_preserving_join_swap_tests { )?; let join_schema = join.schema(); let window_sort_expr = vec![sort_expr("d", &join_schema)]; - let sort = sort_exec(window_sort_expr.clone(), join); - let physical_plan = bounded_window_exec("b", window_sort_expr, sort); + let physical_plan = bounded_window_exec("b", window_sort_expr, join); let expected_input = [ "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", - " SortExec: expr=[d@3 ASC]", - " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", - " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", ]; let expected_optimized = [ "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", @@ -2042,7 +3093,7 @@ mod order_preserving_join_swap_tests { } #[tokio::test] - async fn test_multilayer_joins_nested_loop() -> Result<()> { + async fn test_multilayer_joins_mixed() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; let left_input = @@ -2055,12 +3106,12 @@ mod order_preserving_join_swap_tests { col_indices("a", &left_schema, JoinSide::Left), col_indices("d", &right_schema, JoinSide::Right), ); - let join = nested_loop_join_exec( - left_input, - right_input, - Some(filter), - &JoinType::Inner, - )?; + let on = vec![( + Column::new_with_schema("c", &left_schema)?, + Column::new_with_schema("c", &right_schema)?, + )]; + let join = + hash_join_exec(left_input, right_input, on, Some(filter), &JoinType::Inner)?; let join_schema = join.schema(); let window_sort_expr = vec![sort_expr("d", &join_schema)]; let sort = sort_exec(window_sort_expr.clone(), join); @@ -2090,7 +3141,7 @@ mod order_preserving_join_swap_tests { " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " SortExec: expr=[d@3 ASC]", - " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", ]; @@ -2098,7 +3149,7 @@ mod order_preserving_join_swap_tests { "SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", - " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " SlidingHashJoinExec: join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", ]; @@ -2109,7 +3160,7 @@ mod order_preserving_join_swap_tests { } #[tokio::test] - async fn test_multilayer_joins_with_sort_preserve_nested_loop() -> Result<()> { + async fn test_multilayer_joins_with_sort_preserve_mixed() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; let left_input = @@ -2122,6 +3173,10 @@ mod order_preserving_join_swap_tests { col_indices("a", &left_schema, JoinSide::Left), col_indices("d", &right_schema, JoinSide::Right), ); + let on = vec![( + Column::new_with_schema("c", &left_schema)?, + Column::new_with_schema("c", &right_schema)?, + )]; let join = nested_loop_join_exec( left_input, right_input, @@ -2141,12 +3196,8 @@ mod order_preserving_join_swap_tests { col_indices("a", &left_schema, JoinSide::Left), col_indices("d", &right_schema, JoinSide::Right), ); - let join = nested_loop_join_exec( - left_input, - right_input, - Some(filter), - &JoinType::Inner, - )?; + let join = + hash_join_exec(left_input, right_input, on, Some(filter), &JoinType::Inner)?; let join_schema = join.schema(); let window_sort_expr = vec![sort_expr("d", &join_schema)]; let sort = sort_exec(window_sort_expr, join); @@ -2158,7 +3209,7 @@ mod order_preserving_join_swap_tests { let expected_input = [ "FilterExec: NOT d@6", " SortExec: expr=[d@6 ASC]", - " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " SortExec: expr=[d@3 ASC]", @@ -2168,7 +3219,7 @@ mod order_preserving_join_swap_tests { ]; let expected_optimized = [ "FilterExec: NOT d@6", - " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " SlidingHashJoinExec: join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", @@ -2182,7 +3233,7 @@ mod order_preserving_join_swap_tests { } #[tokio::test] - async fn test_multiple_options_for_joins_nested_loop() -> Result<()> { + async fn test_multiple_options_for_joins_mixed() -> Result<()> { let left_schema = create_test_schema()?; let right_table_schema = create_test_schema2()?; let left_input = @@ -2195,9 +3246,14 @@ mod order_preserving_join_swap_tests { col_indices("a", &left_schema, JoinSide::Left), col_indices("d", &right_table_schema, JoinSide::Right), ); - let join = nested_loop_join_exec( + let on = vec![( + Column::new_with_schema("c", &left_schema)?, + Column::new_with_schema("c", &right_table_schema)?, + )]; + let join = hash_join_exec( left_input, right_input, + on.clone(), Some(filter), &JoinType::Inner, )?; @@ -2231,12 +3287,8 @@ mod order_preserving_join_swap_tests { col_indices("a", &left_schema, JoinSide::Left), col_indices("e", &right_table_schema, JoinSide::Right), ); - let join = nested_loop_join_exec( - left_input, - right_input, - Some(filter), - &JoinType::Inner, - )?; + let join = + hash_join_exec(left_input, right_input, on, Some(filter), &JoinType::Inner)?; let join_schema = join.schema(); // Third join let window_sort_expr = vec![sort_expr("a", &join_schema)]; @@ -2246,12 +3298,12 @@ mod order_preserving_join_swap_tests { let expected_input = [ "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " SortExec: expr=[a@0 ASC]", - " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " SortExec: expr=[d@3 ASC]", - " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[e@1 ASC]", @@ -2259,12 +3311,12 @@ mod order_preserving_join_swap_tests { let expected_optimized = [ "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", " ProjectionExec: expr=[a@3 as a, b@4 as b, c@5 as c, a@6 as a, b@7 as b, c@8 as c, d@9 as d, e@10 as e, c@11 as c, count@12 as count, d@0 as d, e@1 as e, c@2 as c]", - " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " SlidingHashJoinExec: join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[e@1 ASC]", " ProjectionExec: expr=[a@7 as a, b@8 as b, c@9 as c, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e, c@5 as c, count@6 as count]", " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", - " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " SlidingHashJoinExec: join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", @@ -2274,9 +3326,29 @@ mod order_preserving_join_swap_tests { assert_enforce_sorting_join_selection!(expected_optimized, physical_plan); Ok(()) } + // + fn partial_aggregate_exec( + input: Arc, + group_by: PhysicalGroupBy, + aggr_expr: Vec>, + ) -> Arc { + let schema = input.schema(); + Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + group_by, + aggr_expr, + vec![], + vec![], + input, + schema, + ) + .unwrap(), + ) + } #[tokio::test] - async fn test_not_add_sort_bounded_window_by_projection_nested_loop() -> Result<()> { + async fn test_partitioned_hash_join() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; let left_input = @@ -2285,29 +3357,54 @@ mod order_preserving_join_swap_tests { &right_schema, Some(vec![sort_expr("d", &right_schema)]), ); - let prunable_filter = prunable_filter( - col_indices("a", &left_schema, JoinSide::Left), + let on = vec![( + Column::new_with_schema("a", &left_schema)?, + Column::new_with_schema("d", &right_schema)?, + )]; + + // Right side is prunable. + let partial_prunable_filter = partial_prunable_filter( col_indices("d", &right_schema, JoinSide::Right), + col_indices("a", &left_schema, JoinSide::Left), ); - let join = nested_loop_join_exec( + + // Waiting swap on PartitionedHashJoin. + let join = hash_join_exec( left_input, right_input, - Some(prunable_filter), + on, + Some(partial_prunable_filter), &JoinType::Inner, )?; let join_schema = join.schema(); - let window_sort_expr = vec![sort_expr("d", &join_schema)]; - let physical_plan = bounded_window_exec("b", window_sort_expr, join); + // aggregation from build side, not expecting swaping. + let aggr_expr = vec![Arc::new(LastValue::new( + col("b", &join_schema)?, + "LastValue(b)".to_string(), + DataType::Int32, + vec![PhysicalSortExpr { + expr: col("a", &join_schema)?, + options: SortOptions::default(), + }], + vec![DataType::Int32], + )) as _]; + + let groups: Vec<(Arc, String)> = + vec![(col("d", &join_schema)?, "d".to_string())]; + + let partial_group_by = PhysicalGroupBy::new_single(groups); + + let physical_plan = partial_aggregate_exec(join, partial_group_by, aggr_expr); let expected_input = [ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", - " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + "AggregateExec: mode=Partial, gby=[d@3 as d], aggr=[LastValue(b)], ordering_mode=FullyOrdered", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)], filter=0@0 > 1@1", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", ]; let expected_optimized = [ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", - " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + "AggregateExec: mode=Partial, gby=[d@3 as d], aggr=[LastValue(b)], ordering_mode=FullyOrdered", + " PartitionedHashJoinExec: join_type=Inner, on=[(a@0, d@0)], filter=0@0 > 1@1", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", ]; @@ -2316,9 +3413,8 @@ mod order_preserving_join_swap_tests { assert_enforce_sorting_join_selection!(expected_optimized, physical_plan); Ok(()) } - #[tokio::test] - async fn test_multilayer_joins_mixed() -> Result<()> { + async fn test_partitioned_hash_join_with_swap() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; let left_input = @@ -2327,56 +3423,191 @@ mod order_preserving_join_swap_tests { &right_schema, Some(vec![sort_expr("d", &right_schema)]), ); - let filter = prunable_filter( + let on = vec![( + Column::new_with_schema("a", &left_schema)?, + Column::new_with_schema("d", &right_schema)?, + )]; + + // Left side is prunable. + let partial_prunable_filter = partial_prunable_filter( col_indices("a", &left_schema, JoinSide::Left), col_indices("d", &right_schema, JoinSide::Right), ); + + // Waiting swap on PartitionedHashJoin. + let join = hash_join_exec( + left_input, + right_input, + on, + Some(partial_prunable_filter), + &JoinType::Inner, + )?; + let join_schema = join.schema(); + // aggregation from build side, not expecting swaping. + let aggr_expr = vec![Arc::new(LastValue::new( + col("e", &join_schema)?, + "LastValue(e)".to_string(), + DataType::Int32, + vec![PhysicalSortExpr { + expr: col("d", &join_schema)?, + options: SortOptions::default(), + }], + vec![DataType::Int32], + )) as _]; + + let groups: Vec<(Arc, String)> = + vec![(col("a", &join_schema)?, "a".to_string())]; + + let partial_group_by = PhysicalGroupBy::new_single(groups); + + let physical_plan = partial_aggregate_exec(join, partial_group_by, aggr_expr); + + let expected_input = [ + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[LastValue(e)], ordering_mode=FullyOrdered", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)], filter=0@0 > 1@1", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + ]; + let expected_optimized = [ + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[LastValue(e)], ordering_mode=FullyOrdered", + " ProjectionExec: expr=[a@3 as a, b@4 as b, c@5 as c, d@0 as d, e@1 as e, c@2 as c]", + " PartitionedHashJoinExec: join_type=Inner, on=[(d@0, a@0)], filter=0@0 > 1@1", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + ]; + assert_original_plan!(expected_input, physical_plan.clone()); + assert_join_selection_enforce_sorting!(expected_optimized, physical_plan.clone()); + assert_enforce_sorting_join_selection!(expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_partitioned_hash_not_change_due_to_group_by_sides() -> Result<()> { + let left_schema = create_test_schema()?; + let right_schema = create_test_schema2()?; + let left_input = + streaming_table_exec(&left_schema, Some(vec![sort_expr("a", &left_schema)])); + let right_input = streaming_table_exec( + &right_schema, + Some(vec![sort_expr("d", &right_schema)]), + ); let on = vec![( - Column::new_with_schema("c", &left_schema)?, - Column::new_with_schema("c", &right_schema)?, + Column::new_with_schema("a", &left_schema)?, + Column::new_with_schema("d", &right_schema)?, )]; - let join = - hash_join_exec(left_input, right_input, on, Some(filter), &JoinType::Inner)?; + + // Right side is prunable. + let partial_prunable_filter = partial_prunable_filter( + col_indices("d", &right_schema, JoinSide::Right), + col_indices("a", &left_schema, JoinSide::Left), + ); + + // Waiting swap on PartitionedHashJoin. + let join = hash_join_exec( + left_input, + right_input, + on, + Some(partial_prunable_filter), + &JoinType::Inner, + )?; let join_schema = join.schema(); - let window_sort_expr = vec![sort_expr("d", &join_schema)]; - let sort = sort_exec(window_sort_expr.clone(), join); - // Second layer + // aggregation from build side, not expecting swaping. + let aggr_expr = vec![Arc::new(LastValue::new( + col("e", &join_schema)?, + "LastValue(e)".to_string(), + DataType::Int32, + vec![PhysicalSortExpr { + expr: col("d", &join_schema)?, + options: SortOptions::default(), + }], + vec![DataType::Int32], + )) as _]; + + let groups: Vec<(Arc, String)> = + vec![(col("a", &join_schema)?, "a".to_string())]; + + let partial_group_by = PhysicalGroupBy::new_single(groups); + + let physical_plan = partial_aggregate_exec(join, partial_group_by, aggr_expr); + + let expected_input = [ + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[LastValue(e)], ordering_mode=FullyOrdered", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)], filter=0@0 > 1@1", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + ]; + let expected_optimized = [ + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[LastValue(e)], ordering_mode=FullyOrdered", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)], filter=0@0 > 1@1", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + ]; + assert_original_plan!(expected_input, physical_plan.clone()); + assert_join_selection_enforce_sorting!(expected_optimized, physical_plan.clone()); + assert_enforce_sorting_join_selection!(expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_partitioned_hash_not_change_due_to_aggr_expr() -> Result<()> { + let left_schema = create_test_schema()?; + let right_schema = create_test_schema2()?; let left_input = streaming_table_exec(&left_schema, Some(vec![sort_expr("a", &left_schema)])); - let right_input = bounded_window_exec("b", window_sort_expr, sort); - let right_schema = right_input.schema(); + let right_input = streaming_table_exec( + &right_schema, + Some(vec![sort_expr("d", &right_schema)]), + ); + let on = vec![( + Column::new_with_schema("a", &left_schema)?, + Column::new_with_schema("d", &right_schema)?, + )]; - let filter = prunable_filter( - col_indices("a", &left_schema, JoinSide::Left), + // Right side is prunable. + let partial_prunable_filter = partial_prunable_filter( col_indices("d", &right_schema, JoinSide::Right), + col_indices("a", &left_schema, JoinSide::Left), ); - let join = nested_loop_join_exec( + + // Waiting swap on PartitionedHashJoin. + let join = hash_join_exec( left_input, right_input, - Some(filter), + on, + Some(partial_prunable_filter), &JoinType::Inner, )?; let join_schema = join.schema(); - let window_sort_expr = vec![sort_expr("d", &join_schema)]; - let physical_plan = sort_exec(window_sort_expr, join); + // aggregation from build side, not expecting swaping. + let aggr_expr = vec![Arc::new(FirstValue::new( + col("b", &join_schema)?, + "FirstValue(b)".to_string(), + DataType::Int32, + vec![PhysicalSortExpr { + expr: col("a", &join_schema)?, + options: SortOptions::default(), + }], + vec![DataType::Int32], + )) as _]; + + let groups: Vec<(Arc, String)> = + vec![(col("d", &join_schema)?, "d".to_string())]; + + let partial_group_by = PhysicalGroupBy::new_single(groups); + + let physical_plan = partial_aggregate_exec(join, partial_group_by, aggr_expr); let expected_input = [ - "SortExec: expr=[d@6 ASC]", - " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + "AggregateExec: mode=Partial, gby=[d@3 as d], aggr=[FirstValue(b)], ordering_mode=FullyOrdered", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)], filter=0@0 > 1@1", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", - " SortExec: expr=[d@3 ASC]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", - " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", ]; let expected_optimized = [ - "SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", - " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", - " SlidingHashJoinExec: join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", - " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + "AggregateExec: mode=Partial, gby=[d@3 as d], aggr=[FirstValue(b)], ordering_mode=FullyOrdered", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)], filter=0@0 > 1@1", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", ]; assert_original_plan!(expected_input, physical_plan.clone()); assert_join_selection_enforce_sorting!(expected_optimized, physical_plan.clone()); @@ -2385,7 +3616,7 @@ mod order_preserving_join_swap_tests { } #[tokio::test] - async fn test_multilayer_joins_with_sort_preserve_mixed() -> Result<()> { + async fn test_prevent_multiple_partitioned_hash_join_for_single_agg() -> Result<()> { let left_schema = create_test_schema()?; let right_schema = create_test_schema2()?; let left_input = @@ -2394,62 +3625,194 @@ mod order_preserving_join_swap_tests { &right_schema, Some(vec![sort_expr("d", &right_schema)]), ); - let filter = prunable_filter( + let on = vec![( + Column::new_with_schema("a", &left_schema)?, + Column::new_with_schema("d", &right_schema)?, + )]; + + // Right side is prunable. + let partial_prunable_join_filter = partial_prunable_filter( + col_indices("d", &right_schema, JoinSide::Right), col_indices("a", &left_schema, JoinSide::Left), + ); + + // Waiting swap on PartitionedHashJoin. + let join = hash_join_exec( + left_input, + right_input, + on, + Some(partial_prunable_join_filter), + &JoinType::Inner, + )?; + + let first_join_schema = join.schema(); + + // Second Join + let third_table_schema = create_test_schema3()?; + let third_input = streaming_table_exec( + &third_table_schema, + Some(vec![sort_expr("x", &third_table_schema)]), + ); + + // Right side is prunable. + let partial_prunable_filter = partial_prunable_filter( + col_indices("d", &first_join_schema, JoinSide::Right), + col_indices("x", &third_table_schema, JoinSide::Left), + ); + + let on = vec![( + Column::new_with_schema("z", &third_table_schema)?, + Column::new_with_schema("c", &first_join_schema)?, + )]; + + let second_join = hash_join_exec( + third_input, + join, + on, + Some(partial_prunable_filter), + &JoinType::Inner, + )?; + + let second_join_schema = second_join.schema(); + + // aggregation from build side, not expecting swaping. + let aggr_expr = vec![Arc::new(LastValue::new( + col("y", &second_join_schema)?, + "LastValue(y)".to_string(), + DataType::Int32, + vec![PhysicalSortExpr { + expr: col("x", &second_join_schema)?, + options: SortOptions::default(), + }], + vec![DataType::Int32], + )) as _]; + + let groups: Vec<(Arc, String)> = + vec![(col("d", &second_join_schema)?, "d".to_string())]; + + let partial_group_by = PhysicalGroupBy::new_single(groups); + + let physical_plan = + partial_aggregate_exec(second_join, partial_group_by, aggr_expr); + + let expected_input = [ + "AggregateExec: mode=Partial, gby=[d@6 as d], aggr=[LastValue(y)], ordering_mode=FullyOrdered", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(z@2, c@2)], filter=0@0 > 1@1", + " StreamingTableExec: partition_sizes=0, projection=[x, y, z], infinite_source=true, output_ordering=[x@0 ASC]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)], filter=0@0 > 1@1", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + ]; + let expected_optimized = [ + "AggregateExec: mode=Partial, gby=[d@6 as d], aggr=[LastValue(y)], ordering_mode=FullyOrdered", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(z@2, c@2)], filter=0@0 > 1@1", + " StreamingTableExec: partition_sizes=0, projection=[x, y, z], infinite_source=true, output_ordering=[x@0 ASC]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)], filter=0@0 > 1@1", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + ]; + assert_original_plan!(expected_input, physical_plan.clone()); + assert_join_selection_enforce_sorting!(expected_optimized, physical_plan.clone()); + assert_enforce_sorting_join_selection!(expected_optimized, physical_plan); + Ok(()) + } + + #[tokio::test] + async fn test_unified_hash_joins_sliding_hash_join() -> Result<()> { + let left_schema = create_test_schema()?; + let right_schema = create_test_schema2()?; + let left_input = + streaming_table_exec(&left_schema, Some(vec![sort_expr("a", &left_schema)])); + let right_input = memory_exec_with_sort( + &right_schema, + Some(vec![sort_expr("d", &right_schema)]), + ); + let on = vec![( + Column::new_with_schema("a", &left_schema)?, + Column::new_with_schema("d", &right_schema)?, + )]; + + // Right side is prunable. + let partial_prunable_join_filter = partial_prunable_filter( col_indices("d", &right_schema, JoinSide::Right), + col_indices("a", &left_schema, JoinSide::Left), + ); + + // Waiting swap on PartitionedHashJoin. + let join = hash_join_exec( + left_input, + right_input, + on, + Some(partial_prunable_join_filter), + &JoinType::Inner, + )?; + + let first_join_schema = join.schema(); + + // Second Join + let third_table_schema = create_test_schema3()?; + let third_input = streaming_table_exec( + &third_table_schema, + Some(vec![sort_expr("x", &third_table_schema)]), ); + + // Right side is prunable. + let prunable_filter = prunable_filter( + col_indices("d", &first_join_schema, JoinSide::Right), + col_indices("x", &third_table_schema, JoinSide::Left), + ); + let on = vec![( - Column::new_with_schema("c", &left_schema)?, - Column::new_with_schema("c", &right_schema)?, + Column::new_with_schema("z", &third_table_schema)?, + Column::new_with_schema("c", &first_join_schema)?, )]; - let join = nested_loop_join_exec( - left_input, - right_input, - Some(filter), + + let second_join = hash_join_exec( + third_input, + join, + on, + Some(prunable_filter), &JoinType::Inner, )?; - let join_schema = join.schema(); - let window_sort_expr = vec![sort_expr("d", &join_schema)]; - let sort = sort_exec(window_sort_expr.clone(), join); - // Second layer - let left_input = - streaming_table_exec(&left_schema, Some(vec![sort_expr("a", &left_schema)])); - let right_input = bounded_window_exec("b", window_sort_expr, sort); - let right_schema = right_input.schema(); - let filter = prunable_filter( - col_indices("a", &left_schema, JoinSide::Left), - col_indices("d", &right_schema, JoinSide::Right), - ); - let join = - hash_join_exec(left_input, right_input, on, Some(filter), &JoinType::Inner)?; - let join_schema = join.schema(); - let window_sort_expr = vec![sort_expr("d", &join_schema)]; - let sort = sort_exec(window_sort_expr, join); - let physical_plan = filter_exec( - Arc::new(NotExpr::new(col("d", join_schema.as_ref()).unwrap())), - sort, - ); + let second_join_schema = second_join.schema(); + + // aggregation from build side, not expecting swaping. + let aggr_expr = vec![Arc::new(LastValue::new( + col("b", &second_join_schema)?, + "LastValue(b)".to_string(), + DataType::Int32, + vec![PhysicalSortExpr { + expr: col("a", &second_join_schema)?, + options: SortOptions::default(), + }], + vec![DataType::Int32], + )) as _]; + + let groups: Vec<(Arc, String)> = + vec![(col("d", &second_join_schema)?, "d".to_string())]; + + let partial_group_by = PhysicalGroupBy::new_single(groups); + + let physical_plan = + partial_aggregate_exec(second_join, partial_group_by, aggr_expr); let expected_input = [ - "FilterExec: NOT d@6", - " SortExec: expr=[d@6 ASC]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + "AggregateExec: mode=Partial, gby=[d@6 as d], aggr=[LastValue(b)], ordering_mode=FullyOrdered", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(z@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[x, y, z], infinite_source=true, output_ordering=[x@0 ASC]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)], filter=0@0 > 1@1", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", - " SortExec: expr=[d@3 ASC]", - " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", - " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[], output_ordering=d@0 ASC", ]; let expected_optimized = [ - "FilterExec: NOT d@6", - " SlidingHashJoinExec: join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", - " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", - " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + "AggregateExec: mode=Partial, gby=[d@6 as d], aggr=[LastValue(b)], ordering_mode=FullyOrdered", + " SlidingHashJoinExec: join_type=Inner, on=[(z@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", + " StreamingTableExec: partition_sizes=0, projection=[x, y, z], infinite_source=true, output_ordering=[x@0 ASC]", + " ProjectionExec: expr=[a@3 as a, b@4 as b, c@5 as c, d@0 as d, e@1 as e, c@2 as c]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(d@0, a@0)], filter=0@0 > 1@1", + " MemoryExec: partitions=0, partition_sizes=[], output_ordering=d@0 ASC", " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", ]; assert_original_plan!(expected_input, physical_plan.clone()); assert_join_selection_enforce_sorting!(expected_optimized, physical_plan.clone()); @@ -2458,93 +3821,101 @@ mod order_preserving_join_swap_tests { } #[tokio::test] - async fn test_multiple_options_for_joins_mixed() -> Result<()> { + async fn test_unified_hash_joins_partitioned_hash_join() -> Result<()> { let left_schema = create_test_schema()?; - let right_table_schema = create_test_schema2()?; + let right_schema = create_test_schema2()?; let left_input = streaming_table_exec(&left_schema, Some(vec![sort_expr("a", &left_schema)])); - let right_input = streaming_table_exec( - &right_table_schema, - Some(vec![sort_expr("d", &right_table_schema)]), - ); - let filter = prunable_filter( - col_indices("a", &left_schema, JoinSide::Left), - col_indices("d", &right_table_schema, JoinSide::Right), + let right_input = memory_exec_with_sort( + &right_schema, + Some(vec![sort_expr("d", &right_schema)]), ); let on = vec![( - Column::new_with_schema("c", &left_schema)?, - Column::new_with_schema("c", &right_table_schema)?, + Column::new_with_schema("a", &left_schema)?, + Column::new_with_schema("d", &right_schema)?, )]; - let join = hash_join_exec( - left_input, - right_input, - on.clone(), - Some(filter), - &JoinType::Inner, - )?; - let join_schema = join.schema(); - let window_sort_expr = vec![sort_expr("d", &join_schema)]; - let sort = sort_exec(window_sort_expr.clone(), join); - // Second layer - let left_input = - streaming_table_exec(&left_schema, Some(vec![sort_expr("a", &left_schema)])); - let right_input = bounded_window_exec("b", window_sort_expr, sort); - let right_schema = right_input.schema(); - let filter = prunable_filter( - col_indices("a", &left_schema, JoinSide::Left), + + // Right side is prunable. + let partial_prunable_join_filter = partial_prunable_filter( col_indices("d", &right_schema, JoinSide::Right), + col_indices("a", &left_schema, JoinSide::Left), ); - let join = nested_loop_join_exec( + + // Waiting swap on PartitionedHashJoin. + let join = hash_join_exec( left_input, right_input, - Some(filter), + on, + Some(partial_prunable_join_filter), &JoinType::Inner, )?; - // Third layer - let left_input = join.clone(); - let left_schema = join.schema(); - let right_input = streaming_table_exec( - &right_table_schema, - Some(vec![sort_expr("e", &right_table_schema)]), + let first_join_schema = join.schema(); + + // Second Join + let third_table_schema = create_test_schema3()?; + let third_input = streaming_table_exec( + &third_table_schema, + Some(vec![sort_expr("x", &third_table_schema)]), ); - let filter = prunable_filter( - col_indices("a", &left_schema, JoinSide::Left), - col_indices("e", &right_table_schema, JoinSide::Right), + + // Right side is prunable. + let partial_prunable_filter = partial_prunable_filter( + col_indices("d", &first_join_schema, JoinSide::Right), + col_indices("x", &third_table_schema, JoinSide::Left), ); - let join = - hash_join_exec(left_input, right_input, on, Some(filter), &JoinType::Inner)?; - let join_schema = join.schema(); - // Third join - let window_sort_expr = vec![sort_expr("a", &join_schema)]; - let sort = sort_exec(window_sort_expr.clone(), join); - let physical_plan = bounded_window_exec("b", window_sort_expr, sort); + + let on = vec![( + Column::new_with_schema("z", &third_table_schema)?, + Column::new_with_schema("c", &first_join_schema)?, + )]; + + let second_join = hash_join_exec( + third_input, + join, + on, + Some(partial_prunable_filter), + &JoinType::Inner, + )?; + + let second_join_schema = second_join.schema(); + + // aggregation from build side, not expecting swaping. + let aggr_expr = vec![Arc::new(LastValue::new( + col("y", &second_join_schema)?, + "LastValue(y)".to_string(), + DataType::Int32, + vec![PhysicalSortExpr { + expr: col("x", &second_join_schema)?, + options: SortOptions::default(), + }], + vec![DataType::Int32], + )) as _]; + + let groups: Vec<(Arc, String)> = + vec![(col("d", &second_join_schema)?, "d".to_string())]; + + let partial_group_by = PhysicalGroupBy::new_single(groups); + + let physical_plan = + partial_aggregate_exec(second_join, partial_group_by, aggr_expr); let expected_input = [ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", - " SortExec: expr=[a@0 ASC]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", - " NestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", - " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", - " SortExec: expr=[d@3 ASC]", - " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", - " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", - " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[e@1 ASC]", + "AggregateExec: mode=Partial, gby=[d@6 as d], aggr=[LastValue(y)], ordering_mode=FullyOrdered", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(z@2, c@2)], filter=0@0 > 1@1", + " StreamingTableExec: partition_sizes=0, projection=[x, y, z], infinite_source=true, output_ordering=[x@0 ASC]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, d@0)], filter=0@0 > 1@1", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[], output_ordering=d@0 ASC", ]; let expected_optimized = [ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", - " ProjectionExec: expr=[a@3 as a, b@4 as b, c@5 as c, a@6 as a, b@7 as b, c@8 as c, d@9 as d, e@10 as e, c@11 as c, count@12 as count, d@0 as d, e@1 as e, c@2 as c]", - " SlidingHashJoinExec: join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", - " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[e@1 ASC]", - " ProjectionExec: expr=[a@7 as a, b@8 as b, c@9 as c, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e, c@5 as c, count@6 as count]", - " SlidingNestedLoopJoinExec: join_type=Inner, filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }], mode=[Sorted]", - " SlidingHashJoinExec: join_type=Inner, on=[(c@2, c@2)], filter=0@0 + 0 > 1@1 - 3 AND 0@0 + 0 < 1@1 + 3", - " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", - " StreamingTableExec: partition_sizes=0, projection=[d, e, c], infinite_source=true, output_ordering=[d@0 ASC]", - " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", + "AggregateExec: mode=Partial, gby=[d@6 as d], aggr=[LastValue(y)], ordering_mode=FullyOrdered", + " PartitionedHashJoinExec: join_type=Inner, on=[(z@2, c@2)], filter=0@0 > 1@1", + " StreamingTableExec: partition_sizes=0, projection=[x, y, z], infinite_source=true, output_ordering=[x@0 ASC]", + " ProjectionExec: expr=[a@3 as a, b@4 as b, c@5 as c, d@0 as d, e@1 as e, c@2 as c]", + " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(d@0, a@0)], filter=0@0 > 1@1", + " MemoryExec: partitions=0, partition_sizes=[], output_ordering=d@0 ASC", + " StreamingTableExec: partition_sizes=0, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC]", ]; assert_original_plan!(expected_input, physical_plan.clone()); assert_join_selection_enforce_sorting!(expected_optimized, physical_plan.clone()); @@ -2552,3 +3923,492 @@ mod order_preserving_join_swap_tests { Ok(()) } } + +#[cfg(test)] +mod sql_fuzzy_tests { + use crate::common::Result; + use crate::physical_plan::displayable; + use crate::physical_plan::{collect, ExecutionPlan}; + use crate::prelude::{CsvReadOptions, SessionContext}; + use arrow::util::pretty::pretty_format_batches; + use arrow_array::RecordBatch; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_execution::config::SessionConfig; + use datafusion_expr::expr::Sort; + use datafusion_expr::{col, Expr}; + use itertools::izip; + use std::path::PathBuf; + use std::sync::{Arc, OnceLock}; + + pub fn get_tpch_table_schema(table: &str) -> Schema { + match table { + "customer" => Schema::new(vec![ + Field::new("c_custkey", DataType::Int64, false), + Field::new("c_name", DataType::Utf8, false), + Field::new("c_address", DataType::Utf8, false), + Field::new("c_nationkey", DataType::Int64, false), + Field::new("c_phone", DataType::Utf8, false), + Field::new("c_acctbal", DataType::Decimal128(15, 2), false), + Field::new("c_mktsegment", DataType::Utf8, false), + Field::new("c_comment", DataType::Utf8, false), + ]), + + "orders" => Schema::new(vec![ + Field::new("o_orderkey", DataType::Int64, false), + Field::new("o_custkey", DataType::Int64, false), + Field::new("o_orderstatus", DataType::Utf8, false), + Field::new("o_totalprice", DataType::Decimal128(15, 2), false), + Field::new("o_orderdate", DataType::Date32, false), + Field::new("o_orderpriority", DataType::Utf8, false), + Field::new("o_clerk", DataType::Utf8, false), + Field::new("o_shippriority", DataType::Int32, false), + Field::new("o_comment", DataType::Utf8, false), + ]), + + "lineitem" => Schema::new(vec![ + Field::new("l_orderkey", DataType::Int64, false), + Field::new("l_partkey", DataType::Int64, false), + Field::new("l_suppkey", DataType::Int64, false), + Field::new("l_linenumber", DataType::Int32, false), + Field::new("l_quantity", DataType::Decimal128(15, 2), false), + Field::new("l_extendedprice", DataType::Decimal128(15, 2), false), + Field::new("l_discount", DataType::Decimal128(15, 2), false), + Field::new("l_tax", DataType::Decimal128(15, 2), false), + Field::new("l_returnflag", DataType::Utf8, false), + Field::new("l_linestatus", DataType::Utf8, false), + Field::new("l_shipdate", DataType::Date32, false), + Field::new("l_commitdate", DataType::Date32, false), + Field::new("l_receiptdate", DataType::Date32, false), + Field::new("l_shipinstruct", DataType::Utf8, false), + Field::new("l_shipmode", DataType::Utf8, false), + Field::new("l_comment", DataType::Utf8, false), + ]), + + "nation" => Schema::new(vec![ + Field::new("n_nationkey", DataType::Int64, false), + Field::new("n_name", DataType::Utf8, false), + Field::new("n_regionkey", DataType::Int64, false), + Field::new("n_comment", DataType::Utf8, false), + ]), + "region" => Schema::new(vec![ + Field::new("r_regionkey", DataType::Int64, false), + Field::new("r_name", DataType::Utf8, false), + Field::new("r_comment", DataType::Utf8, false), + ]), + _ => unimplemented!(), + } + } + + fn workspace_dir() -> String { + // e.g. /Software/arrow-datafusion/datafusion/core + let dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + + // e.g. /Software/arrow-datafusion/datafusion + dir.parent() + .expect("Can not find parent of datafusion/core") + // e.g. /Software/arrow-datafusion + .parent() + .expect("parent of datafusion") + .to_string_lossy() + .to_string() + } + + fn workspace_root() -> &'static object_store::path::Path { + static WORKSPACE_ROOT_LOCK: OnceLock = OnceLock::new(); + WORKSPACE_ROOT_LOCK.get_or_init(|| { + let workspace_root = workspace_dir(); + + let sanitized_workplace_root = if cfg!(windows) { + // Object store paths are delimited with `/`, e.g. `D:/a/arrow-datafusion/arrow-datafusion/testing/data/csv/aggregate_test_100.csv`. + // The default windows delimiter is `\`, so the workplace path is `D:\a\arrow-datafusion\arrow-datafusion`. + workspace_root + .replace(std::path::MAIN_SEPARATOR, object_store::path::DELIMITER) + } else { + workspace_root.to_string() + }; + + object_store::path::Path::parse(sanitized_workplace_root).unwrap() + }) + } + + fn assert_original_plan(plan: Arc, expected_lines: &[&str]) { + let formatted = displayable(plan.as_ref()).indent(true).to_string(); + let mut formated_strings = formatted + .trim() + .lines() + .map(String::from) + .collect::>(); + let workspace_root: &str = workspace_root().as_ref(); + formated_strings.iter_mut().for_each(|s| { + if s.contains(workspace_root) { + *s = s.replace(workspace_root, "WORKSPACE_ROOT"); + } + }); + let expected_plan_lines: Vec = + expected_lines.iter().map(|s| String::from(*s)).collect(); + + assert_eq!( + expected_plan_lines, formated_strings, + "\n**Original Plan Mismatch\n\nexpected:\n\n{expected_plan_lines:#?}\nactual:\n\n{formated_strings:#?}\n\n" + ); + } + + // Define a common utility to set up the session context and tables + async fn setup_context(mark_infinite: bool) -> Result { + let abs_path = workspace_dir() + "/datafusion/core/tests/tpch-csv/"; + + let config = SessionConfig::new() + .with_target_partitions(1) + .with_repartition_joins(false); + let ctx = SessionContext::with_config(config); + let tables = ["orders", "lineitem", "customer", "nation", "region"]; + let can_be_infinite = [true, true, true, true, false]; + let ordered_columns = [ + "o_orderkey", + "l_orderkey", + "c_custkey", + "n_nationkey", + "r_regionkey", + ]; + + for (table, inf, ordered_col) in izip!(tables, can_be_infinite, ordered_columns) { + ctx.register_csv( + table, + &format!("{}/{}.csv", abs_path, table), + CsvReadOptions::new() + .schema(&get_tpch_table_schema(table)) + .mark_infinite(mark_infinite && inf) + .file_sort_order(vec![vec![Expr::Sort(Sort::new( + Box::new(col(ordered_col)), + true, + false, + ))]]), + ) + .await?; + } + Ok(ctx) + } + + async fn unbounded_execution( + expected_input: &[&str], + sql: &str, + ) -> Result> { + let ctx = setup_context(true).await?; + let dataframe = ctx.sql(sql).await?; + let physical_plan = dataframe.create_physical_plan().await?; + assert_original_plan(physical_plan.clone(), expected_input); + let batches = collect(physical_plan, ctx.task_ctx()).await?; + Ok(batches) + } + + async fn bounded_execution(sql: &str) -> Result> { + let ctx = setup_context(false).await?; + let dataframe = ctx.sql(sql).await?; + let physical_plan = dataframe.create_physical_plan().await?; + let batches = collect(physical_plan, ctx.task_ctx()).await?; + Ok(batches) + } + + async fn experiment(expected_unbounded_plan: &[&str], sql: &str) -> Result<()> { + let first_batches = unbounded_execution(expected_unbounded_plan, sql).await?; + let second_batches = bounded_execution(sql).await?; + compare_batches(&first_batches, &second_batches); + Ok(()) + } + + fn compare_batches(collected_1: &[RecordBatch], collected_2: &[RecordBatch]) { + // compare + let first_formatted = pretty_format_batches(collected_1).unwrap().to_string(); + let second_formatted = pretty_format_batches(collected_2).unwrap().to_string(); + + let mut first_formatted_sorted: Vec<&str> = + first_formatted.trim().lines().collect(); + first_formatted_sorted.sort_unstable(); + + let mut second_formatted_sorted: Vec<&str> = + second_formatted.trim().lines().collect(); + second_formatted_sorted.sort_unstable(); + + for (i, (first_line, second_line)) in first_formatted_sorted + .iter() + .zip(&second_formatted_sorted) + .enumerate() + { + if (i, first_line) != (i, second_line) { + assert_eq!((i, first_line), (i, second_line)); + } + } + } + + #[tokio::test] + async fn test_unbounded_hash_selection1() -> Result<()> { + let sql = "SELECT + o_orderkey, LAST_VALUE(l_suppkey ORDER BY l_orderkey) AS amount_usd + FROM + customer, + nation, + orders, + lineitem + WHERE + c_custkey = o_orderkey + AND n_regionkey = c_nationkey + AND n_nationkey > c_custkey + AND n_nationkey < c_custkey + 20 + AND l_orderkey < o_orderkey - 10 + AND o_orderdate = l_shipdate + AND l_returnflag = 'R' + GROUP BY o_orderkey"; + + let expected_plan = [ + "ProjectionExec: expr=[o_orderkey@0 as o_orderkey, LAST_VALUE(lineitem.l_suppkey) ORDER BY [lineitem.l_orderkey ASC NULLS LAST]@1 as amount_usd]", + " AggregateExec: mode=Single, gby=[o_orderkey@0 as o_orderkey], aggr=[LAST_VALUE(lineitem.l_suppkey)], ordering_mode=FullyOrdered", + " ProjectionExec: expr=[o_orderkey@0 as o_orderkey, l_orderkey@2 as l_orderkey, l_suppkey@3 as l_suppkey]", + " ProjectionExec: expr=[o_orderkey@3 as o_orderkey, o_orderdate@4 as o_orderdate, l_orderkey@0 as l_orderkey, l_suppkey@1 as l_suppkey, l_shipdate@2 as l_shipdate]", + " PartitionedHashJoinExec: join_type=Inner, on=[(l_shipdate@2, o_orderdate@1)], filter=l_orderkey@1 < o_orderkey@0 - 10", + " ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_suppkey@1 as l_suppkey, l_shipdate@3 as l_shipdate]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: l_returnflag@2 = R", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/lineitem.csv]]}, projection=[l_orderkey, l_suppkey, l_returnflag, l_shipdate], infinite_source=true, output_ordering=[l_orderkey@0 ASC NULLS LAST], has_header=true", + " ProjectionExec: expr=[o_orderkey@1 as o_orderkey, o_orderdate@2 as o_orderdate]", + " SortMergeJoin: join_type=Inner, on=[(c_custkey@0, o_orderkey@0)]", + " ProjectionExec: expr=[c_custkey@0 as c_custkey]", + " ProjectionExec: expr=[c_custkey@2 as c_custkey, c_nationkey@3 as c_nationkey, n_nationkey@0 as n_nationkey, n_regionkey@1 as n_regionkey]", + " SlidingHashJoinExec: join_type=Inner, on=[(n_regionkey@1, c_nationkey@1)], filter=n_nationkey@1 > c_custkey@0 AND n_nationkey@1 < c_custkey@0 + 20", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/nation.csv]]}, projection=[n_nationkey, n_regionkey], infinite_source=true, output_ordering=[n_nationkey@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/customer.csv]]}, projection=[c_custkey, c_nationkey], infinite_source=true, output_ordering=[c_custkey@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/orders.csv]]}, projection=[o_orderkey, o_orderdate], infinite_source=true, output_ordering=[o_orderkey@0 ASC NULLS LAST], has_header=true", + ]; + + experiment(&expected_plan, sql).await?; + Ok(()) + } + + #[tokio::test] + async fn test_unbounded_hash_selection3() -> Result<()> { + let sql = "SELECT + n_nationkey, LAST_VALUE(c_custkey ORDER BY c_custkey) AS amount_usd + FROM + orders, + lineitem, + customer, + nation + WHERE + c_custkey = o_orderkey + AND n_regionkey = c_nationkey + AND n_nationkey > c_custkey + AND l_orderkey < o_orderkey - 10 + AND l_orderkey > o_orderkey + 10 + AND o_orderdate = l_shipdate + GROUP BY n_nationkey"; + + let expected_plan = [ + "ProjectionExec: expr=[n_nationkey@0 as n_nationkey, LAST_VALUE(customer.c_custkey) ORDER BY [customer.c_custkey ASC NULLS LAST]@1 as amount_usd]", + " AggregateExec: mode=Single, gby=[n_nationkey@1 as n_nationkey], aggr=[LAST_VALUE(customer.c_custkey)], ordering_mode=FullyOrdered", + " ProjectionExec: expr=[c_custkey@0 as c_custkey, n_nationkey@2 as n_nationkey]", + " PartitionedHashJoinExec: join_type=Inner, on=[(c_nationkey@1, n_regionkey@1)], filter=n_nationkey@1 > c_custkey@0", + " ProjectionExec: expr=[c_custkey@1 as c_custkey, c_nationkey@2 as c_nationkey]", + " SortMergeJoin: join_type=Inner, on=[(o_orderkey@0, c_custkey@0)]", + " ProjectionExec: expr=[o_orderkey@0 as o_orderkey]", + " ProjectionExec: expr=[o_orderkey@2 as o_orderkey, o_orderdate@3 as o_orderdate, l_orderkey@0 as l_orderkey, l_shipdate@1 as l_shipdate]", + " SlidingHashJoinExec: join_type=Inner, on=[(l_shipdate@1, o_orderdate@1)], filter=l_orderkey@1 < o_orderkey@0 - 10 AND l_orderkey@1 > o_orderkey@0 + 10", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/lineitem.csv]]}, projection=[l_orderkey, l_shipdate], infinite_source=true, output_ordering=[l_orderkey@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/orders.csv]]}, projection=[o_orderkey, o_orderdate], infinite_source=true, output_ordering=[o_orderkey@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/customer.csv]]}, projection=[c_custkey, c_nationkey], infinite_source=true, output_ordering=[c_custkey@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/nation.csv]]}, projection=[n_nationkey, n_regionkey], infinite_source=true, output_ordering=[n_nationkey@0 ASC NULLS LAST], has_header=true", + ]; + + experiment(&expected_plan, sql).await?; + Ok(()) + } + + #[tokio::test] + async fn test_unbounded_hash_selection4() -> Result<()> { + let sql = "SELECT + sub.n_nationkey, + SUM(sub.amount_usd) + FROM + ( + SELECT + n_nationkey, + LAST_VALUE(c_custkey ORDER BY c_custkey) AS amount_usd + FROM + orders, + lineitem, + customer, + nation + WHERE + c_custkey = o_orderkey + AND n_regionkey = c_nationkey + AND n_nationkey > c_custkey + AND l_orderkey < o_orderkey - 10 + AND l_orderkey > o_orderkey + 10 + AND o_orderdate = l_shipdate + GROUP BY n_nationkey + ) AS sub + GROUP BY sub.n_nationkey"; + + let expected_plan = [ + "AggregateExec: mode=Single, gby=[n_nationkey@0 as n_nationkey], aggr=[SUM(sub.amount_usd)], ordering_mode=FullyOrdered", + " ProjectionExec: expr=[n_nationkey@0 as n_nationkey, LAST_VALUE(customer.c_custkey) ORDER BY [customer.c_custkey ASC NULLS LAST]@1 as amount_usd]", + " AggregateExec: mode=Single, gby=[n_nationkey@1 as n_nationkey], aggr=[LAST_VALUE(customer.c_custkey)], ordering_mode=FullyOrdered", + " ProjectionExec: expr=[c_custkey@0 as c_custkey, n_nationkey@2 as n_nationkey]", + " PartitionedHashJoinExec: join_type=Inner, on=[(c_nationkey@1, n_regionkey@1)], filter=n_nationkey@1 > c_custkey@0", + " ProjectionExec: expr=[c_custkey@1 as c_custkey, c_nationkey@2 as c_nationkey]", + " SortMergeJoin: join_type=Inner, on=[(o_orderkey@0, c_custkey@0)]", + " ProjectionExec: expr=[o_orderkey@0 as o_orderkey]", + " ProjectionExec: expr=[o_orderkey@2 as o_orderkey, o_orderdate@3 as o_orderdate, l_orderkey@0 as l_orderkey, l_shipdate@1 as l_shipdate]", + " SlidingHashJoinExec: join_type=Inner, on=[(l_shipdate@1, o_orderdate@1)], filter=l_orderkey@1 < o_orderkey@0 - 10 AND l_orderkey@1 > o_orderkey@0 + 10", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/lineitem.csv]]}, projection=[l_orderkey, l_shipdate], infinite_source=true, output_ordering=[l_orderkey@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/orders.csv]]}, projection=[o_orderkey, o_orderdate], infinite_source=true, output_ordering=[o_orderkey@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/customer.csv]]}, projection=[c_custkey, c_nationkey], infinite_source=true, output_ordering=[c_custkey@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/nation.csv]]}, projection=[n_nationkey, n_regionkey], infinite_source=true, output_ordering=[n_nationkey@0 ASC NULLS LAST], has_header=true", + ]; + + experiment(&expected_plan, sql).await?; + Ok(()) + } + + #[tokio::test] + async fn test_partitioned_swap() -> Result<()> { + let sql = "SELECT + o_orderkey, LAST_VALUE(l_suppkey ORDER BY l_orderkey) AS amount_usd + FROM + orders, + lineitem + WHERE + o_orderdate = l_shipdate + AND l_orderkey < o_orderkey - 10 + AND l_returnflag = 'R' + GROUP BY o_orderkey"; + + let expected_plan = [ + "ProjectionExec: expr=[o_orderkey@0 as o_orderkey, LAST_VALUE(lineitem.l_suppkey) ORDER BY [lineitem.l_orderkey ASC NULLS LAST]@1 as amount_usd]", + " AggregateExec: mode=Single, gby=[o_orderkey@0 as o_orderkey], aggr=[LAST_VALUE(lineitem.l_suppkey)], ordering_mode=FullyOrdered", + " ProjectionExec: expr=[o_orderkey@0 as o_orderkey, l_orderkey@2 as l_orderkey, l_suppkey@3 as l_suppkey]", + " ProjectionExec: expr=[o_orderkey@3 as o_orderkey, o_orderdate@4 as o_orderdate, l_orderkey@0 as l_orderkey, l_suppkey@1 as l_suppkey, l_shipdate@2 as l_shipdate]", + " PartitionedHashJoinExec: join_type=Inner, on=[(l_shipdate@2, o_orderdate@1)], filter=l_orderkey@1 < o_orderkey@0 - 10", + " ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_suppkey@1 as l_suppkey, l_shipdate@3 as l_shipdate]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: l_returnflag@2 = R", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/lineitem.csv]]}, projection=[l_orderkey, l_suppkey, l_returnflag, l_shipdate], infinite_source=true, output_ordering=[l_orderkey@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/orders.csv]]}, projection=[o_orderkey, o_orderdate], infinite_source=true, output_ordering=[o_orderkey@0 ASC NULLS LAST], has_header=true", + ]; + + experiment(&expected_plan, sql).await?; + Ok(()) + } + + #[tokio::test] + async fn test_usual_swap() -> Result<()> { + let sql = "SELECT + o_orderkey, LAST_VALUE(l_suppkey ORDER BY l_orderkey) AS amount_usd + FROM + orders, + lineitem + WHERE + o_orderdate = l_shipdate + AND l_orderkey < o_orderkey - 10 + AND l_orderkey > o_orderkey + 10 + AND l_returnflag = 'R' + GROUP BY o_orderkey"; + + let expected_plan = [ + "ProjectionExec: expr=[o_orderkey@0 as o_orderkey, LAST_VALUE(lineitem.l_suppkey) ORDER BY [lineitem.l_orderkey ASC NULLS LAST]@1 as amount_usd]", + " AggregateExec: mode=Single, gby=[o_orderkey@0 as o_orderkey], aggr=[LAST_VALUE(lineitem.l_suppkey)], ordering_mode=FullyOrdered", + " ProjectionExec: expr=[o_orderkey@0 as o_orderkey, l_orderkey@2 as l_orderkey, l_suppkey@3 as l_suppkey]", + " ProjectionExec: expr=[o_orderkey@3 as o_orderkey, o_orderdate@4 as o_orderdate, l_orderkey@0 as l_orderkey, l_suppkey@1 as l_suppkey, l_shipdate@2 as l_shipdate]", + " SlidingHashJoinExec: join_type=Inner, on=[(l_shipdate@2, o_orderdate@1)], filter=l_orderkey@1 < o_orderkey@0 - 10 AND l_orderkey@1 > o_orderkey@0 + 10", + " ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_suppkey@1 as l_suppkey, l_shipdate@3 as l_shipdate]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: l_returnflag@2 = R", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/lineitem.csv]]}, projection=[l_orderkey, l_suppkey, l_returnflag, l_shipdate], infinite_source=true, output_ordering=[l_orderkey@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/orders.csv]]}, projection=[o_orderkey, o_orderdate], infinite_source=true, output_ordering=[o_orderkey@0 ASC NULLS LAST], has_header=true", + ]; + + experiment(&expected_plan, sql).await?; + Ok(()) + } + + #[tokio::test] + async fn test_usual_swap_2() -> Result<()> { + let sql = "SELECT + o_orderkey, AVG(l_suppkey) AS amount_usd + FROM orders + LEFT JOIN lineitem + ON + o_orderdate = l_shipdate + AND l_orderkey < o_orderkey - 10 + AND l_orderkey > o_orderkey + 10 + AND l_returnflag = 'R' + GROUP BY o_orderkey + ORDER BY o_orderkey"; + + let expected_plan = [ + "ProjectionExec: expr=[o_orderkey@0 as o_orderkey, AVG(lineitem.l_suppkey)@1 as amount_usd]", + " AggregateExec: mode=Single, gby=[o_orderkey@0 as o_orderkey], aggr=[AVG(lineitem.l_suppkey)], ordering_mode=FullyOrdered", + " ProjectionExec: expr=[o_orderkey@0 as o_orderkey, l_suppkey@3 as l_suppkey]", + " ProjectionExec: expr=[o_orderkey@3 as o_orderkey, o_orderdate@4 as o_orderdate, l_orderkey@0 as l_orderkey, l_suppkey@1 as l_suppkey, l_shipdate@2 as l_shipdate]", + " SlidingHashJoinExec: join_type=Right, on=[(l_shipdate@2, o_orderdate@1)], filter=l_orderkey@1 < o_orderkey@0 - 10 AND l_orderkey@1 > o_orderkey@0 + 10", + " ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_suppkey@1 as l_suppkey, l_shipdate@3 as l_shipdate]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: l_returnflag@2 = R", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/lineitem.csv]]}, projection=[l_orderkey, l_suppkey, l_returnflag, l_shipdate], infinite_source=true, output_ordering=[l_orderkey@0 ASC NULLS LAST], has_header=true", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/orders.csv]]}, projection=[o_orderkey, o_orderdate], infinite_source=true, output_ordering=[o_orderkey@0 ASC NULLS LAST], has_header=true", + ]; + + experiment(&expected_plan, sql).await?; + Ok(()) + } + + #[tokio::test] + async fn test_unified_approach() -> Result<()> { + let sql = "SELECT + o_orderkey + FROM orders + JOIN region + ON + r_comment = o_comment + GROUP BY o_orderkey + ORDER BY o_orderkey"; + + let expected_plan = [ + "AggregateExec: mode=Single, gby=[o_orderkey@0 as o_orderkey], aggr=[], ordering_mode=FullyOrdered", + " ProjectionExec: expr=[o_orderkey@0 as o_orderkey]", + " ProjectionExec: expr=[o_orderkey@1 as o_orderkey, o_comment@2 as o_comment, r_comment@0 as r_comment]", + " CoalesceBatchesExec: target_batch_size=8192", + " HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(r_comment@0, o_comment@1)]", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/region.csv]]}, projection=[r_comment], has_header=true", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/orders.csv]]}, projection=[o_orderkey, o_comment], infinite_source=true, output_ordering=[o_orderkey@0 ASC NULLS LAST], has_header=true", + ]; + + experiment(&expected_plan, sql).await?; + Ok(()) + } + + #[tokio::test] + async fn test_unified_approach_no_order_req() -> Result<()> { + let sql = "SELECT + o_orderkey, l_suppkey + FROM orders + LEFT JOIN lineitem + ON + o_orderdate = l_shipdate + AND l_orderkey < o_orderkey - 10 + AND l_orderkey > o_orderkey + 10 + AND l_returnflag = 'R'"; + + // TODO; Which one to use? SymmetricHashJoin or SlidingHashJoin? + let expected_plan = [ + "ProjectionExec: expr=[o_orderkey@0 as o_orderkey, l_suppkey@3 as l_suppkey]", + " SlidingHashJoinExec: join_type=Left, on=[(o_orderdate@1, l_shipdate@2)], filter=l_orderkey@1 < o_orderkey@0 - 10 AND l_orderkey@1 > o_orderkey@0 + 10", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/orders.csv]]}, projection=[o_orderkey, o_orderdate], infinite_source=true, output_ordering=[o_orderkey@0 ASC NULLS LAST], has_header=true", + " ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_suppkey@1 as l_suppkey, l_shipdate@3 as l_shipdate]", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: l_returnflag@2 = R", + " CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/lineitem.csv]]}, projection=[l_orderkey, l_suppkey, l_returnflag, l_shipdate], infinite_source=true, output_ordering=[l_orderkey@0 ASC NULLS LAST], has_header=true", + ]; + + experiment(&expected_plan, sql).await?; + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 8ba14f601c7c..b5206ae602c8 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -25,24 +25,23 @@ use std::sync::Arc; +use super::join_pipeline_selection::select_joins_to_preserve_pipeline; + use crate::config::ConfigOptions; +use crate::datasource::physical_plan::is_plan_streaming; use crate::error::Result; -use crate::physical_optimizer::pipeline_checker::PipelineStatePropagator; +use crate::physical_optimizer::join_pipeline_selection::{cost_of_the_plan, PlanState}; use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::joins::{ - CrossJoinExec, HashJoinExec, PartitionMode, StreamJoinPartitionMode, - SymmetricHashJoinExec, -}; +use crate::physical_plan::joins::utils::swap_filter; +use crate::physical_plan::joins::{CrossJoinExec, HashJoinExec, PartitionMode}; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::ExecutionPlan; -use datafusion_common::internal_err; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{DataFusionError, JoinType}; -use datafusion_physical_plan::joins::prunability::separate_columns_of_filter_expression; -use datafusion_physical_plan::joins::utils::{ - swap_join_filter, swap_join_type, swap_reverting_projection, -}; +use arrow_schema::Schema; +use datafusion_common::tree_node::TreeNode; +use datafusion_common::{internal_err, DataFusionError, JoinType}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::PhysicalExpr; /// The [`JoinSelection`] rule tries to modify a given plan so that it can /// accommodate infinite sources and optimize joins in the plan according to @@ -120,6 +119,21 @@ fn supports_swap(join_type: JoinType) -> bool { ) } +/// This function returns the new join type we get after swapping the given +/// join's inputs. +pub(crate) fn swap_join_type(join_type: JoinType) -> JoinType { + match join_type { + JoinType::Inner => JoinType::Inner, + JoinType::Full => JoinType::Full, + JoinType::Left => JoinType::Right, + JoinType::Right => JoinType::Left, + JoinType::LeftSemi => JoinType::RightSemi, + JoinType::RightSemi => JoinType::LeftSemi, + JoinType::LeftAnti => JoinType::RightAnti, + JoinType::RightAnti => JoinType::LeftAnti, + } +} + /// This function swaps the inputs of the given join operator. fn swap_hash_join( hash_join: &HashJoinExec, @@ -128,14 +142,14 @@ fn swap_hash_join( let left = hash_join.left(); let right = hash_join.right(); let new_join = HashJoinExec::try_new( - Arc::clone(right), - Arc::clone(left), + right.clone(), + left.clone(), hash_join .on() .iter() .map(|(l, r)| (r.clone(), l.clone())) .collect(), - swap_join_filter(hash_join.filter()), + hash_join.filter().map(swap_filter), &swap_join_type(*hash_join.join_type()), partition_mode, hash_join.null_equals_null(), @@ -158,44 +172,50 @@ fn swap_hash_join( } } +/// When the order of the join is changed by the optimizer, the columns in +/// the output should not be impacted. This function creates the expressions +/// that will allow to swap back the values from the original left as the first +/// columns and those on the right next. +pub(crate) fn swap_reverting_projection( + left_schema: &Schema, + right_schema: &Schema, +) -> Vec<(Arc, String)> { + let right_cols = right_schema + .fields() + .iter() + .enumerate() + .map(|(i, f)| (Arc::new(Column::new(f.name(), i)) as _, f.name().to_owned())); + let right_len = right_cols.len(); + let left_cols = left_schema.fields().iter().enumerate().map(|(i, f)| { + ( + Arc::new(Column::new(f.name(), right_len + i)) as _, + f.name().to_owned(), + ) + }); + + left_cols.chain(right_cols).collect() +} + impl PhysicalOptimizerRule for JoinSelection { fn optimize( &self, plan: Arc, config: &ConfigOptions, ) -> Result> { - // First, we inspect all joins in the plan and make necessary modifications - // to preserve output orderings when necessary: - let plan_with_hash_joins = crate::physical_optimizer::join_pipeline_selection::PlanWithCorrespondingHashJoin::new(plan); - let state = plan_with_hash_joins - .transform_up(&|p| crate::physical_optimizer::join_pipeline_selection::select_joins_to_preserve_order_subrule(p, config))?; - let pipeline = PipelineStatePropagator::new(state.plan); - // Next, we make pipeline-fixing modifications to joins so as to accommodate - // unbounded inputs. Each pipeline-fixing subrule, which is a function - // of type `PipelineFixerSubrule`, takes a single [`PipelineStatePropagator`] - // argument storing state variables that indicate the unboundedness status - // of the current [`ExecutionPlan`] as we traverse the plan tree. - let subrules: Vec> = vec![ - Box::new(hash_join_convert_symmetric_subrule), - Box::new(hash_join_swap_subrule), - ]; - let state = pipeline.transform_up(&|p| apply_subrules(p, &subrules, config))?; - // Next, we apply another subrule that tries to optimize joins using any - // statistics their inputs might have. - // - For a hash join with partition mode [`PartitionMode::Auto`], we will - // make a cost-based decision to select which `PartitionMode` mode - // (`Partitioned`/`CollectLeft`) is optimal. If the statistics information - // is not available, we will fall back to [`PartitionMode::Partitioned`]. - // - We optimize/swap join sides so that the left (build) side of the join - // is the small side. If the statistics information is not available, we - // do not modify join sides. - // - We will also swap left and right sides for cross joins so that the left - // side is the small side. - let config = &config.optimizer; - let collect_left_threshold = config.hash_join_single_partition_threshold; - state.plan.transform_up(&|plan| { - statistical_join_selection_subrule(plan, collect_left_threshold) - }) + let plan_global = PlanState::new(plan.clone()); + let new_state = plan_global + .transform_up(&|p| select_joins_to_preserve_pipeline(p, config))?; + let lowest_plan = if let Some(optimized_plan) = new_state + .plans + .into_iter() + .filter(|plan| is_plan_streaming(plan).is_ok()) + .min_by(|a, b| cost_of_the_plan(a).cmp(&cost_of_the_plan(b))) + { + optimized_plan + } else { + plan + }; + Ok(lowest_plan) } fn name(&self) -> &str { @@ -305,179 +325,61 @@ fn partitioned_hash_join(hash_join: &HashJoinExec) -> Result, +/// - For a hash join with partition mode [`PartitionMode::Auto`], we will +/// make a cost-based decision to select which `PartitionMode` mode +/// (`Partitioned`/`CollectLeft`) is optimal. If the statistics information +/// is not available, we will fall back to [`PartitionMode::Partitioned`]. +/// - We optimize/swap join sides so that the left (build) side of the join +/// is the small side. If the statistics information is not available, we +/// do not modify join sides. +/// - We will also swap left and right sides for cross joins so that the left +/// side is the small side. +pub(crate) fn statistical_join_selection_hash_join( + hash_join: &HashJoinExec, collect_left_threshold: usize, -) -> Result>> { - let transformed = if let Some(hash_join) = - plan.as_any().downcast_ref::() - { - match hash_join.partition_mode() { - PartitionMode::Auto => { - try_collect_left(hash_join, Some(collect_left_threshold))?.map_or_else( - || partitioned_hash_join(hash_join).map(Some), - |v| Ok(Some(v)), - )? - } - PartitionMode::CollectLeft => try_collect_left(hash_join, None)? - .map_or_else( - || partitioned_hash_join(hash_join).map(Some), - |v| Ok(Some(v)), - )?, - PartitionMode::Partitioned => { - let left = hash_join.left(); - let right = hash_join.right(); - if should_swap_join_order(&**left, &**right)? - && supports_swap(*hash_join.join_type()) - { - swap_hash_join(hash_join, PartitionMode::Partitioned).map(Some)? - } else { - None - } +) -> Result>> { + let new_plan = match hash_join.partition_mode() { + PartitionMode::Auto => try_collect_left(hash_join, Some(collect_left_threshold))? + .map_or_else( + || partitioned_hash_join(hash_join).map(Some), + |v| Ok(Some(v)), + )?, + PartitionMode::CollectLeft => try_collect_left(hash_join, None)?.map_or_else( + || partitioned_hash_join(hash_join).map(Some), + |v| Ok(Some(v)), + )?, + PartitionMode::Partitioned => { + let left = hash_join.left(); + let right = hash_join.right(); + if should_swap_join_order(&**left, &**right) + && supports_swap(*hash_join.join_type()) + { + swap_hash_join(hash_join, PartitionMode::Partitioned).map(Some)? + } else { + None } } - } else if let Some(cross_join) = plan.as_any().downcast_ref::() { - let left = cross_join.left(); - let right = cross_join.right(); - if should_swap_join_order(&**left, &**right)? { - let new_join = CrossJoinExec::new(Arc::clone(right), Arc::clone(left)); - // TODO avoid adding ProjectionExec again and again, only adding Final Projection - let proj: Arc = Arc::new(ProjectionExec::try_new( - swap_reverting_projection(&left.schema(), &right.schema()), - Arc::new(new_join), - )?); - Some(proj) - } else { - None - } - } else { - None }; - - Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) - } else { - Transformed::No(plan) - }) + Ok(new_plan) } -/// Pipeline-fixing join selection subrule. -pub type PipelineFixerSubrule = dyn Fn( - PipelineStatePropagator, - &ConfigOptions, -) -> Option>; - -/// This subrule checks if we can replace a hash join with a symmetric hash -/// join when we are dealing with infinite inputs on both sides. This change -/// avoids pipeline breaking and preserves query runnability. If possible, -/// this subrule makes this replacement; otherwise, it has no effect. -fn hash_join_convert_symmetric_subrule( - mut input: PipelineStatePropagator, - config_options: &ConfigOptions, -) -> Option> { - if let Some(hash_join) = input.plan.as_any().downcast_ref::() { - let ub_flags = input.children_unbounded(); - let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); - input.unbounded = left_unbounded || right_unbounded; - let result = if left_unbounded && right_unbounded { - let mode = if config_options.optimizer.repartition_joins { - StreamJoinPartitionMode::Partitioned - } else { - StreamJoinPartitionMode::SinglePartition - }; - SymmetricHashJoinExec::try_new( - hash_join.left().clone(), - hash_join.right().clone(), - hash_join.on().to_vec(), - hash_join - .filter() - .map(|filter| separate_columns_of_filter_expression(filter.clone())), - hash_join.join_type(), - hash_join.null_equals_null(), - mode, - ) - .map(|exec| { - input.plan = Arc::new(exec) as _; - input - }) - } else { - Ok(input) - }; - Some(result) - } else { - None - } -} - -/// This subrule will swap build/probe sides of a hash join depending on whether -/// one of its inputs may produce an infinite stream of records. The rule ensures -/// that the left (build) side of the hash join always operates on an input stream -/// that will produce a finite set of records. If the left side can not be chosen -/// to be "finite", the join sides stay the same as the original query. -/// ```text -/// For example, this rule makes the following transformation: -/// -/// -/// -/// +--------------+ +--------------+ -/// | | unbounded | | -/// Left | Infinite | true | Hash |\true -/// | Data source |--------------| Repartition | \ +--------------+ +--------------+ -/// | | | | \ | | | | -/// +--------------+ +--------------+ - | Hash Join |-------| Projection | -/// - | | | | -/// +--------------+ +--------------+ / +--------------+ +--------------+ -/// | | unbounded | | / -/// Right | Finite | false | Hash |/false -/// | Data Source |--------------| Repartition | -/// | | | | -/// +--------------+ +--------------+ -/// -/// -/// -/// +--------------+ +--------------+ -/// | | unbounded | | -/// Left | Finite | false | Hash |\false -/// | Data source |--------------| Repartition | \ +--------------+ +--------------+ -/// | | | | \ | | true | | true -/// +--------------+ +--------------+ - | Hash Join |-------| Projection |----- -/// - | | | | -/// +--------------+ +--------------+ / +--------------+ +--------------+ -/// | | unbounded | | / -/// Right | Infinite | true | Hash |/true -/// | Data Source |--------------| Repartition | -/// | | | | -/// +--------------+ +--------------+ -/// -/// ``` -fn hash_join_swap_subrule( - mut input: PipelineStatePropagator, - _config_options: &ConfigOptions, -) -> Option> { - if let Some(hash_join) = input.plan.as_any().downcast_ref::() { - let ub_flags = input.children_unbounded(); - let (left_unbounded, right_unbounded) = (ub_flags[0], ub_flags[1]); - input.unbounded = left_unbounded || right_unbounded; - let result = if left_unbounded - && !right_unbounded - && matches!( - *hash_join.join_type(), - JoinType::Inner - | JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftAnti - ) { - swap_join_according_to_unboundedness(hash_join).map(|plan| { - input.plan = plan; - input - }) - } else { - Ok(input) - }; - Some(result) +pub(crate) fn statistical_join_selection_cross_join( + cross_join: &CrossJoinExec, +) -> Result>> { + let left = cross_join.left(); + let right = cross_join.right(); + let new_plan = if should_swap_join_order(&**left, &**right) { + let new_join = CrossJoinExec::new(Arc::clone(right), Arc::clone(left)); + // TODO avoid adding ProjectionExec again and again, only adding Final Projection + let proj: Arc = Arc::new(ProjectionExec::try_new( + swap_reverting_projection(&left.schema(), &right.schema()), + Arc::new(new_join), + )?); + Some(proj) } else { None - } + }; + Ok(new_plan) } /// This function swaps sides of a hash join to make it runnable even if one of @@ -485,7 +387,7 @@ fn hash_join_swap_subrule( /// [`JoinType::Full`], [`JoinType::Right`], [`JoinType::RightAnti`] and /// [`JoinType::RightSemi`] can not run with an unbounded left side, even if /// we swap join sides. Therefore, we do not consider them here. -fn swap_join_according_to_unboundedness( +pub(crate) fn swap_join_according_to_unboundedness( hash_join: &HashJoinExec, ) -> Result> { let partition_mode = hash_join.partition_mode(); @@ -507,32 +409,6 @@ fn swap_join_according_to_unboundedness( } } -/// Apply given `PipelineFixerSubrule`s to a given plan. This plan, along with -/// auxiliary boundedness information, is in the `PipelineStatePropagator` object. -fn apply_subrules( - mut input: PipelineStatePropagator, - subrules: &Vec>, - config_options: &ConfigOptions, -) -> Result> { - for subrule in subrules { - if let Some(value) = subrule(input.clone(), config_options).transpose()? { - input = value; - } - } - let is_unbounded = input - .plan - .unbounded_output(&input.children_unbounded()) - // Treat the case where an operator can not run on unbounded data as - // if it can and it outputs unbounded data. Do not raise an error yet. - // Such operators may be fixed, adjusted or replaced later on during - // optimization passes -- sorts may be removed, windows may be adjusted - // etc. If this doesn't happen, the final `PipelineChecker` rule will - // catch this and raise an error anyway. - .unwrap_or(true); - input.unbounded = is_unbounded; - Ok(Transformed::Yes(input)) -} - #[cfg(test)] mod tests_statistical { use std::sync::Arc; @@ -1187,6 +1063,7 @@ mod util_tests { #[cfg(test)] mod hash_join_tests { use super::*; + use crate::physical_optimizer::join_selection::swap_join_type; use crate::physical_optimizer::test_utils::SourceType; use crate::physical_plan::expressions::Column; use crate::physical_plan::joins::PartitionMode; @@ -1564,28 +1441,9 @@ mod hash_join_tests { false, )?; - let children = vec![ - PipelineStatePropagator { - plan: Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), - unbounded: left_unbounded, - children: vec![], - }, - PipelineStatePropagator { - plan: Arc::new(EmptyExec::new(Arc::new(Schema::empty()))), - unbounded: right_unbounded, - children: vec![], - }, - ]; - let initial_hash_join_state = PipelineStatePropagator { - plan: Arc::new(join), - unbounded: false, - children, - }; - - let optimized_hash_join = - hash_join_swap_subrule(initial_hash_join_state, &ConfigOptions::new()) - .unwrap()?; - let optimized_join_plan = optimized_hash_join.plan; + let optimized_join_plan = JoinSelection::new() + .optimize(Arc::new(join), &ConfigOptions::new()) + .unwrap(); // If swap did happen let projection_added = optimized_join_plan.as_any().is::(); diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index f2e53905821a..c9c54d81c239 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -397,6 +397,25 @@ pub fn prunable_filter(left_index: ColumnIndex, right_index: ColumnIndex) -> Joi ); JoinFilter::new(filter_expr, column_indices, intermediate_schema) } +pub fn partial_prunable_filter( + left_index: ColumnIndex, + right_index: ColumnIndex, +) -> JoinFilter { + // Filter columns, ensure first batches will have matching rows. + let intermediate_schema = Schema::new(vec![ + Field::new("0", DataType::Int32, true), + Field::new("1", DataType::Int32, true), + ]); + let column_indices = vec![left_index, right_index]; + let filter_expr = Arc::new(BinaryExpr::new( + col("0", &intermediate_schema).unwrap(), + Operator::Gt, + col("1", &intermediate_schema).unwrap(), + )); + + JoinFilter::new(filter_expr, column_indices, intermediate_schema) +} + pub fn not_prunable_filter( left_index: ColumnIndex, right_index: ColumnIndex, diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index 255c5053f597..db4efece9287 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -21,8 +21,10 @@ use std::fmt; use std::fmt::Formatter; use std::sync::Arc; +use crate::error::Result; +use crate::physical_plan::aggregates::AggregateExec; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; -use crate::physical_plan::joins::{HashJoinExec, NestedLoopJoinExec}; +use crate::physical_plan::joins::{CrossJoinExec, HashJoinExec, NestedLoopJoinExec}; use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; @@ -164,3 +166,13 @@ pub fn is_hash_join(plan: &Arc) -> bool { pub fn is_nested_loop_join(plan: &Arc) -> bool { plan.as_any().is::() } + +/// Checks whether the given operator is a [`CrossJoinExec`]. +pub fn is_cross_join(plan: &Arc) -> bool { + plan.as_any().is::() +} + +/// Checks whether the given operator is an [`AggregateExec`]. +pub fn is_aggregate(plan: &Arc) -> bool { + plan.as_any().is::() +} diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 528bde632355..9daf544e0337 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -125,13 +125,13 @@ async fn join_change_in_planner() -> Result<()> { let formatted = displayable(physical_plan.as_ref()).indent(true).to_string(); let expected = { [ - "SymmetricHashJoinExec: mode=Partitioned, join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10", + "SlidingHashJoinExec: join_type=Full, on=[(a2@1, a2@1)], filter=CAST(a1@0 AS Int64) > CAST(a1@1 AS Int64) + 3 AND CAST(a1@0 AS Int64) < CAST(a1@1 AS Int64) + 10", " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8", + " SortPreservingRepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", // " CsvExec: file_groups={1 group: [[tempdir/left.csv]]}, projection=[a1, a2], has_header=false", " CoalesceBatchesExec: target_batch_size=8192", - " RepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8", + " SortPreservingRepartitionExec: partitioning=Hash([a2@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", // " CsvExec: file_groups={1 group: [[tempdir/right.csv]]}, projection=[a1, a2], has_header=false" ] diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index a7697cb2d1c5..007f9ea21913 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -20,6 +20,7 @@ pub use cross_join::CrossJoinExec; pub use hash_join::HashJoinExec; pub use nested_loop_join::NestedLoopJoinExec; +pub use partitioned_hash_join::PartitionedHashJoinExec; pub use sliding_hash_join::{swap_sliding_hash_join, SlidingHashJoinExec}; pub use sliding_nested_loop_join::{ swap_sliding_nested_loop_join, SlidingNestedLoopJoinExec, @@ -33,6 +34,7 @@ pub mod utils; mod cross_join; mod hash_join; mod nested_loop_join; +mod partitioned_hash_join; mod sliding_hash_join; mod sliding_nested_loop_join; mod sliding_window_join_utils; diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 5b7cbad1664b..766abaada9b9 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -191,6 +191,22 @@ impl ExecutionPlan for NestedLoopJoinExec { distribution_from_join_type(&self.join_type) } + fn unbounded_output(&self, children: &[bool]) -> Result { + let (left, right) = (children[0], children[1]); + if left || right { + plan_err!( + "Join Error: The join with cannot be executed with unbounded inputs. {}", + if left && right { + "Currently, we do not support unbounded inputs on both sides." + } else { + "Please consider a different type of join." + } + ) + } else { + Ok(false) + } + } + fn equivalence_properties(&self) -> EquivalenceProperties { join_equivalence_properties( self.left.equivalence_properties(), @@ -744,21 +760,15 @@ mod tests { use std::sync::Arc; use super::*; - use crate::{ - common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, - test::build_table_i32, - }; + use crate::joins::test_utils::partitioned_nested_join_with_filter; + use crate::joins::utils::JoinSide; + use crate::{expressions::Column, memory::MemoryExec, repartition::RepartitionExec, test::build_table_i32}; use arrow::datatypes::{DataType, Field}; use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue}; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; - - use crate::physical_plan::joins::test_utils::partitioned_nested_join_with_filter; - use crate::joins::utils::JoinSide; - use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue}; - use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::PhysicalExpr; fn build_table( @@ -831,62 +841,13 @@ mod tests { JoinFilter::new(filter_expression, column_indices, intermediate_schema) } - async fn multi_partitioned_join_collect( - left: Arc, - right: Arc, - join_type: &JoinType, - join_filter: Option, - context: Arc, - ) -> Result<(Vec, Vec)> { - let partition_count = 4; - let mut output_partition = 1; - let distribution = distribution_from_join_type(join_type); - // left - let left = if matches!(distribution[0], Distribution::SinglePartition) { - left - } else { - output_partition = partition_count; - Arc::new(RepartitionExec::try_new( - left, - Partitioning::RoundRobinBatch(partition_count), - )?) - } as Arc; - - let right = if matches!(distribution[1], Distribution::SinglePartition) { - right - } else { - output_partition = partition_count; - Arc::new(RepartitionExec::try_new( - right, - Partitioning::RoundRobinBatch(partition_count), - )?) - } as Arc; - - // Use the required distribution for nested loop join to test partition data - let nested_loop_join = - NestedLoopJoinExec::try_new(left, right, join_filter, join_type)?; - let columns = columns(&nested_loop_join.schema()); - let mut batches = vec![]; - for i in 0..output_partition { - let stream = nested_loop_join.execute(i, context.clone())?; - let more_batches = common::collect(stream).await?; - batches.extend( - more_batches - .into_iter() - .filter(|b| b.num_rows() > 0) - .collect::>(), - ); - } - Ok((columns, batches)) - } - #[tokio::test] async fn join_inner_with_filter() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let left = build_left_table(); let right = build_right_table(); let filter = prepare_join_filter(); - let (columns, batches) = multi_partitioned_join_collect( + let (columns, batches) = partitioned_nested_join_with_filter( left, right, &JoinType::Inner, @@ -915,7 +876,7 @@ mod tests { let right = build_right_table(); let filter = prepare_join_filter(); - let (columns, batches) = multi_partitioned_join_collect( + let (columns, batches) = partitioned_nested_join_with_filter( left, right, &JoinType::Left, @@ -946,7 +907,7 @@ mod tests { let right = build_right_table(); let filter = prepare_join_filter(); - let (columns, batches) = multi_partitioned_join_collect( + let (columns, batches) = partitioned_nested_join_with_filter( left, right, &JoinType::Right, @@ -977,7 +938,7 @@ mod tests { let right = build_right_table(); let filter = prepare_join_filter(); - let (columns, batches) = multi_partitioned_join_collect( + let (columns, batches) = partitioned_nested_join_with_filter( left, right, &JoinType::Full, @@ -1010,7 +971,7 @@ mod tests { let right = build_right_table(); let filter = prepare_join_filter(); - let (columns, batches) = multi_partitioned_join_collect( + let (columns, batches) = partitioned_nested_join_with_filter( left, right, &JoinType::LeftSemi, @@ -1039,7 +1000,7 @@ mod tests { let right = build_right_table(); let filter = prepare_join_filter(); - let (columns, batches) = multi_partitioned_join_collect( + let (columns, batches) = partitioned_nested_join_with_filter( left, right, &JoinType::LeftAnti, @@ -1069,7 +1030,7 @@ mod tests { let right = build_right_table(); let filter = prepare_join_filter(); - let (columns, batches) = multi_partitioned_join_collect( + let (columns, batches) = partitioned_nested_join_with_filter( left, right, &JoinType::RightSemi, @@ -1098,7 +1059,7 @@ mod tests { let right = build_right_table(); let filter = prepare_join_filter(); - let (columns, batches) = multi_partitioned_join_collect( + let (columns, batches) = partitioned_nested_join_with_filter( left, right, &JoinType::RightAnti, @@ -1152,7 +1113,7 @@ mod tests { let task_ctx = TaskContext::default().with_runtime(runtime); let task_ctx = Arc::new(task_ctx); - let err = multi_partitioned_join_collect( + let err = partitioned_nested_join_with_filter( left.clone(), right.clone(), &join_type, @@ -1171,9 +1132,4 @@ mod tests { Ok(()) } - - /// Returns the column names on the schema - fn columns(schema: &Schema) -> Vec { - schema.fields().iter().map(|f| f.name().clone()).collect() - } } diff --git a/datafusion/physical-plan/src/joins/partitioned_hash_join.rs b/datafusion/physical-plan/src/joins/partitioned_hash_join.rs new file mode 100644 index 000000000000..c199883e0c9a --- /dev/null +++ b/datafusion/physical-plan/src/joins/partitioned_hash_join.rs @@ -0,0 +1,2083 @@ +// Copyright (C) Synnada, Inc. - All Rights Reserved. +// This file does not contain any Apache Software Foundation copyrighted code. + +use std::any::Any; +use std::fmt::Formatter; +use std::sync::Arc; +use std::task::Poll; +use std::{fmt, mem}; + +use crate::common::SharedMemoryReservation; +use crate::joins::hash_join::equal_rows_arr; +use crate::joins::hash_join_utils::SortedFilterExpr; +use crate::joins::sliding_window_join_utils::{ + calculate_the_necessary_build_side_range, check_if_sliding_window_condition_is_met, + get_probe_batch, is_batch_suitable_interval_calculation, +}; +use crate::joins::utils::{ + apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, + calculate_join_output_ordering, check_join_is_valid, + combine_join_equivalence_properties, combine_join_ordering_equivalence_properties, + get_filter_representation_of_build_side, partitioned_join_output_partitioning, + prepare_sorted_exprs, ColumnIndex, JoinFilter, JoinOn, JoinSide, +}; +use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use crate::{ + metrics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, +}; + +use arrow::compute::concat_batches; +use arrow_array::builder::{UInt32Builder, UInt64Builder}; +use arrow_array::{ArrayRef, RecordBatch, UInt32Array, UInt64Array}; +use arrow_schema::{Field, Schema, SchemaRef}; +use datafusion_common::utils::{ + get_record_batch_at_indices, get_row_at_idx, linear_search, +}; +use datafusion_common::{ + internal_err, DataFusionError, JoinType, Result, ScalarValue, Statistics, +}; +use datafusion_execution::memory_pool::MemoryConsumer; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::hash_utils::create_hashes; +use datafusion_physical_expr::intervals::{ExprIntervalGraph, Interval}; +use datafusion_physical_expr::window::PartitionKey; +use datafusion_physical_expr::{ + EquivalenceProperties, OrderingEquivalenceProperties, PhysicalExpr, PhysicalSortExpr, + PhysicalSortRequirement, +}; + +use ahash::RandomState; +use futures::{ready, Stream, StreamExt}; +use hashbrown::raw::RawTable; +use parking_lot::Mutex; + +/// This `enum` encapsulates the different states that a join stream might be +/// in throughout its execution. Depending on its current state, the join +/// operation will perform different actions such as pulling data from the build +/// side or the probe side, or performing the join itself. +pub enum PartitionedHashJoinStreamState { + /// The action is to pull data from the probe side (right stream). + /// This state continues to pull data until the probe batches are suitable + /// for interval calculations, or the probe stream is exhausted. + PullProbe, + /// The action is to pull data from the build side (left stream) within a + /// given interval. + /// This state continues to pull data until a suitable range of batches is + /// found, or the build stream is exhausted. + PullBuild { + interval: Vec<(PhysicalSortExpr, Interval)>, + }, + /// The probe side is completely processed. In this state, the build side + /// will be ready and its results will be processed until the build stream + /// is also exhausted. + ProbeExhausted, + /// The join operation is actively processing data from both sides to produce + /// the result. It also contains build side intervals to correctly prune the partitioned + /// buffers. + Join { + interval: Vec<(PhysicalSortExpr, Interval)>, + }, +} + +/// Represents a partitioned hash join execution plan. +/// +/// The `PartitionedHashJoinExec` struct facilitates the execution of hash join operations in +/// parallel across multiple partitions of data. It takes two input streams (`left` and `right`), +/// a set of common columns to join on (`on`), and applies a join filter to find matching rows. +/// The type of the join (e.g., inner, left, right) is determined by `join_type`. +/// +/// A hash join operation builds a hash table on the "build" side (the left side in this implementation) +/// using a `BuildBuffer` to segment and hash rows. The hash table is then probed with rows from the "probe" side +/// (the right side) to find matches based on the common columns and join filter. +/// +/// The resulting schema after the join is represented by `schema`. +/// +/// The struct also maintains several other properties and metrics for efficient execution and monitoring +/// of the join operation. +#[derive(Debug)] +pub struct PartitionedHashJoinExec { + /// Left side stream + pub(crate) left: Arc, + /// Right side stream + pub(crate) right: Arc, + /// Set of common columns used to join on + pub(crate) on: Vec<(Column, Column)>, + /// Filters applied when finding matching rows + pub(crate) filter: JoinFilter, + /// How the join is performed + pub(crate) join_type: JoinType, + /// The schema once the join is applied + schema: SchemaRef, + /// Shares the `RandomState` for the hashing algorithm + random_state: RandomState, + /// Information of index and left / right placement of columns + column_indices: Vec, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// If null_equals_null is true, null == null else null != null + pub(crate) null_equals_null: bool, + /// The left SortExpr + left_sort_exprs: Vec, + /// The right SortExpr + right_sort_exprs: Vec, + /// The output ordering + output_ordering: Option>, + /// Fetch per key + fetch_per_key: usize, +} + +/// This object encapsulates metrics pertaining to a single input (i.e. side) +/// of the operator `PartitionedHashJoinExec`. +#[derive(Debug)] +struct PartitionedHashJoinSideMetrics { + /// Number of batches consumed by this operator + input_batches: metrics::Count, + /// Number of rows consumed by this operator + input_rows: metrics::Count, +} + +/// Metrics for operator `PartitionedHashJoinExec`. +#[derive(Debug)] +struct PartitionedHashJoinMetrics { + /// Number of build batches/rows consumed by this operator + build: PartitionedHashJoinSideMetrics, + /// Number of probe batches/rows consumed by this operator + probe: PartitionedHashJoinSideMetrics, + /// Memory used by sides in bytes + pub(crate) stream_memory_usage: metrics::Gauge, + /// Number of batches produced by this operator + output_batches: metrics::Count, + /// Number of rows produced by this operator + output_rows: metrics::Count, +} + +impl PartitionedHashJoinMetrics { + // Creates a new `PartitionedHashJoinMetrics` object according to the + // given number of partitions and the metrics set. + pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let build = PartitionedHashJoinSideMetrics { + input_batches, + input_rows, + }; + + let input_batches = + MetricBuilder::new(metrics).counter("input_batches", partition); + let input_rows = MetricBuilder::new(metrics).counter("input_rows", partition); + let probe = PartitionedHashJoinSideMetrics { + input_batches, + input_rows, + }; + + let stream_memory_usage = + MetricBuilder::new(metrics).gauge("stream_memory_usage", partition); + + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + + let output_rows = MetricBuilder::new(metrics).output_rows(partition); + + Self { + build, + probe, + output_batches, + stream_memory_usage, + output_rows, + } + } +} + +/// State for each unique partition determined according to key column(s) +#[derive(Debug)] +pub struct PartitionBatchState { + /// The record_batch belonging to current partition + pub record_batch: RecordBatch, + /// Matched indices count + pub matched_indices: usize, +} + +impl PartitionedHashJoinExec { + /// Attempts to create a new `PartitionedHashJoinExec` instance. + /// + /// * `left`: Left side stream. + /// * `right`: Right side stream. + /// * `on`: Set of common columns used to join on. + /// * `filter`: Filters applied when finding matching rows. + /// * `join_type`: How the join is performed. + /// * `null_equals_null`: If true, null == null; otherwise, null != null. + /// * `left_sort_exprs`: The left SortExpr. + /// * `right_sort_exprs`: The right SortExpr. + /// * `fetch_per_key`: Fetch per key. + /// + /// Returns a result containing the created `PartitionedHashJoinExec` instance or an error. + #[allow(clippy::too_many_arguments)] + pub fn try_new( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + join_type: &JoinType, + null_equals_null: bool, + left_sort_exprs: Vec, + right_sort_exprs: Vec, + fetch_per_key: usize, + ) -> Result { + let left_fields: Result> = left + .schema() + .fields() + .iter() + .map(|field| { + let mut metadata = field.metadata().clone(); + let mut new_field = Field::new( + field.name(), + field.data_type().clone(), + field.is_nullable(), + ); + metadata + .insert("PartitionedHashJoinExec".into(), "JoinSide::Left".into()); + new_field.set_metadata(metadata); + Ok(new_field) + }) + .collect(); + let left_schema = Arc::new(Schema::new_with_metadata( + left_fields?, + left.schema().metadata().clone(), + )); + + let right_fields: Result> = right + .schema() + .fields() + .iter() + .map(|field| { + let mut metadata = field.metadata().clone(); + let mut new_field = Field::new( + field.name(), + field.data_type().clone(), + field.is_nullable(), + ); + metadata + .insert("PartitionedHashJoinExec".into(), "JoinSide::Right".into()); + new_field.set_metadata(metadata); + Ok(new_field) + }) + .collect(); + let right_schema = Arc::new(Schema::new_with_metadata( + right_fields?, + right.schema().metadata().clone(), + )); + + if on.is_empty() { + return Err(DataFusionError::Plan( + "On constraints in PartitionedHashJoinExec should be non-empty" + .to_string(), + )); + } + + if matches!( + join_type, + JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::Full + | JoinType::Left + | JoinType::RightSemi + | JoinType::RightAnti + ) { + return Err(DataFusionError::NotImplemented(format!( + "PartitionedHashJoinExec does not support {}", + join_type + ))); + } + + check_join_is_valid(&left_schema, &right_schema, &on)?; + + // Initialize the random state for the join operation: + let random_state = RandomState::with_seeds(0, 0, 0, 0); + + let (schema, column_indices) = + build_join_schema(&left_schema, &right_schema, join_type); + + let output_ordering = calculate_join_output_ordering( + &left_sort_exprs, + &right_sort_exprs, + *join_type, + &on, + left_schema.fields.len(), + &Self::maintains_input_order(*join_type), + Some(JoinSide::Right), + )?; + + Ok(Self { + left, + right, + on, + filter, + join_type: *join_type, + schema: Arc::new(schema), + random_state, + column_indices, + metrics: ExecutionPlanMetricsSet::new(), + null_equals_null, + left_sort_exprs, + right_sort_exprs, + output_ordering, + fetch_per_key, + }) + } + + /// Get probe side information for the hash join. + pub fn probe_side() -> JoinSide { + // In current implementation right side is always probe side. + JoinSide::Right + } + + /// Calculate order preservation flags for this join. + fn maintains_input_order(join_type: JoinType) -> Vec { + vec![ + false, + matches!(join_type, JoinType::Inner | JoinType::Right), + ] + } + + /// left (build) side which gets hashed + pub fn left(&self) -> &Arc { + &self.left + } + + /// right (probe) side which are filtered by the hash table + pub fn right(&self) -> &Arc { + &self.right + } + + /// Set of common columns used to join on + pub fn on(&self) -> &[(Column, Column)] { + &self.on + } + + /// Filters applied before join output + pub fn filter(&self) -> &JoinFilter { + &self.filter + } + + /// How the join is performed + pub fn join_type(&self) -> &JoinType { + &self.join_type + } + + /// Get null_equals_null + pub fn null_equals_null(&self) -> bool { + self.null_equals_null + } + + /// Get left_sort_exprs + pub fn left_sort_exprs(&self) -> &Vec { + &self.left_sort_exprs + } + + /// Get right_sort_exprs + pub fn right_sort_exprs(&self) -> &Vec { + &self.right_sort_exprs + } +} + +fn dyn_eq_with_null_support( + lhs: &ScalarValue, + rhs: &ScalarValue, + null_equals_null: bool, +) -> bool { + match (lhs.is_null(), rhs.is_null()) { + (false, false) => lhs.eq(rhs), + (true, true) => null_equals_null, + _ => false, + } +} +/// Represents a buffer used in the "build" phase of a partitioned hash join operation. +/// +/// During the execution of a hash join, the `BuildBuffer` is responsible for segmenting and hashing rows from +/// the "build" side (the left side in the context of a partitioned hash join). It uses hash maps +/// to store unique partitions of the data based on common columns (used for joining), which facilitates +/// efficient lookups during the "probe" phase. +/// +/// The buffer maintains two primary hash maps: +/// - `row_map_batch` maps a hash value of a row to a unique partition ID. This map is used to quickly find +/// which partition a row belongs to based on its hash value. +/// - `join_hash_map` stores the actual data of the partitions, where each entry contains a key, its hash value, +/// and a batch of data corresponding to that key. +/// +/// The `BuildBuffer` also includes several utility methods to evaluate and prune partitions based on various +/// criteria, such as filters and join conditions. +struct BuildBuffer { + /// We use this [`RawTable`] to calculate unique partitions for each new + /// RecordBatch. First entry in the tuple is the hash value, the second + /// entry is the unique ID for each partition (increments from 0 to n). + row_map_batch: RawTable<(u64, usize)>, + /// We use this [`RawTable`] to hold partitions for each key. + join_hash_map: RawTable<(PartitionKey, u64, PartitionBatchState)>, + /// Used for interval calculations + latest_batch: RecordBatch, + /// Set of common columns used to join on + pub(crate) on: Vec, +} + +impl BuildBuffer { + pub fn new(schema: SchemaRef, on: Vec) -> Self { + Self { + latest_batch: RecordBatch::new_empty(schema), + on, + row_map_batch: RawTable::with_capacity(0), + join_hash_map: RawTable::with_capacity(0), + } + } + + pub fn size(&self) -> usize { + let mut size = 0; + size += mem::size_of_val(self); + size += self.row_map_batch.allocation_info().1.size(); + size += self.join_hash_map.allocation_info().1.size(); + size += mem::size_of_val(&self.on); + size += self.latest_batch.get_array_memory_size(); + size + } + + /// Determines per-partition indices based on the given columns and record batch. + /// + /// This function first computes hash values for each row in the batch based on the given columns. + /// It then maps these hash values to partition keys and groups row indices by partition. + /// This helps in grouping rows that belong to the same partition. + /// + /// # Arguments + /// * `random_state`: State to maintain reproducible randomization for hashing. + /// * `columns`: Arrays representing the columns that define partitions. + /// * `batch`: Record batch containing the rows to be partitioned. + /// * `null_equals_null`: Determines whether null values should be treated as equal. + /// + /// # Returns + /// * A vector containing tuples with partition keys, associated hash values, + /// and row indices for rows in each partition. + fn get_per_partition_indices( + &mut self, + random_state: &RandomState, + columns: &[ArrayRef], + batch: &RecordBatch, + null_equals_null: bool, + ) -> Result)>> { + let mut batch_hashes = vec![0; batch.num_rows()]; + create_hashes(columns, random_state, &mut batch_hashes)?; + // reset row_map for new calculation + self.row_map_batch.clear(); + // res stores PartitionKey and row indices (indices where these partition occurs in the `batch`) for each partition. + let mut result: Vec<(PartitionKey, u64, Vec)> = vec![]; + for (hash, row_idx) in batch_hashes.into_iter().zip(0u32..) { + let entry = self.row_map_batch.get_mut(hash, |(_, group_idx)| { + // We can safely get the first index of the partition indices + // since partition indices has one element during initialization. + let row = get_row_at_idx(columns, row_idx as usize).unwrap(); + // Handle hash collusions with an equality check: + row.eq(&result[*group_idx].0) + }); + if let Some((_, group_idx)) = entry { + result[*group_idx].2.push(row_idx) + } else { + self.row_map_batch + .insert(hash, (hash, result.len()), |(hash, _)| *hash); + let row = get_row_at_idx(columns, row_idx as usize)?; + // If null_equals_null is true, we do not stop adding the rows. + // If null_equals_null is false, we ensure that row does not contains a null value + // since it is not joinable to anything. + if null_equals_null || row.iter().all(|s| !s.is_null()) { + // This is a new partition its only index is row_idx for now. + result.push((row, hash, vec![row_idx])); + } + } + } + Ok(result) + } + + /// Evaluates partitions within a build batch. + /// + /// This function calculates the partitioned indices for the build batch rows and + /// constructs new record batches using these indices. These new record batches represent + /// partitioned subsets of the original build batch. + /// + /// # Arguments + /// * `build_batch`: The probe record batch to be partitioned. + /// * `random_state`: State to maintain reproducible randomization for hashing. + /// * `null_equals_null`: Determines whether null values should be treated as equal. + /// + /// # Returns + /// * A vector containing tuples with partition keys, associated hash values, + /// and the partitioned record batches. + fn evaluate_partition_batches( + &mut self, + build_batch: &RecordBatch, + random_state: &RandomState, + null_equals_null: bool, + ) -> Result> { + let columns = self + .on + .iter() + .map(|c| Ok(c.evaluate(build_batch)?.into_array(build_batch.num_rows()))) + .collect::>>()?; + // Calculate indices for each partition and construct a new record + // batch from the rows at these indices for each partition: + self.get_per_partition_indices( + random_state, + &columns, + build_batch, + null_equals_null, + )? + .into_iter() + .map(|(row, hash, indices)| { + let mut new_indices = UInt32Builder::with_capacity(indices.len()); + new_indices.append_slice(&indices); + let indices = new_indices.finish(); + Ok(( + row, + hash, + get_record_batch_at_indices(build_batch, &indices)?, + )) + }) + .collect() + } + + /// Updates the latest batch and associated partition buffers with a new build record batch. + /// + /// If the new record batch contains rows, it evaluates the partition batches for + /// these rows and updates the `join_hash_map` with the resulting partitioned record batches. + /// + /// # Arguments + /// * `record_batch`: New record batch to update the current state. + /// * `random_state`: State to maintain reproducible randomization for hashing. + /// * `null_equals_null`: Determines whether null values should be treated as equal. + /// + /// # Returns + /// * A `Result` indicating the success or failure of the update operation. + fn update_partition_batch( + &mut self, + build_batch: &RecordBatch, + random_state: &RandomState, + null_equals_null: bool, + ) -> Result<()> { + if build_batch.num_rows() > 0 { + let partition_batches = self.evaluate_partition_batches( + build_batch, + random_state, + null_equals_null, + )?; + for (partition_row, partition_hash, partition_batch) in partition_batches { + let item = self + .join_hash_map + .get_mut(partition_hash, |(_, hash, _)| *hash == partition_hash); + if let Some((_, _, partition_batch_state)) = item { + partition_batch_state.record_batch = concat_batches( + &partition_batch.schema(), + [&partition_batch_state.record_batch, &partition_batch], + )?; + } else { + self.join_hash_map.insert( + partition_hash, + // store the value + 1 as 0 value reserved for end of list + ( + partition_row, + partition_hash, + PartitionBatchState { + record_batch: partition_batch, + matched_indices: 0, + }, + ), + |(_, hash, _)| *hash, + ); + } + } + self.latest_batch = build_batch.clone(); + } + Ok(()) + } + + /// Prunes the record batches within the join hash map based on the specified filter and build expressions. + /// + /// This function leverages a pruning strategy, which aims to reduce the number of rows processed by the join + /// by filtering out rows that are determined to be irrelevant based on the given `JoinFilter` and + /// `build_shrunk_exprs`. + /// + /// ```plaintext + /// + /// Partition + /// Batch + /// +----------+ Probe Batch + /// | | + /// | | +---------+ + /// | Prunable | | | + /// | Area | | | + /// | | | | + /// | | ----+| | + /// | | | | | + /// | | | +---------+ + /// |--------- |----+ + /// | | + /// | | + /// | | + /// +----------+ + /// + /// ``` + /// We make sure pruning is made from the safe area. + /// + /// + /// # Arguments + /// * `filter` - The join filter which helps determine the rows to prune. + /// * `build_shrunk_exprs` - A vector of expressions paired with their respective intervals, + /// which are used to evaluate the filter on the build side. + /// * `fetch_size` - The number of rows to fetch from the join hash map. + /// + /// # Returns + /// * A `Result` indicating the success or failure of the pruning operation. + fn prune( + &mut self, + filter: &JoinFilter, + build_shrunk_exprs: Vec<(PhysicalSortExpr, Interval)>, + fetch_size: usize, + ) -> Result<()> { + unsafe { + self.join_hash_map + .iter() + .map(|bucket| bucket.as_mut()) + .try_for_each(|(_, _, partition_state)| { + let matched_indices_len = partition_state.matched_indices; + let buffer_len = partition_state.record_batch.num_rows(); + let prune_length = if matched_indices_len > fetch_size { + // matched_indices is reset since if the corresponding key is not come from the probe side, + // we will still be able to prune it by interval calculations. + partition_state.matched_indices = 0; + matched_indices_len - fetch_size + } else { + let intermediate_batch = get_filter_representation_of_build_side( + filter.schema(), + &partition_state.record_batch, + filter.column_indices(), + JoinSide::Left, + )?; + let prune_lengths = build_shrunk_exprs + .iter() + .map(|(sort_expr, interval)| { + let options = sort_expr.options; + + // Get the lower or upper interval based on the sort direction: + let target = if options.descending { + &interval.lower.value + } else { + &interval.upper.value + } + .clone(); + + // Evaluate the build side filter expression and convert it into an array: + let batch_arr = sort_expr + .expr + .evaluate(&intermediate_batch)? + .into_array(intermediate_batch.num_rows()); + + // Perform binary search on the array to determine the length of + // the record batch to prune: + linear_search::(&[batch_arr], &[target], &[options]) + }) + .collect::>>()?; + let upper_slice_index = + prune_lengths.into_iter().min().unwrap_or(0); + + if upper_slice_index > fetch_size { + upper_slice_index - fetch_size + } else { + 0 + } + }; + partition_state.record_batch = partition_state + .record_batch + .slice(prune_length, buffer_len - prune_length); + Ok(()) + }) + } + } +} + +impl DisplayAs for PartitionedHashJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let display_filter = format!(", filter={}", self.filter.expression()); + let on = self + .on + .iter() + .map(|(c1, c2)| format!("({}, {})", c1, c2)) + .collect::>() + .join(", "); + write!( + f, + "PartitionedHashJoinExec: join_type={:?}, on=[{}]{}", + self.join_type, on, display_filter + ) + } + } + } +} + +impl ExecutionPlan for PartitionedHashJoinExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn output_partitioning(&self) -> Partitioning { + let left_columns_len = self.left.schema().fields.len(); + partitioned_join_output_partitioning( + self.join_type, + self.left.output_partitioning(), + self.right.output_partitioning(), + left_columns_len, + ) + } + + fn unbounded_output(&self, children: &[bool]) -> Result { + Ok(children.iter().any(|u| *u)) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + self.output_ordering.as_deref() + } + + fn required_input_distribution(&self) -> Vec { + let (left_expr, right_expr) = self + .on + .iter() + .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) + .unzip(); + vec![ + Distribution::HashPartitioned(left_expr), + Distribution::HashPartitioned(right_expr), + ] + } + + fn required_input_ordering(&self) -> Vec>> { + vec![ + Some(PhysicalSortRequirement::from_sort_exprs( + &self.left_sort_exprs, + )), + Some(PhysicalSortRequirement::from_sort_exprs( + &self.right_sort_exprs, + )), + ] + } + + fn maintains_input_order(&self) -> Vec { + Self::maintains_input_order(self.join_type) + } + + fn benefits_from_input_partitioning(&self) -> Vec { + vec![false; 2] + } + + fn equivalence_properties(&self) -> EquivalenceProperties { + let left_columns_len = self.left.schema().fields.len(); + combine_join_equivalence_properties( + self.join_type, + self.left.equivalence_properties(), + self.right.equivalence_properties(), + left_columns_len, + &self.on, + self.schema(), + ) + } + + fn ordering_equivalence_properties(&self) -> OrderingEquivalenceProperties { + combine_join_ordering_equivalence_properties( + &self.join_type, + &self.left, + &self.right, + self.schema(), + &self.maintains_input_order(), + Some(Self::probe_side()), + self.equivalence_properties(), + ) + .unwrap() + } + + fn children(&self) -> Vec> { + vec![self.left.clone(), self.right.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + match &children[..] { + [left, right] => Ok(Arc::new(PartitionedHashJoinExec::try_new( + left.clone(), + right.clone(), + self.on.clone(), + self.filter.clone(), + &self.join_type, + self.null_equals_null, + self.left_sort_exprs.clone(), + self.right_sort_exprs.clone(), + self.fetch_per_key, + )?)), + _ => Err(DataFusionError::Internal( + "PartitionedHashJoinExec wrong number of children".to_string(), + )), + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let left_partitions = self.left.output_partitioning().partition_count(); + let right_partitions = self.right.output_partitioning().partition_count(); + if left_partitions != right_partitions { + return internal_err!( + "Invalid PartitionedHashJoinExec, partition count mismatch {left_partitions}!={right_partitions},\ + consider using RepartitionExec" + ); + } + + let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = if let Some(( + left_sorted_filter_expr, + right_sorted_filter_expr, + graph, + )) = + prepare_sorted_exprs( + &self.filter, + &self.left, + &self.right, + &self.left_sort_exprs, + &self.right_sort_exprs, + )? { + (left_sorted_filter_expr, right_sorted_filter_expr, graph) + } else { + return internal_err!("PartitionedHashJoinExec can not operate unless both sides are pruning tables."); + }; + + let (on_left, on_right) = self.on.iter().cloned().unzip(); + + let left_stream = self.left.execute(partition, context.clone())?; + + let right_stream = self.right.execute(partition, context.clone())?; + + let metrics = PartitionedHashJoinMetrics::new(partition, &self.metrics); + let reservation = Arc::new(Mutex::new( + MemoryConsumer::new(format!("PartitionedHashJoinStream[{partition}]")) + .register(context.memory_pool()), + )); + reservation.lock().try_grow(graph.size())?; + + Ok(Box::pin(PartitionedHashJoinStream { + left_stream, + right_stream, + probe_buffer: ProbeBuffer::new(self.right.schema(), on_right), + build_buffer: BuildBuffer::new(self.left.schema(), on_left), + schema: self.schema(), + filter: self.filter.clone(), + join_type: self.join_type, + random_state: self.random_state.clone(), + column_indices: self.column_indices.clone(), + graph, + left_sorted_filter_expr, + right_sorted_filter_expr, + null_equals_null: self.null_equals_null, + reservation, + state: PartitionedHashJoinStreamState::PullProbe, + fetch_per_key: self.fetch_per_key, + metrics, + })) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Statistics { + // TODO stats: it is not possible in general to know the output size of joins + Statistics::default() + } +} + +/// We use this buffer to keep track of the probe side pulling. +struct ProbeBuffer { + /// The batch used for join operations + current_batch: RecordBatch, + /// The batches buffered in ProbePull state. + candidate_buffer: Vec, + /// The probe side keys + on: Vec, +} + +impl ProbeBuffer { + pub fn new(schema: SchemaRef, on: Vec) -> Self { + Self { + current_batch: RecordBatch::new_empty(schema), + candidate_buffer: vec![], + on, + } + } + pub fn size(&self) -> usize { + let mut size = 0; + size += self.current_batch.get_array_memory_size(); + size += self + .candidate_buffer + .iter() + .map(|batch| batch.get_array_memory_size()) + .sum::(); + size += mem::size_of_val(&self.on); + size + } +} + +/// A specialized stream designed to handle the output batches resulting from the execution of a `PartitionedHashJoinExec`. +/// +/// The `PartitionedHashJoinStream` manages the flow of record batches from both left and right input streams +/// during the hash join operation. For each batch of records from the right ("probe") side, it checks for matching rows +/// in the hash table constructed from the left ("build") side. +/// +/// The stream leverages sorted filter expressions for both left and right inputs to optimize range calculations +/// and potentially prune unnecessary data. It maintains buffers for currently processed batches and uses a given +/// schema, join filter, and join type to construct the resultant batches of the join operation. +struct PartitionedHashJoinStream { + /// Left stream + left_stream: SendableRecordBatchStream, + /// Right stream + right_stream: SendableRecordBatchStream, + /// Left globally sorted filter expression. + /// This expression is used to range calculations from the left stream. + left_sorted_filter_expr: Vec, + /// Right globally sorted filter expression. + /// This expression is used to range calculations from the right stream. + right_sorted_filter_expr: Vec, + /// Hash joiner for the right side. It is responsible for creating a hash map + /// from the right side data, which can be used to quickly look up matches when + /// joining with left side data. + build_buffer: BuildBuffer, + /// Buffer for the left side data. It keeps track of the current batch of data + /// from the left stream that we're working with. + probe_buffer: ProbeBuffer, + /// Schema of the input data. This defines the structure of the data in both + /// the left and right streams. + schema: Arc, + /// The join filter expression. This is a boolean expression that determines + /// whether a pair of rows, one from the left side and one from the right side, + /// should be included in the output of the join. + filter: JoinFilter, + /// The type of the join operation. This can be one of: inner, left, right, full, + /// semi, or anti join. + join_type: JoinType, + /// Information about the index and placement of columns. This is used when + /// constructing the output record batch, to know where to get data for each column. + column_indices: Vec, + /// Expression graph for range pruning. This graph describes the dependencies + /// between different columns in terms of range bounds, which can be used for + /// advanced optimizations, such as range calculations and pruning. + graph: ExprIntervalGraph, + /// Random state used for initializing the hash function in the hash joiner. + random_state: RandomState, + /// If true, null values are considered equal to other null values. If false, + /// null values are considered distinct from everything, including other null values. + null_equals_null: bool, + /// Memory reservation for this join operation. + reservation: SharedMemoryReservation, + /// Current state of the stream. This state machine tracks what the stream is + /// currently doing or should do next, e.g., pulling data from the probe side, + /// pulling data from the build side, performing the join, etc. + state: PartitionedHashJoinStreamState, + /// We limit the build side per key to achieve bounded memory for unbounded inputs + fetch_per_key: usize, + /// Metrics + metrics: PartitionedHashJoinMetrics, +} + +fn build_join_indices( + right_row_index: usize, + right_batch: &RecordBatch, + left_batch: &RecordBatch, + filter: &JoinFilter, +) -> Result<(UInt64Array, UInt32Array)> { + let left_row_count = left_batch.num_rows(); + let build_indices = UInt64Array::from_iter_values(0..(left_row_count as u64)); + let probe_indices = UInt32Array::from(vec![right_row_index as u32; left_row_count]); + apply_join_filter_to_indices( + left_batch, + right_batch, + build_indices, + probe_indices, + filter, + JoinSide::Left, + ) +} + +impl RecordBatchStream for PartitionedHashJoinStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +impl Stream for PartitionedHashJoinStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.poll_next_impl(cx) + } +} + +fn adjust_probe_row_indice_by_join_type( + build_indices: UInt64Array, + probe_indices: UInt32Array, + row_probe_batch: u32, + join_type: JoinType, +) -> Result<(UInt64Array, UInt32Array)> { + match join_type { + JoinType::Inner => { + // Unmatched rows for the left join will be produces in pruning phase. + Ok((build_indices, probe_indices)) + } + JoinType::Right => { + if probe_indices.is_empty() { + let build = (0..1).map(|_| None).collect::(); + Ok((build, UInt32Array::from_value(row_probe_batch, 1))) + } else { + Ok((build_indices, probe_indices)) + } + } + JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::Full + | JoinType::Left + | JoinType::RightSemi + | JoinType::RightAnti => { + // These join types are not supported. + unreachable!() + } + } +} + +impl PartitionedHashJoinStream { + /// Returns the total memory size of the stream. It's the sum of memory size of each field. + fn size(&self) -> usize { + let mut size = 0; + size += mem::size_of_val(&self.left_stream); + size += mem::size_of_val(&self.right_stream); + size += mem::size_of_val(&self.schema); + size += mem::size_of_val(&self.filter); + size += mem::size_of_val(&self.join_type); + size += self.build_buffer.size(); + size += self.probe_buffer.size(); + size += mem::size_of_val(&self.column_indices); + size += self.graph.size(); + size += mem::size_of_val(&self.left_sorted_filter_expr); + size += mem::size_of_val(&self.right_sorted_filter_expr); + size += mem::size_of_val(&self.metrics); + size + } + + #[allow(clippy::too_many_arguments)] + pub fn build_equal_condition_join_indices(&mut self) -> Result> { + let probe_batch = &self.probe_buffer.current_batch; + let probe_on = &self.probe_buffer.on; + let random_state = &self.random_state; + let filter = &self.filter; + let keys_values = probe_on + .iter() + .map(|c| Ok(c.evaluate(probe_batch)?.into_array(probe_batch.num_rows()))) + .collect::>>()?; + let mut hashes_buffer = vec![0_u64; probe_batch.num_rows()]; + let hash_values = create_hashes(&keys_values, random_state, &mut hashes_buffer)?; + let mut result = vec![]; + // Visit all of the probe rows + for (row, hash_value) in hash_values.iter().enumerate() { + // Get the hash and find it in the build index + // For every item on the build and probe we check if it matches + // This possibly contains rows with hash collisions, + // So we have to check here whether rows are equal or not + if let Some((_, _, partition_state)) = self + .build_buffer + .join_hash_map + .get_mut(*hash_value, |(key, _, _)| { + let partition_key = get_row_at_idx(&keys_values, row).unwrap(); + partition_key.iter().zip(key.iter()).all(|(lhs, rhs)| { + dyn_eq_with_null_support(lhs, rhs, self.null_equals_null) + }) + }) + { + let build_batch = &partition_state.record_batch; + let build_join_values = self + .build_buffer + .on + .iter() + .map(|c| { + Ok(c.evaluate(build_batch)?.into_array(build_batch.num_rows())) + }) + .collect::>>()?; + let (build_indices, probe_indices) = + build_join_indices(row, probe_batch, build_batch, filter)?; + + let (build_indices, probe_indices) = equal_rows_arr( + &build_indices, + &probe_indices, + &build_join_values, + &keys_values, + self.null_equals_null, + )?; + partition_state.matched_indices = build_indices.len(); + // adjust the two side indices base on the join type + // Adjusts indices according to the type of join + let (build_indices, probe_indices) = + adjust_probe_row_indice_by_join_type( + build_indices, + probe_indices, + row as u32, + self.join_type, + )?; + let batch = build_batch_from_indices( + &self.schema, + build_batch, + probe_batch, + &build_indices, + &probe_indices, + &self.column_indices, + JoinSide::Left, + )?; + if batch.num_rows() > 0 { + result.push(batch) + } + } else { + let mut build_indices = UInt64Builder::with_capacity(0); + let build_indices = build_indices.finish(); + let mut probe_indices = UInt32Builder::with_capacity(0); + let probe_indices = probe_indices.finish(); + let (build_indices, probe_indices) = + adjust_probe_row_indice_by_join_type( + build_indices, + probe_indices, + row as u32, + self.join_type, + )?; + let batch = build_batch_from_indices( + &self.schema, + // It will be none + &self.build_buffer.latest_batch, + probe_batch, + &build_indices, + &probe_indices, + &self.column_indices, + JoinSide::Left, + )?; + if batch.num_rows() > 0 { + result.push(batch) + } + } + } + + Ok(result) + } + + fn poll_next_impl( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + loop { + match &mut self.state { + // When the state is "PullProbe", poll the right (probe) stream + PartitionedHashJoinStreamState::PullProbe => { + loop { + match ready!(self.right_stream.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + // Update metrics for polled batch: + self.metrics.probe.input_batches.add(1); + self.metrics.probe.input_rows.add(batch.num_rows()); + + // Check if batch meets interval calculation criteria: + let stop_polling = + is_batch_suitable_interval_calculation( + &self.filter, + &self.right_sorted_filter_expr, + &batch, + JoinSide::Right, + )?; + // Add the batch into candidate buffer: + self.probe_buffer.candidate_buffer.push(batch); + if stop_polling { + break; + } + } + None => break, + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + } + } + if self.probe_buffer.candidate_buffer.is_empty() { + // If no batches were collected, change state to "ProbeExhausted" + self.state = PartitionedHashJoinStreamState::ProbeExhausted; + continue; + } + // Get probe batch by joining all the collected batches + self.probe_buffer.current_batch = get_probe_batch(mem::take( + &mut self.probe_buffer.candidate_buffer, + ))?; + + if self.probe_buffer.current_batch.num_rows() == 0 { + continue; + } + // Since we only use schema information of the build side record batch, + // keeping only first batch + // Update the probe side with the new probe batch: + let calculated_build_side_interval = + calculate_the_necessary_build_side_range( + &self.filter, + &self.build_buffer.latest_batch, + &mut self.graph, + &mut self.left_sorted_filter_expr, + &mut self.right_sorted_filter_expr, + &self.probe_buffer.current_batch, + )?; + // Update state to "PullBuild" with calculated interval + self.state = PartitionedHashJoinStreamState::PullBuild { + interval: calculated_build_side_interval, + }; + } + PartitionedHashJoinStreamState::PullBuild { interval } => { + let build_interval = interval.clone(); + // Keep pulling data from the left stream until a suitable + // range on batches is found: + loop { + match ready!(self.left_stream.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + self.metrics.build.input_batches.add(1); + if batch.num_rows() == 0 { + continue; + } + self.metrics.build.input_batches.add(1); + self.metrics.build.input_rows.add(batch.num_rows()); + + self.build_buffer.update_partition_batch( + &batch, + &self.random_state, + self.null_equals_null, + )?; + + if check_if_sliding_window_condition_is_met( + &self.filter, + &batch, + &build_interval, + )? { + self.state = PartitionedHashJoinStreamState::Join { + interval: build_interval, + }; + break; + } + } + // If the poll doesn't return any data, check if there are any batches. If so, + // combine them into one and update the build buffer's internal state. + None => { + self.state = PartitionedHashJoinStreamState::Join { + interval: build_interval, + }; + break; + } + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + } + } + } + PartitionedHashJoinStreamState::Join { interval } => { + let build_interval = interval.clone(); + // Calculate the equality results + let result_batches = self.build_equal_condition_join_indices()?; + + // Prune the buffers to drain until 'fetch' number of hashable rows remain. + self.build_buffer.prune( + &self.filter, + build_interval, + self.fetch_per_key, + )?; + + // Combine join result into a single batch. + let result = concat_batches(&self.schema, &result_batches)?; + + // Update the state to PullProbe, so the next iteration will pull from the probe side. + self.state = PartitionedHashJoinStreamState::PullProbe; + + // Calculate the current memory usage of the stream. + let capacity = self.size(); + self.metrics.stream_memory_usage.set(capacity); + + // Update memory pool + self.reservation.lock().try_resize(capacity)?; + + if result.num_rows() > 0 { + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(result.num_rows()); + return Poll::Ready(Some(Ok(result))); + } + } + PartitionedHashJoinStreamState::ProbeExhausted => { + // After probe is exhausted first there will be no more new addition into any + // group key, since probe fetches the necessary build side first. If probe side + // is exhausted before the build side, the previous probe batch saw all necessary + // data. + return Poll::Ready(None); + } + } + } + } +} + +#[cfg(test)] +mod fuzzy_tests { + use std::collections::HashMap; + use std::sync::{Arc, Mutex}; + + use super::*; + + use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; + use crate::common; + use crate::joins::test_utils::{ + build_sides_record_batches, compare_batches, create_memory_table, + split_record_batches, + }; + use crate::joins::utils::JoinSide; + use crate::joins::{HashJoinExec, PartitionMode}; + use crate::repartition::RepartitionExec; + use crate::sorts::sort_preserving_merge::SortPreservingMergeExec; + + use arrow::datatypes::{DataType, Field}; + use arrow_schema::{SortOptions, TimeUnit}; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{binary, col, BinaryExpr, Literal}; + use datafusion_physical_expr::{ + add_offset_to_lex_ordering, expressions, AggregateExpr, + }; + + use once_cell::sync::Lazy; + use rstest::*; + + const TABLE_SIZE: i32 = 100; + type TableKey = (i32, i32, usize); // (cardinality.0, cardinality.1, batch_size) + type TableValue = (Vec, Vec); // (left, right) + + // Cache for storing tables + static TABLE_CACHE: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::new())); + + fn get_or_create_table( + cardinality: (i32, i32), + batch_size: usize, + ) -> Result { + { + let cache = TABLE_CACHE.lock().unwrap(); + if let Some(table) = cache.get(&(cardinality.0, cardinality.1, batch_size)) { + return Ok(table.clone()); + } + } + + // If not, create the table + let (left_batch, right_batch) = + build_sides_record_batches(TABLE_SIZE, cardinality)?; + + let (left_partition, right_partition) = ( + split_record_batches(&left_batch, batch_size)?, + split_record_batches(&right_batch, batch_size)?, + ); + + // Lock the cache again and store the table + let mut cache = TABLE_CACHE.lock().unwrap(); + + // Store the table in the cache + cache.insert( + (cardinality.0, cardinality.1, batch_size), + (left_partition.clone(), right_partition.clone()), + ); + + Ok((left_partition, right_partition)) + } + + /// This test function generates a conjunctive statement with two numeric + /// terms with the following form: + /// left_col (op_1) a >/>= right_col (op_2) + fn gen_conjunctive_numerical_expr_single_side_prunable( + left_col: Arc, + right_col: Arc, + op: (Operator, Operator), + a: ScalarValue, + b: ScalarValue, + comparison_op: Operator, + ) -> Arc { + let (op_1, op_2) = op; + let left_and_1 = Arc::new(BinaryExpr::new( + left_col.clone(), + op_1, + Arc::new(Literal::new(a)), + )); + let left_and_2 = Arc::new(BinaryExpr::new( + right_col.clone(), + op_2, + Arc::new(Literal::new(b)), + )); + Arc::new(BinaryExpr::new(left_and_1, comparison_op, left_and_2)) + } + /// This test function generates a conjunctive statement with + /// two scalar values with the following form: + /// left_col (op_1) a > right_col (op_2) + #[allow(clippy::too_many_arguments)] + fn gen_conjunctive_temporal_expr_single_side( + left_col: Arc, + right_col: Arc, + op_1: Operator, + op_2: Operator, + a: ScalarValue, + b: ScalarValue, + schema: &Schema, + comparison_op: Operator, + ) -> Result, DataFusionError> { + let left_and_1 = + binary(left_col.clone(), op_1, Arc::new(Literal::new(a)), schema)?; + let left_and_2 = + binary(right_col.clone(), op_2, Arc::new(Literal::new(b)), schema)?; + Ok(Arc::new(BinaryExpr::new( + left_and_1, + comparison_op, + left_and_2, + ))) + } + + async fn partitioned_hash_join_with_filter_and_group_by( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + join_type: &JoinType, + null_equals_null: bool, + context: Arc, + ) -> Result> { + let partition_count = 24; + let (left_expr, right_expr) = on + .iter() + .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) + .unzip(); + + let right_sort_expr = right + .output_ordering() + .map(|order| order.to_vec()) + .ok_or(DataFusionError::Internal("Test fail.".to_owned())) + .unwrap(); + + let adjusted_right_order = + add_offset_to_lex_ordering(&right_sort_expr, left.schema().fields().len())?; + + let join = Arc::new(HashJoinExec::try_new( + Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, partition_count), + )?), + Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, partition_count), + )?), + on, + Some(filter), + join_type, + PartitionMode::Partitioned, + null_equals_null, + )?); + + let join_schema = join.schema(); + + let agg = Arc::new(expressions::LastValue::new( + Arc::new(Column::new_with_schema("la1", &join_schema)?), + "array_agg(la1)".to_string(), + join_schema + .field_with_name("la1") + .unwrap() + .data_type() + .clone(), + vec![], + vec![], + )); + + let aggregates: Vec> = vec![agg]; + + let groups: Vec<(Arc, String)> = vec![( + Arc::new(Column::new_with_schema("ra1", &join_schema)?), + "ra1".to_string(), + )]; + + let final_grouping_set = PhysicalGroupBy::new_single(groups); + + let merged_aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + final_grouping_set, + aggregates, + vec![None], + vec![None], + Arc::new(SortPreservingMergeExec::new(adjusted_right_order, join)), + join_schema, + )?); + + let stream = merged_aggregate.execute(0, context.clone())?; + let batches = common::collect(stream).await?; + + Ok(batches) + } + + #[allow(clippy::too_many_arguments)] + async fn partitioned_partial_hash_join_with_filter_group_by( + left: Arc, + right: Arc, + on: JoinOn, + filter: JoinFilter, + join_type: &JoinType, + null_equals_null: bool, + context: Arc, + fetch_per_key: usize, + ) -> Result> { + let partition_count = 1; + let (left_expr, right_expr) = on + .iter() + .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _)) + .unzip(); + let left_sort_expr = left + .output_ordering() + .map(|order| order.to_vec()) + .ok_or(DataFusionError::Internal( + "PartitionedHashJoinExec needs left and right side ordered.".to_owned(), + )) + .unwrap(); + let right_sort_expr = right + .output_ordering() + .map(|order| order.to_vec()) + .ok_or(DataFusionError::Internal( + "PartitionedHashJoinExec needs left and right side ordered.".to_owned(), + )) + .unwrap(); + + let adjusted_right_order = + add_offset_to_lex_ordering(&right_sort_expr, left.schema().fields().len())?; + + let join = Arc::new(PartitionedHashJoinExec::try_new( + Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, partition_count), + )?), + Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, partition_count), + )?), + on, + filter, + join_type, + null_equals_null, + left_sort_expr, + right_sort_expr, + fetch_per_key, + )?); + + let join_schema = join.schema(); + + let agg = Arc::new(expressions::LastValue::new( + Arc::new(Column::new_with_schema("la1", &join_schema)?), + "array_agg(la1)".to_string(), + join_schema + .field_with_name("la1") + .unwrap() + .data_type() + .clone(), + vec![], + vec![], + )); + + let aggregates: Vec> = vec![agg]; + + let groups: Vec<(Arc, String)> = vec![( + Arc::new(Column::new_with_schema("ra1", &join_schema)?), + "ra1".to_string(), + )]; + + let final_grouping_set = PhysicalGroupBy::new_single(groups); + + let merged_aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + final_grouping_set, + aggregates, + vec![None], + vec![None], + Arc::new(SortPreservingMergeExec::new(adjusted_right_order, join)), + join_schema, + )?); + + let stream = merged_aggregate.execute(0, context.clone())?; + let batches = common::collect(stream).await?; + + Ok(batches) + } + + async fn experiment_with_group_by( + left: Arc, + right: Arc, + filter: JoinFilter, + join_type: JoinType, + on: JoinOn, + task_ctx: Arc, + fetch_per_key: usize, + ) -> Result<()> { + let first_batches = partitioned_partial_hash_join_with_filter_group_by( + left.clone(), + right.clone(), + on.clone(), + filter.clone(), + &join_type, + false, + task_ctx.clone(), + fetch_per_key, + ) + .await?; + let second_batches = partitioned_hash_join_with_filter_and_group_by( + left.clone(), + right.clone(), + on.clone(), + filter.clone(), + &join_type, + false, + task_ctx.clone(), + ) + .await?; + compare_batches(&first_batches, &second_batches); + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn join_all_one_ascending_numeric( + #[values(JoinType::Inner, JoinType::Right)] join_type: JoinType, + #[values( + (4, 5), + (11, 21), + (21, 13), + (99, 12), + )] + cardinality: (i32, i32), + #[values(5, 200, 131, 1, 2, 40)] batch_size: usize, + #[values(1, 3, 30, 100)] fetch_per_key: usize, + #[values( + ("l_random_ordered", "r_random_ordered"), + ("la1", "ra1") + )] + sorted_cols: (&str, &str), + ) -> Result<()> { + let (l_sorted_col, r_sorted_col) = sorted_cols; + let task_ctx = Arc::new(TaskContext::default()); + let (left_partition, right_partition) = + get_or_create_table(cardinality, batch_size)?; + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col(l_sorted_col, left_schema)?, + options: SortOptions::default(), + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col(r_sorted_col, right_schema)?, + options: SortOptions::default(), + }]; + let (left, right) = create_memory_table( + left_partition, + right_partition, + vec![left_sorted], + vec![right_sorted], + )?; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + + let filter_expr = gen_conjunctive_numerical_expr_single_side_prunable( + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + (Operator::Plus, Operator::Minus), + ScalarValue::Int32(Some(10)), + ScalarValue::Int32(Some(3)), + Operator::Lt, + ); + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let column_indices = vec![ + ColumnIndex { + index: left_schema.index_of(l_sorted_col).unwrap(), + side: JoinSide::Left, + }, + ColumnIndex { + index: right_schema.index_of(r_sorted_col).unwrap(), + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + experiment_with_group_by( + left, + right, + filter, + join_type, + on, + task_ctx, + fetch_per_key, + ) + .await?; + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + #[ignore] + async fn testing_with_temporal_columns( + #[values(JoinType::Inner, JoinType::Right)] join_type: JoinType, + #[values( + (4, 5), + (11, 21), + (21, 13), + (99, 12), + )] + cardinality: (i32, i32), + #[values(5, 200, 131, 1, 2, 40)] batch_size: usize, + #[values(1, 3, 30, 100)] fetch_per_key: usize, + ) -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let (left_partition, right_partition) = + get_or_create_table(cardinality, batch_size)?; + + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + let left_sorted = vec![PhysicalSortExpr { + expr: col("lt1", left_schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("rt1", right_schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]; + let (left, right) = create_memory_table( + left_partition, + right_partition, + vec![left_sorted], + vec![right_sorted], + )?; + let intermediate_schema = Schema::new(vec![ + Field::new( + "left", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new( + "right", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + ]); + + let filter_expr = gen_conjunctive_temporal_expr_single_side( + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + Operator::Minus, + Operator::Minus, + ScalarValue::new_interval_dt(0, 100), // 100 ms + ScalarValue::new_interval_dt(0, 200), + &intermediate_schema, + Operator::LtEq, + )?; + + let column_indices = vec![ + ColumnIndex { + index: left_schema.index_of("lt1").unwrap(), + side: JoinSide::Left, + }, + ColumnIndex { + index: right_schema.index_of("rt1").unwrap(), + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + experiment_with_group_by( + left, + right, + filter, + join_type, + on, + task_ctx, + fetch_per_key, + ) + .await?; + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn build_null_columns_first_descending( + #[values(JoinType::Inner, JoinType::Right)] join_type: JoinType, + #[values( + (4, 5), + (11, 21), + (21, 13), + (99, 12), + )] + cardinality: (i32, i32), + #[values(5, 200, 131, 1, 2, 40)] batch_size: usize, + #[values(1, 3, 30, 100)] fetch_per_key: usize, + ) -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let (left_partition, right_partition) = + get_or_create_table(cardinality, batch_size)?; + + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col("l_desc_null_first", left_schema)?, + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("r_desc_null_first", right_schema)?, + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let (left, right) = create_memory_table( + left_partition, + right_partition, + vec![left_sorted], + vec![right_sorted], + )?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = gen_conjunctive_numerical_expr_single_side_prunable( + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + (Operator::Plus, Operator::Minus), + ScalarValue::Int32(Some(10)), + ScalarValue::Int32(Some(3)), + Operator::Gt, + ); + let column_indices = vec![ + ColumnIndex { + index: left_schema.index_of("l_desc_null_first").unwrap(), + side: JoinSide::Left, + }, + ColumnIndex { + index: right_schema.index_of("r_desc_null_first").unwrap(), + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment_with_group_by( + left, + right, + filter, + join_type, + on, + task_ctx, + fetch_per_key, + ) + .await?; + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn build_null_columns_last( + #[values(JoinType::Inner, JoinType::Right)] join_type: JoinType, + #[values( + (4, 5), + (11, 21), + (21, 13), + (99, 12), + )] + cardinality: (i32, i32), + #[values(5, 200, 131, 1, 2, 40)] batch_size: usize, + #[values(1, 3, 30, 100)] fetch_per_key: usize, + ) -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let (left_partition, right_partition) = + get_or_create_table(cardinality, batch_size)?; + + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col("l_asc_null_last", left_schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("r_asc_null_last", right_schema)?, + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + let (left, right) = create_memory_table( + left_partition, + right_partition, + vec![left_sorted], + vec![right_sorted], + )?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = gen_conjunctive_numerical_expr_single_side_prunable( + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + (Operator::Plus, Operator::Minus), + ScalarValue::Int32(Some(10)), + ScalarValue::Int32(Some(3)), + Operator::Lt, + ); + let column_indices = vec![ + ColumnIndex { + index: left_schema.index_of("l_asc_null_last").unwrap(), + side: JoinSide::Left, + }, + ColumnIndex { + index: right_schema.index_of("r_asc_null_last").unwrap(), + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment_with_group_by( + left, + right, + filter, + join_type, + on, + task_ctx, + fetch_per_key, + ) + .await?; + Ok(()) + } + + #[rstest] + #[tokio::test(flavor = "multi_thread")] + async fn join_all_one_descending_numeric_particular( + #[values(JoinType::Inner, JoinType::Right)] join_type: JoinType, + #[values( + (4, 5), + (11, 21), + (21, 13), + (99, 12), + )] + cardinality: (i32, i32), + #[values(5, 200, 131, 1, 2, 40)] batch_size: usize, + #[values(1, 3, 30, 100)] fetch_per_key: usize, + ) -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let (left_partition, right_partition) = + get_or_create_table(cardinality, batch_size)?; + + let left_schema = &left_partition[0].schema(); + let right_schema = &right_partition[0].schema(); + let left_sorted = vec![PhysicalSortExpr { + expr: col("la1_des", left_schema)?, + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let right_sorted = vec![PhysicalSortExpr { + expr: col("ra1_des", right_schema)?, + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let (left, right) = create_memory_table( + left_partition, + right_partition, + vec![left_sorted], + vec![right_sorted], + )?; + + let on = vec![( + Column::new_with_schema("lc1", left_schema)?, + Column::new_with_schema("rc1", right_schema)?, + )]; + + let intermediate_schema = Schema::new(vec![ + Field::new("left", DataType::Int32, true), + Field::new("right", DataType::Int32, true), + ]); + let filter_expr = gen_conjunctive_numerical_expr_single_side_prunable( + col("left", &intermediate_schema)?, + col("right", &intermediate_schema)?, + (Operator::Plus, Operator::Minus), + ScalarValue::Int32(Some(10)), + ScalarValue::Int32(Some(3)), + Operator::Gt, + ); + let column_indices = vec![ + ColumnIndex { + index: left_schema.index_of("la1_des").unwrap(), + side: JoinSide::Left, + }, + ColumnIndex { + index: right_schema.index_of("ra1_des").unwrap(), + side: JoinSide::Right, + }, + ]; + let filter = JoinFilter::new(filter_expr, column_indices, intermediate_schema); + + experiment_with_group_by( + left, + right, + filter, + join_type, + on, + task_ctx, + fetch_per_key, + ) + .await?; + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/sliding_hash_join.rs b/datafusion/physical-plan/src/joins/sliding_hash_join.rs index a69f359c94b7..8a8e3b4fb9d8 100644 --- a/datafusion/physical-plan/src/joins/sliding_hash_join.rs +++ b/datafusion/physical-plan/src/joins/sliding_hash_join.rs @@ -635,7 +635,7 @@ impl ProbeBuffer { /// /// # Returns /// - /// A new `BuildSideBuffer`. + /// A new `BuildSideBuffer`. pub fn new(schema: SchemaRef, on: Vec) -> Self { Self { current_batch: RecordBatch::new_empty(schema), @@ -648,7 +648,7 @@ impl ProbeBuffer { /// /// # Returns /// - /// The size of this `ProbeBuffer` in bytes. + /// The size of this `ProbeBuffer` in bytes. pub fn size(&self) -> usize { let mut size = 0; size += self.current_batch.get_array_memory_size(); diff --git a/datafusion/physical-plan/src/joins/sliding_nested_loop_join.rs b/datafusion/physical-plan/src/joins/sliding_nested_loop_join.rs index 3566cb99de10..57073e39ed79 100644 --- a/datafusion/physical-plan/src/joins/sliding_nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/sliding_nested_loop_join.rs @@ -31,6 +31,7 @@ use crate::joins::utils::{ JoinSide, }; use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use crate::projection::ProjectionExec; use crate::stream::RecordBatchBroadcastStreamsBuilder; use crate::{ metrics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, @@ -41,8 +42,7 @@ use arrow::array::{UInt32Array, UInt32Builder, UInt64Array, UInt64Builder}; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, DataFusionError, Result, Statistics}; -use datafusion_execution::memory_pool::MemoryConsumer; -use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_expr::JoinType; use datafusion_physical_expr::intervals::ExprIntervalGraph; @@ -51,7 +51,6 @@ use datafusion_physical_expr::{ PhysicalSortRequirement, }; -use crate::projection::ProjectionExec; use futures::{ready, Stream, StreamExt}; use hashbrown::HashSet; use parking_lot::Mutex; @@ -230,6 +229,20 @@ impl SlidingNestedLoopJoinExec { }) } + /// Calculate order preservation flags for this join. + fn maintains_input_order(join_type: JoinType) -> Vec { + vec![ + false, + matches!( + join_type, + JoinType::Inner + | JoinType::Right + | JoinType::RightAnti + | JoinType::RightSemi + ), + ] + } + /// left (build) side which gets hashed pub fn left(&self) -> &Arc { &self.left @@ -260,20 +273,6 @@ impl SlidingNestedLoopJoinExec { &self.right_sort_exprs } - /// Calculate order preservation flags for this join. - fn maintains_input_order(join_type: JoinType) -> Vec { - vec![ - false, - matches!( - join_type, - JoinType::Inner - | JoinType::Right - | JoinType::RightAnti - | JoinType::RightSemi - ), - ] - } - /// In this section, we are employing the strategy of broadcasting /// single partition sides. This approach mirrors how we distribute /// `OnceFut` in `NestedLoopJoinStream`(s). Each partition diff --git a/datafusion/physical-plan/src/joins/sliding_window_join_utils.rs b/datafusion/physical-plan/src/joins/sliding_window_join_utils.rs index 21ef86de5856..ab7bfb31aee1 100644 --- a/datafusion/physical-plan/src/joins/sliding_window_join_utils.rs +++ b/datafusion/physical-plan/src/joins/sliding_window_join_utils.rs @@ -7,7 +7,9 @@ use crate::joins::{ }, utils::{ append_right_indices, get_anti_indices, get_build_side_pruned_exprs, - get_filter_representation_of_build_side, get_semi_indices, JoinFilter, JoinSide, + get_filter_representation_of_build_side, + get_filter_representation_schema_of_build_side, get_semi_indices, JoinFilter, + JoinSide, }, }; @@ -25,8 +27,22 @@ use datafusion_physical_expr::{ use hashbrown::HashSet; -/// This function checks if the batch offers a reference value that enables us -/// to tell whether is falls in the viable sliding window via interval analysis. +/// Determines if the given batch is suitable for interval calculations based on the join +/// filter and sorted filter expressions. +/// +/// The function evaluates the latest row of the batch for each sorted filter expression. +/// It is considered suitable if the evaluated value for all sorted filter expressions are non-null. +/// Empty batches are deemed unsuitable by default. +/// +/// # Arguments +/// * `filter`: The `JoinFilter` used to determine the suitability of the batch. +/// * `probe_sorted_filter_exprs`: A slice of sorted filter expressions used to evaluate the suitability of the batch. +/// * `batch`: The `RecordBatch` to evaluate. +/// * `build_side`: The side of the join operation (either `JoinSide::Left` or `JoinSide::Right`). +/// +/// # Returns +/// * A `Result` containing a boolean value. Returns `true` if the batch is suitable for interval calculation, `false` otherwise. +/// pub fn is_batch_suitable_interval_calculation( filter: &JoinFilter, probe_sorted_filter_exprs: &[SortedFilterExpr], @@ -62,9 +78,26 @@ pub fn is_batch_suitable_interval_calculation( Ok(result.into_iter().all(|b| b)) } -/// This function takes a batch of data from the probe side, calculates the -/// interval for the build side filter expression, and updates the probe -/// buffer and the stream state accordingly. +/// Calculates the necessary build-side range for join pruning. +/// +/// Given a join filter, build inner buffer, and the current state of the expression graph, +/// this function computes the interval range for the build side filter expression and then +/// updates the expression graph with the calculated interval range. This aids in optimizing +/// the join operation by pruning unnecessary rows from the build side and fetching just enough +/// batch. +/// +/// # Arguments +/// * `filter`: The join filter which dictates the join condition. +/// * `build_inner_buffer`: The record batch representing the build side of the join. +/// * `graph`: The current state of the expression interval graph to be updated. +/// * `build_sorted_filter_exprs`: Sorted filter expressions related to the build side. +/// * `probe_sorted_filter_exprs`: Sorted filter expressions related to the probe side. +/// * `probe_batch`: The probe record batch. +/// +/// # Returns +/// * A vector of tuples containing the physical sort expression and its associated interval +/// for the build side. These tuples represent the range in which join pruning can occur +/// for each expression. pub fn calculate_the_necessary_build_side_range( filter: &JoinFilter, build_inner_buffer: &RecordBatch, @@ -78,7 +111,6 @@ pub fn calculate_the_necessary_build_side_range( filter, build_inner_buffer, build_sorted_filter_exprs, - JoinSide::Left, probe_batch, probe_sorted_filter_exprs, JoinSide::Right, @@ -98,24 +130,34 @@ pub fn calculate_the_necessary_build_side_range( // Update the physical expression graph using the join filter intervals: graph.update_ranges(&mut filter_intervals)?; - let intermediate_batch = get_filter_representation_of_build_side( + let intermediate_schema = get_filter_representation_schema_of_build_side( filter.schema(), - build_inner_buffer, filter.column_indices(), JoinSide::Left, )?; - let intermediate_schema = intermediate_batch.schema(); - // Filter expressions that can shrink. let shrunk_exprs = graph.get_deepest_pruning_exprs()?; // Get only build side filter expressions get_build_side_pruned_exprs(shrunk_exprs, intermediate_schema, filter, JoinSide::Left) } -/// Checks whether the given reference value (i.e. `latest_value`) falls within -/// the viable sliding window specified by `interval` according to the sort -/// options of the join in question (i.e. `sort_options`). +/// Checks if the sliding window condition is met for the join operation. +/// +/// This function evaluates the incoming build batch against a set of intervals +/// to determine whether the sliding window condition has been satisfied. It assesses +/// that the current window has captured all the relevant rows for the join. +/// +/// # Arguments +/// * `filter`: The join filter defining the join condition. +/// * `incoming_build_batch`: The incoming record batch from the build side. +/// * `intervals`: A set of intervals representing the build side's boundaries +/// against which the incoming batch is evaluated. +/// +/// # Returns +/// * A boolean value indicating if the sliding window condition is met: +/// * `true` if all rows necessary from the build side for this window have been processed. +/// * `false` otherwise. pub fn check_if_sliding_window_condition_is_met( filter: &JoinFilter, incoming_build_batch: &RecordBatch, @@ -137,6 +179,9 @@ pub fn check_if_sliding_window_condition_is_met( .evaluate(&latest_build_intermediate_batch)? .into_array(1); let latest_value = ScalarValue::try_from_array(&array, 0)?; + if latest_value.is_null() { + return Ok(false); + } Ok(if sorted_shrunk_expr.options.descending { // Data is sorted in descending order, so check if latest value is less // than the lower bound of the interval. If it is, we must have processed @@ -153,10 +198,18 @@ pub fn check_if_sliding_window_condition_is_met( Ok(results.iter().all(|e| *e)) } -/// This function combines the given batches into a probe batch. -pub fn get_probe_batch( - mut batches: Vec, -) -> datafusion_common::Result { +/// Constructs a single `RecordBatch` from a vector of `RecordBatch`es. +/// +/// If there's only one batch in the vector, it's directly returned. Otherwise, +/// all the batches are concatenated to produce a single `RecordBatch`. +/// +/// # Arguments +/// * `batches`: A vector of `RecordBatch`es to be combined into a single batch. +/// +/// # Returns +/// * A `Result` containing a single `RecordBatch` or an error if the concatenation fails. +/// +pub fn get_probe_batch(mut batches: Vec) -> Result { let probe_batch = if batches.len() == 1 { batches.remove(0) } else { @@ -432,22 +485,14 @@ pub(crate) fn update_filter_expr_bounds( filter: &JoinFilter, build_inner_buffer: &RecordBatch, build_sorted_filter_exprs: &mut [SortedFilterExpr], - build_side: JoinSide, probe_batch: &RecordBatch, probe_sorted_filter_exprs: &mut [SortedFilterExpr], probe_side: JoinSide, ) -> Result<()> { - let build_intermediate_batch = get_filter_representation_of_build_side( - filter.schema(), - &build_inner_buffer.slice(0, 0), - filter.column_indices(), - build_side, - )?; // Evaluate the build side order expression to get datatype: let build_order_datatype = build_sorted_filter_exprs[0] .intermediate_batch_filter_expr() - .evaluate(&build_intermediate_batch)? - .data_type(); + .data_type(&build_inner_buffer.schema())?; // Create a null scalar value with the obtained datatype: let null_scalar = ScalarValue::try_from(build_order_datatype)?; diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index 60a345884078..1476029e277a 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -39,6 +39,7 @@ use arrow_array::{ use arrow_schema::{DataType, Schema}; use datafusion_common::ScalarValue; use datafusion_common::{Result, DataFusionError, ScalarValue}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_execution::TaskContext; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::expressions::{binary, cast, col, lit}; @@ -411,17 +412,40 @@ macro_rules! join_expr_tests { join_expr_tests!(join_expr_tests_fixture_i32, i32, Int32); join_expr_tests!(join_expr_tests_fixture_f64, f64, Float64); +fn generate_ordered_array(size: i32, duplicate_ratio: f32) -> Arc { + let mut rng = StdRng::seed_from_u64(42); + let unique_count = (size as f32 * (1.0 - duplicate_ratio)) as i32; + + // Generate unique random values + let mut values: Vec = (0..unique_count) + .map(|_| rng.gen_range(1..500)) // Modify as per your range + .collect(); + + // Duplicate the values according to the duplicate ratio + for _ in 0..(size - unique_count) { + let index = rng.gen_range(0..unique_count); + values.push(values[index as usize]); + } + + // Sort the values to ensure they are ordered + values.sort(); + + Arc::new(Int32Array::from_iter(values)) +} + pub fn build_sides_record_batches( table_size: i32, key_cardinality: (i32, i32), ) -> Result<(RecordBatch, RecordBatch)> { let null_ratio: f64 = 0.4; + let duplicate_ratio = 0.4; let initial_range = 0..table_size; let index = (table_size as f64 * null_ratio).round() as i32; let rest_of = index..table_size; let ordered: ArrayRef = Arc::new(Int32Array::from_iter( initial_range.clone().collect::>(), )); + let random_ordered = generate_ordered_array(table_size, duplicate_ratio); let ordered_des = Arc::new(Int32Array::from_iter( initial_range.clone().rev().collect::>(), )); @@ -490,6 +514,7 @@ pub fn build_sides_record_batches( ("l_desc_null_first", ordered_desc_null_first.clone()), ("li1", interval_time.clone()), ("l_float", float_asc.clone()), + ("l_random_ordered", random_ordered.clone()), ])?; let right = RecordBatch::try_from_iter(vec![ ("ra1", ordered.clone()), @@ -503,6 +528,7 @@ pub fn build_sides_record_batches( ("r_desc_null_first", ordered_desc_null_first), ("ri1", interval_time), ("r_float", float_asc), + ("r_random_ordered", random_ordered), ])?; Ok((left, right)) } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 95ea5f3cddf5..af769d527f4c 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -1443,7 +1443,7 @@ pub fn swap_join_on(on: &JoinOn) -> JoinOn { } /// Swaps join sides for filter column indices and produces new JoinFilter -pub(crate) fn swap_filter(filter: &JoinFilter) -> JoinFilter { +pub fn swap_filter(filter: &JoinFilter) -> JoinFilter { let column_indices = filter .column_indices() .iter() @@ -1459,11 +1459,6 @@ pub(crate) fn swap_filter(filter: &JoinFilter) -> JoinFilter { ) } -/// Swaps join sides for filter column indices and produces new `JoinFilter` (if exists). -pub fn swap_join_filter(filter: Option<&JoinFilter>) -> Option { - filter.map(swap_filter) -} - /// This function returns the new join type we get after swapping the given /// join's inputs. pub fn swap_join_type(join_type: JoinType) -> JoinType { diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs index c0db111bc60d..c46e7d9e8198 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs @@ -15,21 +15,20 @@ // specific language governing permissions and limitations // under the License. -use arrow::util::display::ArrayFormatter; -use arrow::{array, array::ArrayRef, datatypes::DataType, record_batch::RecordBatch}; -use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; -use datafusion_common::DFField; -use datafusion_common::DataFusionError; use std::path::PathBuf; use std::sync::OnceLock; -use crate::engines::output::DFColumnType; - use super::super::conversion::*; use super::error::{DFSqlLogicTestError, Result}; +use crate::engines::output::DFColumnType; + +use arrow::util::display::ArrayFormatter; +use arrow::{array, array::ArrayRef, datatypes::DataType, record_batch::RecordBatch}; +use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; +use datafusion_common::{DFField, DataFusionError}; /// Converts `batches` to a result as expected by sqllogicteset. -pub(crate) fn convert_batches(batches: Vec) -> Result>> { +pub fn convert_batches(batches: Vec) -> Result>> { if batches.is_empty() { Ok(vec![]) } else { diff --git a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt index 4835ebb7816d..2a6fc6914a97 100644 --- a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt +++ b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt @@ -418,7 +418,7 @@ Projection: subquery.c_custkey, subquery.c_nationkey, subquery.price_rank, natio ----TableScan: nation projection=[n_nationkey, n_name] physical_plan ProjectionExec: expr=[c_custkey@0 as c_custkey, c_nationkey@1 as c_nationkey, price_rank@3 as price_rank, n_name@5 as n_name] ---SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(c_address@2, n_name@1)], filter=n_nationkey@2 > c_custkey@0 AND price_rank@1 > n_nationkey@2 +--SlidingHashJoinExec: join_type=Inner, on=[(c_address@2, n_name@1)], filter=n_nationkey@2 > c_custkey@0 AND price_rank@1 > n_nationkey@2 ----ProjectionExec: expr=[c_custkey@0 as c_custkey, c_nationkey@2 as c_nationkey, c_address@1 as c_address, CAST(ROW_NUMBER() ORDER BY [customer.c_custkey ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 AS Int64) as price_rank] ------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [customer.c_custkey ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [customer.c_custkey ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }], mode=[Sorted] --------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/customer.csv]]}, projection=[c_custkey, c_address, c_nationkey], infinite_source=true, output_ordering=[c_custkey@0 ASC NULLS LAST], has_header=false @@ -525,6 +525,44 @@ LIMIT 10 1 2 1 2 +# Rewrite of filter predicate enables the pruning for nested loop join +query TT +EXPLAIN SELECT + subquery.c_custkey, + subquery.c_nationkey, + subquery.price_rank, + nation.n_name +FROM + ( + SELECT + customer.c_custkey, + customer.c_nationkey, + customer.c_address, + CAST(ROW_NUMBER() OVER(ORDER BY customer.c_custkey) as BIGINT) as price_rank + FROM + customer + ) as subquery +INNER JOIN + nation + ON nation.n_nationkey - subquery.c_custkey > 0 + AND subquery.price_rank > nation.n_nationkey +---- +logical_plan +Projection: subquery.c_custkey, subquery.c_nationkey, subquery.price_rank, nation.n_name +--Inner Join: Filter: nation.n_nationkey - subquery.c_custkey > Int64(0) AND subquery.price_rank > nation.n_nationkey +----SubqueryAlias: subquery +------Projection: customer.c_custkey, customer.c_nationkey, CAST(ROW_NUMBER() ORDER BY [customer.c_custkey ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS Int64) AS price_rank +--------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [customer.c_custkey ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +----------TableScan: customer projection=[c_custkey, c_nationkey] +----TableScan: nation projection=[n_nationkey, n_name] +physical_plan +ProjectionExec: expr=[c_custkey@0 as c_custkey, c_nationkey@1 as c_nationkey, price_rank@2 as price_rank, n_name@4 as n_name] +--SlidingNestedLoopJoinExec: join_type=Inner, filter=n_nationkey@2 > c_custkey@0 AND price_rank@1 > n_nationkey@2 +----ProjectionExec: expr=[c_custkey@0 as c_custkey, c_nationkey@1 as c_nationkey, CAST(ROW_NUMBER() ORDER BY [customer.c_custkey ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 AS Int64) as price_rank] +------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [customer.c_custkey ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [customer.c_custkey ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }], mode=[Sorted] +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/customer.csv]]}, projection=[c_custkey, c_nationkey], infinite_source=true, output_ordering=[c_custkey@0 ASC NULLS LAST], has_header=false +----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/nation.csv]]}, projection=[n_nationkey, n_name], infinite_source=true, output_ordering=[n_nationkey@0 ASC NULLS LAST], has_header=false + # Rewrite of filter predicate enables the pruning. query TT EXPLAIN SELECT @@ -556,7 +594,7 @@ Projection: subquery.c_custkey, subquery.c_nationkey, subquery.price_rank, natio ----TableScan: nation projection=[n_nationkey, n_name] physical_plan ProjectionExec: expr=[c_custkey@0 as c_custkey, c_nationkey@1 as c_nationkey, price_rank@3 as price_rank, n_name@5 as n_name] ---SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(c_address@2, n_name@1)], filter=n_nationkey@2 > c_custkey@0 AND n_nationkey@2 < price_rank@1 +--SlidingHashJoinExec: join_type=Inner, on=[(c_address@2, n_name@1)], filter=n_nationkey@2 > c_custkey@0 AND n_nationkey@2 < price_rank@1 ----ProjectionExec: expr=[c_custkey@0 as c_custkey, c_nationkey@2 as c_nationkey, c_address@1 as c_address, CAST(ROW_NUMBER() ORDER BY [customer.c_custkey ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 AS Int64) as price_rank] ------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [customer.c_custkey ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [customer.c_custkey ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }], mode=[Sorted] --------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/customer.csv]]}, projection=[c_custkey, c_address, c_nationkey], infinite_source=true, output_ordering=[c_custkey@0 ASC NULLS LAST], has_header=false diff --git a/datafusion/sqllogictest/test_files/stream.slt b/datafusion/sqllogictest/test_files/stream.slt index dea32c879700..cee248f6293c 100644 --- a/datafusion/sqllogictest/test_files/stream.slt +++ b/datafusion/sqllogictest/test_files/stream.slt @@ -234,27 +234,84 @@ physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/te # GROUP BY s.sn # ORDER BY s.sn -# s.ts cannot be resolved -# SELECT s.*, s.amount * LAST_VALUE(e.rate ORDER BY e.sn) AS amount_usd -# INTO sales_global_converted -# FROM sales_global AS s -# JOIN exchange_rates AS e -# ON s.currency = e.currency_from AND -# e.currency_to = 'USD' AND -# s.ts >= e.ts -# GROUP BY s.sn -# ORDER BY s.sn +# A temporary table until Order Equivalence problems resolved. The ProjectionExec does not +# yield the output ordering correct in usual tables. (sn-ts order determines which one is under order equivalence.) + +statement ok +CREATE UNBOUNDED EXTERNAL TABLE exchange_rates_temp ( + "ts" TIMESTAMP, + "sn2" INTEGER, + "currency_from" VARCHAR NOT NULL, + "currency_to" VARCHAR NOT NULL, + "rate" FLOAT +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (sn2 ASC) +WITH ORDER (ts ASC) +LOCATION '../core/tests/data/exchange_rates.csv'; + +statement ok +CREATE UNBOUNDED EXTERNAL TABLE sales_global_temp ( + "ts" TIMESTAMP, + "sn" INTEGER, + "amount" INTEGER, + "currency" VARCHAR NOT NULL +) +STORED AS CSV +WITH HEADER ROW +WITH ORDER (sn ASC) +WITH ORDER (ts ASC) +LOCATION '../core/tests/data/sales_global.csv'; + # s.ts cannot be resolved -# SELECT s.*, s.amount * LAST_VALUE(e.rate ORDER BY e.sn) AS amount_usd -# INTO sales_global_converted -# FROM sales_global AS s, -# exchange_rates AS e -# WHERE s.currency = e.currency_from AND -# e.currency_to = 'USD' AND -# s.ts >= e.ts -# GROUP BY s.sn -# ORDER BY s.sn +query TT +EXPLAIN SELECT LAST_VALUE(e.rate ORDER BY e.sn2) AS amount_usd +FROM sales_global_temp AS s, + exchange_rates_temp AS e +WHERE s.currency = e.currency_from AND + e.currency_to = 'USD' AND + s.ts >= e.ts +GROUP BY s.sn +ORDER BY s.sn +---- +logical_plan +Projection: amount_usd +--Sort: s.sn ASC NULLS LAST +----Projection: LAST_VALUE(e.rate) ORDER BY [e.sn2 ASC NULLS LAST] AS amount_usd, s.sn +------Aggregate: groupBy=[[s.sn]], aggr=[[LAST_VALUE(e.rate) ORDER BY [e.sn2 ASC NULLS LAST]]] +--------Projection: s.sn, e.sn2, e.rate +----------Inner Join: s.currency = e.currency_from Filter: s.ts >= e.ts +------------SubqueryAlias: s +--------------TableScan: sales_global_temp projection=[ts, sn, currency] +------------SubqueryAlias: e +--------------Projection: exchange_rates_temp.ts, exchange_rates_temp.sn2, exchange_rates_temp.currency_from, exchange_rates_temp.rate +----------------Filter: exchange_rates_temp.currency_to = Utf8("USD") +------------------TableScan: exchange_rates_temp projection=[ts, sn2, currency_from, currency_to, rate], partial_filters=[exchange_rates_temp.currency_to = Utf8("USD")] +physical_plan +ProjectionExec: expr=[amount_usd@0 as amount_usd] +--SortPreservingMergeExec: [sn@1 ASC NULLS LAST] +----ProjectionExec: expr=[LAST_VALUE(e.rate) ORDER BY [e.sn2 ASC NULLS LAST]@1 as amount_usd, sn@0 as sn] +------AggregateExec: mode=FinalPartitioned, gby=[sn@0 as sn], aggr=[LAST_VALUE(e.rate)], ordering_mode=FullyOrdered +--------CoalesceBatchesExec: target_batch_size=8192 +----------SortPreservingRepartitionExec: partitioning=Hash([sn@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[sn@0 as sn], aggr=[LAST_VALUE(e.rate)], ordering_mode=FullyOrdered +--------------ProjectionExec: expr=[sn@1 as sn, sn2@4 as sn2, rate@6 as rate] +----------------ProjectionExec: expr=[ts@4 as ts, sn@5 as sn, currency@6 as currency, ts@0 as ts, sn2@1 as sn2, currency_from@2 as currency_from, rate@3 as rate] +------------------PartitionedHashJoinExec: join_type=Inner, on=[(currency_from@2, currency@2)], filter=ts@0 >= ts@1 +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------SortPreservingRepartitionExec: partitioning=Hash([currency_from@2], 4), input_partitions=4 +------------------------ProjectionExec: expr=[ts@0 as ts, sn2@1 as sn2, currency_from@2 as currency_from, rate@4 as rate] +--------------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------------FilterExec: currency_to@3 = USD +------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/exchange_rates.csv]]}, projection=[ts, sn2, currency_from, currency_to, rate], infinite_source=true, output_ordering=[sn2@1 ASC NULLS LAST], has_header=true +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------SortPreservingRepartitionExec: partitioning=Hash([currency@2], 4), input_partitions=4 +------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/sales_global.csv]]}, projection=[ts, sn, currency], infinite_source=true, output_ordering=[sn@1 ASC NULLS LAST], has_header=true + # Scalar Subquery not supported # SELECT s.*, @@ -301,31 +358,83 @@ physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/te # Outer Joins -# s.ts cannot be resolved -# SELECT s.*, s.amount * AVG(e.rate) AS amount_usd -# INTO sales_global_converted -# FROM sales_global AS s -# LEFT JOIN exchange_rates AS e -# ON s.currency = e.currency_from AND -# e.currency_to = 'USD' AND -# s.ts >= e.ts AND -# e.ts >= s.ts - INTERVAL '10' MINUTE -# GROUP BY s.sn -# ORDER BY s.sn +query TT +EXPLAIN SELECT AVG(e.rate) AS amount_usd +FROM sales_global AS s +LEFT JOIN exchange_rates AS e +ON s.currency = e.currency_from AND + e.currency_to = 'USD' AND + s.ts >= e.ts AND + e.ts >= s.ts - INTERVAL '10' MINUTE +GROUP BY s.sn +ORDER BY s.sn +---- +logical_plan +Projection: amount_usd +--Sort: s.sn ASC NULLS LAST +----Projection: AVG(e.rate) AS amount_usd, s.sn +------Aggregate: groupBy=[[s.sn]], aggr=[[AVG(CAST(e.rate AS Float64))]] +--------Projection: s.sn, e.rate +----------Left Join: s.currency = e.currency_from Filter: s.ts >= e.ts AND e.ts >= s.ts - IntervalMonthDayNano("600000000000") +------------SubqueryAlias: s +--------------TableScan: sales_global projection=[ts, sn, currency] +------------SubqueryAlias: e +--------------Projection: exchange_rates.ts, exchange_rates.currency_from, exchange_rates.rate +----------------Filter: exchange_rates.currency_to = Utf8("USD") +------------------TableScan: exchange_rates projection=[ts, currency_from, currency_to, rate], partial_filters=[exchange_rates.currency_to = Utf8("USD")] +physical_plan +ProjectionExec: expr=[amount_usd@0 as amount_usd] +--SortPreservingMergeExec: [sn@1 ASC NULLS LAST] +----ProjectionExec: expr=[AVG(e.rate)@1 as amount_usd, sn@0 as sn] +------AggregateExec: mode=FinalPartitioned, gby=[sn@0 as sn], aggr=[AVG(e.rate)], ordering_mode=FullyOrdered +--------CoalesceBatchesExec: target_batch_size=8192 +----------SortPreservingRepartitionExec: partitioning=Hash([sn@0], 4), input_partitions=4 +------------AggregateExec: mode=Partial, gby=[sn@0 as sn], aggr=[AVG(e.rate)], ordering_mode=FullyOrdered +--------------ProjectionExec: expr=[sn@1 as sn, rate@5 as rate] +----------------ProjectionExec: expr=[ts@3 as ts, sn@4 as sn, currency@5 as currency, ts@0 as ts, currency_from@1 as currency_from, rate@2 as rate] +------------------SlidingHashJoinExec: join_type=Right, on=[(currency_from@1, currency@2)], filter=ts@0 >= ts@1 AND ts@1 >= ts@0 - 600000000000 +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------SortPreservingRepartitionExec: partitioning=Hash([currency_from@1], 4), input_partitions=4 +------------------------ProjectionExec: expr=[ts@0 as ts, currency_from@1 as currency_from, rate@3 as rate] +--------------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------------FilterExec: currency_to@2 = USD +------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/exchange_rates.csv]]}, projection=[ts, currency_from, currency_to, rate], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST], has_header=true +--------------------CoalesceBatchesExec: target_batch_size=8192 +----------------------SortPreservingRepartitionExec: partitioning=Hash([currency@2], 4), input_partitions=4 +------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/sales_global.csv]]}, projection=[ts, sn, currency], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST], has_header=true # Explicit Windows in Joins # Sort is not removed! -# SELECT e.*, -# AVG(rate) OVER sliding_window AS avg_rate -# INTO avg_exchange_rates -# FROM exchange_rates AS e -# WINDOW sliding_window AS ( -# PARTITION BY currency_from, currency_to -# ORDER BY ts RANGE INTERVAL '10' MINUTE PRECEDING -# ) -# ORDER BY sn + +query TT +EXPLAIN SELECT e.*, + AVG(rate) OVER sliding_window AS avg_rate +FROM exchange_rates AS e +WINDOW sliding_window AS ( + PARTITION BY currency_from, currency_to + ORDER BY ts RANGE INTERVAL '10' MINUTE PRECEDING +) +ORDER BY sn +---- +logical_plan +Sort: e.sn ASC NULLS LAST +--Projection: e.ts, e.sn, e.currency_from, e.currency_to, e.rate, AVG(e.rate) PARTITION BY [e.currency_from, e.currency_to] ORDER BY [e.ts ASC NULLS LAST] RANGE BETWEEN 10 MINUTE PRECEDING AND CURRENT ROW AS avg_rate +----WindowAggr: windowExpr=[[AVG(CAST(e.rate AS Float64)) PARTITION BY [e.currency_from, e.currency_to] ORDER BY [e.ts ASC NULLS LAST] RANGE BETWEEN 600000000000 PRECEDING AND CURRENT ROW AS AVG(e.rate) PARTITION BY [e.currency_from, e.currency_to] ORDER BY [e.ts ASC NULLS LAST] RANGE BETWEEN 10 MINUTE PRECEDING AND CURRENT ROW]] +------SubqueryAlias: e +--------TableScan: exchange_rates projection=[ts, sn, currency_from, currency_to, rate] +physical_plan +SortPreservingMergeExec: [ts@0 ASC NULLS LAST] +--ProjectionExec: expr=[ts@0 as ts, sn@1 as sn, currency_from@2 as currency_from, currency_to@3 as currency_to, rate@4 as rate, AVG(e.rate) PARTITION BY [e.currency_from, e.currency_to] ORDER BY [e.ts ASC NULLS LAST] RANGE BETWEEN 10 MINUTE PRECEDING AND CURRENT ROW@5 as avg_rate] +----BoundedWindowAggExec: wdw=[AVG(e.rate) PARTITION BY [e.currency_from, e.currency_to] ORDER BY [e.ts ASC NULLS LAST] RANGE BETWEEN 10 MINUTE PRECEDING AND CURRENT ROW: Ok(Field { name: "AVG(e.rate) PARTITION BY [e.currency_from, e.currency_to] ORDER BY [e.ts ASC NULLS LAST] RANGE BETWEEN 10 MINUTE PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(IntervalMonthDayNano("600000000000")), end_bound: CurrentRow }], mode=[Linear] +------CoalesceBatchesExec: target_batch_size=8192 +--------SortPreservingRepartitionExec: partitioning=Hash([currency_from@2, currency_to@3], 4), input_partitions=4 +----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/exchange_rates.csv]]}, projection=[ts, sn, currency_from, currency_to, rate], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST], has_header=true + # Machine Learning via User-Defined Functions (UDFs) @@ -533,7 +642,7 @@ physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/te # FROM agent_contexts; -# Example: Cybersecurity +################## Example: Cybersecurity ################## # SELECT e.*, # NOVELTY_DETECTION_MODEL('model-name', ts, e.vector) OVER running_window AS model @@ -557,4 +666,4 @@ physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/te # PARTITION BY contract_id # ORDER BY ts RANGE INTERVAL '1' DAY PRECEDING # ) -# ORDER BY sn \ No newline at end of file +# ORDER BY sn diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index bedab8f398e6..19682d0f8267 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -146,6 +146,7 @@ datafusion/physical-plan/src/joins/sliding_window_join_utils.rs datafusion/physical-plan/src/joins/sliding_nested_loop_join.rs datafusion/physical-plan/src/joins/sliding_hash_join.rs datafusion/physical-plan/src/joins/utils.rs +datafusion/physical-plan/src/joins/partitioned_hash_join.rs force_push_main.sh generate_synnada_commits.sh datafusion/sqllogictest/test_files/stream.slt