Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove logical cross join in planning #12985

Merged
merged 23 commits into from
Oct 18, 2024
22 changes: 13 additions & 9 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ use datafusion_expr::expr::{
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::{
DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, SortExpr,
DescribeTable, DmlStatement, Extension, Filter, JoinType, RecursiveQuery, SortExpr,
StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp,
};
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
Expand Down Expand Up @@ -1045,14 +1045,18 @@ impl DefaultPhysicalPlanner {
session_state.config_options().optimizer.prefer_hash_join;

let join: Arc<dyn ExecutionPlan> = if join_on.is_empty() {
// there is no equal join condition, use the nested loop join
// TODO optimize the plan, and use the config of `target_partitions` and `repartition_joins`
Arc::new(NestedLoopJoinExec::try_new(
physical_left,
physical_right,
join_filter,
join_type,
)?)
if join_filter.is_none() && matches!(join_type, JoinType::Inner) {
// cross join if there is no join conditions and no join filter set
Arc::new(CrossJoinExec::new(physical_left, physical_right))
} else {
// there is no equal join condition, use the nested loop join
Arc::new(NestedLoopJoinExec::try_new(
physical_left,
physical_right,
join_filter,
join_type,
)?)
}
} else if session_state.config().target_partitions() > 1
&& session_state.config().repartition_joins()
&& !prefer_hash_join
Expand Down
11 changes: 8 additions & 3 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use crate::expr_rewriter::{
rewrite_sort_cols_by_aggs,
};
use crate::logical_plan::{
Aggregate, Analyze, CrossJoin, Distinct, DistinctOn, EmptyRelation, Explain, Filter,
Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare,
Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join,
JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare,
Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values,
Window,
};
Expand Down Expand Up @@ -950,9 +950,14 @@ impl LogicalPlanBuilder {
pub fn cross_join(self, right: LogicalPlan) -> Result<Self> {
let join_schema =
build_join_schema(self.plan.schema(), right.schema(), &JoinType::Inner)?;
Ok(Self::new(LogicalPlan::CrossJoin(CrossJoin {
Ok(Self::new(LogicalPlan::Join(Join {
left: self.plan,
right: Arc::new(right),
on: vec![],
filter: None,
join_type: JoinType::Inner,
join_constraint: JoinConstraint::On,
null_equals_null: false,
schema: DFSchemaRef::new(join_schema),
})))
}
Expand Down
6 changes: 6 additions & 0 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ pub enum LogicalPlan {
Join(Join),
/// Apply Cross Join to two logical plans.
/// This is used to implement SQL `CROSS JOIN`
/// Deprecated: use [LogicalPlan::Join] instead with empty `on` / no filter
CrossJoin(CrossJoin),
/// Repartitions the input based on a partitioning scheme. This is
/// used to add parallelism and is sometimes referred to as an
Expand Down Expand Up @@ -1873,6 +1874,11 @@ impl LogicalPlan {
.as_ref()
.map(|expr| format!(" Filter: {expr}"))
.unwrap_or_else(|| "".to_string());
let join_type = if filter.is_none() && keys.is_empty() && matches!(join_type, JoinType::Inner) {
"Cross".to_string()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we extend JoinType enum to support Cross?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it’s better not to, otherwise it will be similar to having LogicalPlan::CrossJoin (I.e. unnecesary).

} else {
join_type.to_string()
};
match join_constraint {
JoinConstraint::On => {
write!(
Expand Down
25 changes: 15 additions & 10 deletions datafusion/optimizer/src/eliminate_cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{internal_err, Result};
use datafusion_expr::expr::{BinaryExpr, Expr};
use datafusion_expr::logical_plan::{
CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection,
};
use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
use datafusion_expr::{build_join_schema, ExprSchemable, Operator};
Expand All @@ -51,7 +51,7 @@ impl EliminateCrossJoin {
/// Looks like this:
/// ```text
/// Filter(a.x = b.y AND b.xx = 100)
/// CrossJoin
/// Cross Join
/// TableScan a
/// TableScan b
/// ```
Expand Down Expand Up @@ -351,10 +351,15 @@ fn find_inner_join(
&JoinType::Inner,
)?);

Ok(LogicalPlan::CrossJoin(CrossJoin {
Ok(LogicalPlan::Join(Join {
left: Arc::new(left_input),
right: Arc::new(right),
schema: join_schema,
on: vec![],
filter: None,
join_type: JoinType::Inner,
join_constraint: JoinConstraint::On,
null_equals_null: false,
}))
}

Expand Down Expand Up @@ -513,7 +518,7 @@ mod tests {

let expected = vec![
"Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
Expand Down Expand Up @@ -601,7 +606,7 @@ mod tests {

let expected = vec![
"Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
Expand All @@ -627,7 +632,7 @@ mod tests {

let expected = vec![
"Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
Expand Down Expand Up @@ -843,7 +848,7 @@ mod tests {

let expected = vec![
"Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
Expand Down Expand Up @@ -924,7 +929,7 @@ mod tests {
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
" Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]",
];
Expand Down Expand Up @@ -999,7 +1004,7 @@ mod tests {
"Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
" Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
Expand Down Expand Up @@ -1238,7 +1243,7 @@ mod tests {

let expected = vec![
"Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]",
" TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]",
];
Expand Down
26 changes: 1 addition & 25 deletions datafusion/optimizer/src/eliminate_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use datafusion_common::{Result, ScalarValue};
use datafusion_expr::JoinType::Inner;
use datafusion_expr::{
logical_plan::{EmptyRelation, LogicalPlan},
CrossJoin, Expr,
Expr,
};

/// Eliminates joins when join condition is false.
Expand Down Expand Up @@ -54,13 +54,6 @@ impl OptimizerRule for EliminateJoin {
match plan {
LogicalPlan::Join(join) if join.join_type == Inner && join.on.is_empty() => {
match join.filter {
Some(Expr::Literal(ScalarValue::Boolean(Some(true)))) => {
Ok(Transformed::yes(LogicalPlan::CrossJoin(CrossJoin {
left: join.left,
right: join.right,
schema: join.schema,
})))
}
Some(Expr::Literal(ScalarValue::Boolean(Some(false)))) => Ok(
Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
Expand Down Expand Up @@ -105,21 +98,4 @@ mod tests {
let expected = "EmptyRelation";
assert_optimized_plan_equal(plan, expected)
}

#[test]
fn join_on_true() -> Result<()> {
let plan = LogicalPlanBuilder::empty(false)
.join_on(
LogicalPlanBuilder::empty(false).build()?,
Inner,
Some(lit(true)),
)?
.build()?;

let expected = "\
CrossJoin:\
\n EmptyRelation\
\n EmptyRelation";
assert_optimized_plan_equal(plan, expected)
}
}
4 changes: 2 additions & 2 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1727,7 +1727,7 @@ mod tests {
.build()?;

let expected = "Projection: test.a, test1.d\
\n CrossJoin:\
\n Cross Join: \
\n Projection: test.a, test.b, test.c\
\n TableScan: test, full_filters=[test.a = Int32(1)]\
\n Projection: test1.d, test1.e, test1.f\
Expand All @@ -1754,7 +1754,7 @@ mod tests {
.build()?;

let expected = "Projection: test.a, test1.a\
\n CrossJoin:\
\n Cross Join: \
\n Projection: test.a, test.b, test.c\
\n TableScan: test, full_filters=[test.a = Int32(1)]\
\n Projection: test1.a, test1.b, test1.c\
Expand Down
7 changes: 3 additions & 4 deletions datafusion/optimizer/src/push_down_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,9 @@ fn push_down_join(mut join: Join, limit: usize) -> Transformed<Join> {

let (left_limit, right_limit) = if is_no_join_condition(&join) {
match join.join_type {
Left | Right | Full => (Some(limit), Some(limit)),
Left | Right | Full | Inner => (Some(limit), Some(limit)),
LeftAnti | LeftSemi => (Some(limit), None),
RightAnti | RightSemi => (None, Some(limit)),
Inner => (None, None),
}
} else {
match join.join_type {
Expand Down Expand Up @@ -1116,7 +1115,7 @@ mod test {
.build()?;

let expected = "Limit: skip=0, fetch=1000\
\n CrossJoin:\
\n Cross Join: \
\n Limit: skip=0, fetch=1000\
\n TableScan: test, fetch=1000\
\n Limit: skip=0, fetch=1000\
Expand All @@ -1136,7 +1135,7 @@ mod test {
.build()?;

let expected = "Limit: skip=1000, fetch=1000\
\n CrossJoin:\
\n Cross Join: \
\n Limit: skip=0, fetch=2000\
\n TableScan: test, fetch=2000\
\n Limit: skip=0, fetch=2000\
Expand Down
4 changes: 3 additions & 1 deletion datafusion/sql/src/relation/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
.build()
}
}
JoinConstraint::None => not_impl_err!("NONE constraint is not supported"),
JoinConstraint::None => LogicalPlanBuilder::from(left)
.join_on(right, join_type, [])?
.build(),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ fn roundtrip_crossjoin() -> Result<()> {
.unwrap();

let expected = "Projection: j1.j1_id, j2.j2_string\
\n Inner Join: Filter: Boolean(true)\
\n Cross Join: \
\n TableScan: j1\
\n TableScan: j2";

Expand Down
Loading