Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 88 additions & 16 deletions datafusion/core/src/physical_optimizer/join_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
//! pipeline-friendly ones. To achieve the second goal, it selects the proper
//! `PartitionMode` and the build side using the available statistics for hash joins.

use std::sync::Arc;

use crate::config::ConfigOptions;
use crate::error::Result;
use crate::physical_optimizer::PhysicalOptimizerRule;
Expand All @@ -35,6 +33,7 @@ use crate::physical_plan::joins::{
};
use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
use std::sync::Arc;

use arrow_schema::Schema;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
Expand Down Expand Up @@ -140,20 +139,32 @@ fn swap_join_projection(
left_schema_len: usize,
right_schema_len: usize,
projection: Option<&Vec<usize>>,
join_type: &JoinType,
) -> Option<Vec<usize>> {
projection.map(|p| {
p.iter()
.map(|i| {
// If the index is less than the left schema length, it is from the left schema, so we add the right schema length to it.
// Otherwise, it is from the right schema, so we subtract the left schema length from it.
if *i < left_schema_len {
*i + right_schema_len
} else {
*i - left_schema_len
}
})
.collect()
})
match join_type {
// For Anti/Semi join types, projection should remain unmodified,
// since these joins output schema remains the same after swap
JoinType::LeftAnti
| JoinType::LeftSemi
| JoinType::RightAnti
| JoinType::RightSemi => projection.cloned(),

_ => projection.map(|p| {
p.iter()
.map(|i| {
// If the index is less than the left schema length, it is from
// the left schema, so we add the right schema length to it.
// Otherwise, it is from the right schema, so we subtract the left
// schema length from it.
if *i < left_schema_len {
*i + right_schema_len
} else {
*i - left_schema_len
}
})
.collect()
}),
}
}

/// This function swaps the inputs of the given join operator.
Expand All @@ -179,6 +190,7 @@ pub fn swap_hash_join(
left.schema().fields().len(),
right.schema().fields().len(),
hash_join.projection.as_ref(),
hash_join.join_type(),
),
partition_mode,
hash_join.null_equals_null(),
Expand All @@ -189,7 +201,8 @@ pub fn swap_hash_join(
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
) {
) || hash_join.projection.is_some()
{
Ok(Arc::new(new_join))
} else {
// TODO avoid adding ProjectionExec again and again, only adding Final Projection
Expand Down Expand Up @@ -1158,6 +1171,65 @@ mod tests_statistical {
);
}

#[rstest(
join_type, projection, small_on_right,
case::inner(JoinType::Inner, vec![1], true),
case::left(JoinType::Left, vec![1], true),
case::right(JoinType::Right, vec![1], true),
case::full(JoinType::Full, vec![1], true),
case::left_anti(JoinType::LeftAnti, vec![0], false),
case::left_semi(JoinType::LeftSemi, vec![0], false),
case::right_anti(JoinType::RightAnti, vec![0], true),
case::right_semi(JoinType::RightSemi, vec![0], true),
)]
#[tokio::test]
async fn test_hash_join_swap_on_joins_with_projections(
join_type: JoinType,
projection: Vec<usize>,
small_on_right: bool,
) -> Result<()> {
let (big, small) = create_big_and_small();

let left = if small_on_right { &big } else { &small };
let right = if small_on_right { &small } else { &big };

let left_on = if small_on_right {
"big_col"
} else {
"small_col"
};
let right_on = if small_on_right {
"small_col"
} else {
"big_col"
};

let join = Arc::new(HashJoinExec::try_new(
Arc::clone(left),
Arc::clone(right),
vec![(
Arc::new(Column::new_with_schema(left_on, &left.schema())?),
Arc::new(Column::new_with_schema(right_on, &right.schema())?),
)],
None,
&join_type,
Some(projection),
PartitionMode::Partitioned,
false,
)?);

let swapped = swap_hash_join(&join.clone(), PartitionMode::Partitioned)
.expect("swap_hash_join must support joins with projections");
let swapped_join = swapped.as_any().downcast_ref::<HashJoinExec>().expect(
"ProjectionExec won't be added above if HashJoinExec contains embedded projection",
);

assert_eq!(swapped_join.projection, Some(vec![0_usize]));
assert_eq!(swapped.schema().fields.len(), 1);
assert_eq!(swapped.schema().fields[0].name(), "small_col");
Ok(())
}

#[rstest(
join_type,
case::inner(JoinType::Inner),
Expand Down