Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support input reordering for NestedLoopJoinExec #9676

Merged
merged 8 commits into from
Apr 22, 2024
209 changes: 206 additions & 3 deletions datafusion/core/src/physical_optimizer/join_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use crate::error::Result;
use crate::physical_optimizer::PhysicalOptimizerRule;
use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
use crate::physical_plan::joins::{
CrossJoinExec, HashJoinExec, PartitionMode, StreamJoinPartitionMode,
SymmetricHashJoinExec,
CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode,
StreamJoinPartitionMode, SymmetricHashJoinExec,
};
use crate::physical_plan::projection::ProjectionExec;
use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties};
Expand Down Expand Up @@ -199,6 +199,38 @@ fn swap_hash_join(
}
}

/// Swaps inputs of `NestedLoopJoinExec` and wraps it into `ProjectionExec` is required
fn swap_nl_join(join: &NestedLoopJoinExec) -> Result<Arc<dyn ExecutionPlan>> {
let new_filter = swap_join_filter(join.filter());
let new_join_type = &swap_join_type(*join.join_type());

let new_join = NestedLoopJoinExec::try_new(
Arc::clone(join.right()),
Arc::clone(join.left()),
new_filter,
new_join_type,
)?;

// For Semi/Anti joins, swap result will produce same output schema,
// no need to wrap them into additional projection
let plan: Arc<dyn ExecutionPlan> = if matches!(
join.join_type(),
JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
) {
Arc::new(new_join)
} else {
let projection =
swap_reverting_projection(&join.left().schema(), &join.right().schema());

Arc::new(ProjectionExec::try_new(projection, Arc::new(new_join))?)
};

Ok(plan)
}

/// When the order of the join is changed by the optimizer, the columns in
/// the output should not be impacted. This function creates the expressions
/// that will allow to swap back the values from the original left as the first
Expand Down Expand Up @@ -461,6 +493,14 @@ fn statistical_join_selection_subrule(
} else {
None
}
} else if let Some(nl_join) = plan.as_any().downcast_ref::<NestedLoopJoinExec>() {
let left = nl_join.left();
let right = nl_join.right();
if should_swap_join_order(&**left, &**right)? {
swap_nl_join(nl_join).map(Some)?
} else {
None
}
} else {
None
};
Expand Down Expand Up @@ -697,9 +737,12 @@ mod tests_statistical {

use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{stats::Precision, JoinType, ScalarValue};
use datafusion_physical_expr::expressions::Column;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, Column};
use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef};

use rstest::rstest;

/// Return statistcs for empty table
fn empty_statistics() -> Statistics {
Statistics {
Expand Down Expand Up @@ -785,6 +828,35 @@ mod tests_statistical {
}]
}

/// Create join filter for NLJoinExec with expression `big_col > small_col`
/// where both columns are 0-indexed and come from left and right inputs respectively
fn nl_join_filter() -> Option<JoinFilter> {
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
];
let intermediate_schema = Schema::new(vec![
Field::new("big_col", DataType::Int32, false),
Field::new("small_col", DataType::Int32, false),
]);
let expression = Arc::new(BinaryExpr::new(
Arc::new(Column::new_with_schema("big_col", &intermediate_schema).unwrap()),
Operator::Gt,
Arc::new(Column::new_with_schema("big_col", &intermediate_schema).unwrap()),
korowa marked this conversation as resolved.
Show resolved Hide resolved
)) as _;
Some(JoinFilter::new(
expression,
column_indices,
intermediate_schema,
))
}

/// Returns three plans with statistics of (min, max, distinct_count)
/// * big 100K rows @ (0, 50k, 50k)
/// * medium 10K rows @ (1k, 5k, 1k)
Expand Down Expand Up @@ -1151,6 +1223,137 @@ mod tests_statistical {
crosscheck_plans(join).unwrap();
}

#[rstest(
join_type,
case::inner(JoinType::Inner),
case::left(JoinType::Left),
case::right(JoinType::Right),
case::full(JoinType::Full)
)]
#[tokio::test]
async fn test_nl_join_with_swap(join_type: JoinType) {
let (big, small) = create_big_and_small();

let join = Arc::new(
NestedLoopJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
nl_join_filter(),
&join_type,
)
.unwrap(),
);

let optimized_join = JoinSelection::new()
.optimize(join.clone(), &ConfigOptions::new())
.unwrap();

let swapping_projection = optimized_join
.as_any()
.downcast_ref::<ProjectionExec>()
.expect("A proj is required to swap columns back to their original order");

assert_eq!(swapping_projection.expr().len(), 2);
let (col, name) = &swapping_projection.expr()[0];
assert_eq!(name, "big_col");
assert_col_expr(col, "big_col", 1);
let (col, name) = &swapping_projection.expr()[1];
assert_eq!(name, "small_col");
assert_col_expr(col, "small_col", 0);

let swapped_join = swapping_projection
.input()
.as_any()
.downcast_ref::<NestedLoopJoinExec>()
.expect("The type of the plan should not be changed");

// Assert join side of big_col swapped in filter expression
let swapped_filter = swapped_join.filter().unwrap();
let swapped_big_col_idx = swapped_filter.schema().index_of("big_col").unwrap();
let swapped_big_col_side = swapped_filter
.column_indices()
.get(swapped_big_col_idx)
.unwrap()
.side;
assert_eq!(
swapped_big_col_side,
JoinSide::Right,
"Filter column side should be swapped"
);

assert_eq!(
swapped_join.left().statistics().unwrap().total_byte_size,
Precision::Inexact(8192)
);
assert_eq!(
swapped_join.right().statistics().unwrap().total_byte_size,
Precision::Inexact(2097152)
);
crosscheck_plans(join.clone()).unwrap();
}

#[rstest(
join_type,
case::left_semi(JoinType::LeftSemi),
case::left_anti(JoinType::LeftAnti),
case::right_semi(JoinType::RightSemi),
case::right_anti(JoinType::RightAnti)
)]
#[tokio::test]
async fn test_nl_join_with_swap_no_proj(join_type: JoinType) {
let (big, small) = create_big_and_small();

let join = Arc::new(
NestedLoopJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
nl_join_filter(),
&join_type,
)
.unwrap(),
);

let optimized_join = JoinSelection::new()
.optimize(join.clone(), &ConfigOptions::new())
.unwrap();

let swapped_join = optimized_join
.as_any()
.downcast_ref::<NestedLoopJoinExec>()
.expect("The type of the plan should not be changed");

// Assert before/after schemas are equal
assert_eq!(
join.schema(),
swapped_join.schema(),
"Join schema should not be modified while optimization"
);

// Assert join side of big_col swapped in filter expression
let swapped_filter = swapped_join.filter().unwrap();
let swapped_big_col_idx = swapped_filter.schema().index_of("big_col").unwrap();
let swapped_big_col_side = swapped_filter
.column_indices()
.get(swapped_big_col_idx)
.unwrap()
.side;
assert_eq!(
swapped_big_col_side,
JoinSide::Right,
"Filter column side should be swapped"
);

assert_eq!(
swapped_join.left().statistics().unwrap().total_byte_size,
Precision::Inexact(8192)
);
assert_eq!(
swapped_join.right().statistics().unwrap().total_byte_size,
Precision::Inexact(2097152)
);
crosscheck_plans(join.clone()).unwrap();
}

#[tokio::test]
async fn test_swap_reverting_projection() {
let left_schema = Schema::new(vec![
Expand Down
89 changes: 84 additions & 5 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,19 @@ use arrow::array::{ArrayRef, Int32Array};
use arrow::compute::SortOptions;
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use arrow_schema::Schema;
use rand::Rng;

use datafusion::common::JoinSide;
use datafusion::logical_expr::{JoinType, Operator};
use datafusion::physical_expr::expressions::BinaryExpr;
use datafusion::physical_plan::collect;
use datafusion::physical_plan::expressions::Column;
use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec};
use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
use datafusion::physical_plan::joins::{
HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec,
};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion_expr::JoinType;

use datafusion::prelude::{SessionConfig, SessionContext};
use test_utils::stagger_batch_with_seed;
Expand Down Expand Up @@ -73,7 +79,7 @@ async fn test_full_join_1k() {
}

#[tokio::test]
async fn test_semi_join_1k() {
async fn test_semi_join_10k() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a nice drive by cleanup

run_join_test(
make_staggered_batches(10000),
make_staggered_batches(10000),
Expand All @@ -83,7 +89,7 @@ async fn test_semi_join_1k() {
}

#[tokio::test]
async fn test_anti_join_1k() {
async fn test_anti_join_10k() {
run_join_test(
make_staggered_batches(10000),
make_staggered_batches(10000),
Expand Down Expand Up @@ -118,6 +124,46 @@ async fn run_join_test(
),
];

// Nested loop join uses filter for joining records
let column_indices = vec![
ColumnIndex {
index: 0,
side: JoinSide::Left,
},
ColumnIndex {
index: 1,
side: JoinSide::Left,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
ColumnIndex {
index: 1,
side: JoinSide::Right,
},
];
let intermediate_schema = Schema::new(vec![
schema1.field_with_name("a").unwrap().to_owned(),
schema1.field_with_name("b").unwrap().to_owned(),
schema2.field_with_name("a").unwrap().to_owned(),
schema2.field_with_name("b").unwrap().to_owned(),
]);

let equal_a = Arc::new(BinaryExpr::new(
Arc::new(Column::new("a", 0)),
Operator::Eq,
Arc::new(Column::new("a", 2)),
)) as _;
let equal_b = Arc::new(BinaryExpr::new(
Arc::new(Column::new("b", 1)),
Operator::Eq,
Arc::new(Column::new("b", 3)),
)) as _;
let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b)) as _;

let on_filter = JoinFilter::new(expression, column_indices, intermediate_schema);

// sort-merge join
let left = Arc::new(
MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(),
Expand Down Expand Up @@ -161,22 +207,55 @@ async fn run_join_test(
);
let hj_collected = collect(hj, task_ctx.clone()).await.unwrap();

// nested loop join
let left = Arc::new(
MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(),
);
let right = Arc::new(
MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(),
);
let nlj = Arc::new(
NestedLoopJoinExec::try_new(left, right, Some(on_filter), &join_type)
.unwrap(),
);
let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap();

// compare
let smj_formatted = pretty_format_batches(&smj_collected).unwrap().to_string();
let hj_formatted = pretty_format_batches(&hj_collected).unwrap().to_string();
let nlj_formatted = pretty_format_batches(&nlj_collected).unwrap().to_string();

let mut smj_formatted_sorted: Vec<&str> = smj_formatted.trim().lines().collect();
smj_formatted_sorted.sort_unstable();

let mut hj_formatted_sorted: Vec<&str> = hj_formatted.trim().lines().collect();
hj_formatted_sorted.sort_unstable();

let mut nlj_formatted_sorted: Vec<&str> = nlj_formatted.trim().lines().collect();
nlj_formatted_sorted.sort_unstable();

for (i, (smj_line, hj_line)) in smj_formatted_sorted
.iter()
.zip(&hj_formatted_sorted)
.enumerate()
{
assert_eq!((i, smj_line), (i, hj_line));
assert_eq!(
(i, smj_line),
(i, hj_line),
"SortMergeJoinExec and HashJoinExec produced different results"
);
}

for (i, (nlj_line, hj_line)) in nlj_formatted_sorted
.iter()
.zip(&hj_formatted_sorted)
.enumerate()
{
assert_eq!(
(i, nlj_line),
(i, hj_line),
"NestedLoopJoinExec and HashJoinExec produced different results"
);
}
}
}
Expand Down
Loading
Loading