-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support type coercion for equijoin #4666
Changes from all commits
0b6d258
1558b1d
4c9129d
0460ebc
e73faab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -1448,11 +1448,11 @@ async fn hash_join_with_decimal() -> Result<()> { | |||||
let state = ctx.state(); | ||||||
let plan = state.optimize(&plan)?; | ||||||
let expected = vec![ | ||||||
"Explain [plan_type:Utf8, plan:Utf8]", | ||||||
" Projection: t1.c1, t1.c2, t1.c3, t1.c4, t2.c1, t2.c2, t2.c3, t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", | ||||||
" Right Join: t1.c3 = t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", | ||||||
" TableScan: t1 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N]", | ||||||
" TableScan: t2 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", | ||||||
"Explain [plan_type:Utf8, plan:Utf8]", | ||||||
" Projection: t1.c1, t1.c2, t1.c3, t1.c4, t2.c1, t2.c2, t2.c3, t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", | ||||||
" Right Join: CAST(t1.c3 AS Decimal128(10, 2)) = t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", | ||||||
" TableScan: t1 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N]", | ||||||
" TableScan: t2 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]", | ||||||
]; | ||||||
let formatted = plan.display_indent_schema().to_string(); | ||||||
let actual: Vec<&str> = formatted.trim().lines().collect(); | ||||||
|
@@ -1982,19 +1982,22 @@ async fn sort_merge_join_on_decimal() -> Result<()> { | |||||
let state = ctx.state(); | ||||||
let logical_plan = state.optimize(&plan)?; | ||||||
let physical_plan = state.create_physical_plan(&logical_plan).await?; | ||||||
|
||||||
let expected = vec![ | ||||||
"ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c1@4 as c1, c2@5 as c2, c3@6 as c3, c4@7 as c4]", | ||||||
" SortMergeJoin: join_type=Right, on=[(Column { name: \"c3\", index: 2 }, Column { name: \"c3\", index: 2 })]", | ||||||
" SortExec: [c3@2 ASC]", | ||||||
" CoalesceBatchesExec: target_batch_size=4096", | ||||||
" RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)", | ||||||
" RepartitionExec: partitioning=RoundRobinBatch(2)", | ||||||
" MemoryExec: partitions=1, partition_sizes=[1]", | ||||||
" SortExec: [c3@2 ASC]", | ||||||
" CoalesceBatchesExec: target_batch_size=4096", | ||||||
" RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)", | ||||||
" RepartitionExec: partitioning=RoundRobinBatch(2)", | ||||||
" MemoryExec: partitions=1, partition_sizes=[1]", | ||||||
" ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c1@5 as c1, c2@6 as c2, c3@7 as c3, c4@8 as c4]", | ||||||
" SortMergeJoin: join_type=Right, on=[(Column { name: \"CAST(t1.c3 AS Decimal128(10, 2))\", index: 4 }, Column { name: \"c3\", index: 2 })]", | ||||||
" SortExec: [CAST(t1.c3 AS Decimal128(10, 2))@4 ASC]", | ||||||
" CoalesceBatchesExec: target_batch_size=4096", | ||||||
" RepartitionExec: partitioning=Hash([Column { name: \"CAST(t1.c3 AS Decimal128(10, 2))\", index: 4 }], 2)", | ||||||
" ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, CAST(c3@2 AS Decimal128(10, 2)) as CAST(t1.c3 AS Decimal128(10, 2))]", | ||||||
" RepartitionExec: partitioning=RoundRobinBatch(2)", | ||||||
" MemoryExec: partitions=1, partition_sizes=[1]", | ||||||
" SortExec: [c3@2 ASC]", | ||||||
" CoalesceBatchesExec: target_batch_size=4096", | ||||||
" RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)", | ||||||
" RepartitionExec: partitioning=RoundRobinBatch(2)", | ||||||
" MemoryExec: partitions=1, partition_sizes=[1]", | ||||||
]; | ||||||
let formatted = displayable(physical_plan.as_ref()).indent().to_string(); | ||||||
let actual: Vec<&str> = formatted.trim().lines().collect(); | ||||||
|
@@ -2776,3 +2779,140 @@ async fn select_wildcard_with_expr_key_inner_join() -> Result<()> { | |||||
|
||||||
Ok(()) | ||||||
} | ||||||
|
||||||
#[tokio::test] | ||||||
async fn join_with_type_coercion_for_equi_expr() -> Result<()> { | ||||||
let ctx = create_join_context("t1_id", "t2_id", false)?; | ||||||
|
||||||
let sql = "select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on t1.t1_id + 11 = t2.t2_id"; | ||||||
|
||||||
// assert logical plan | ||||||
let msg = format!("Creating logical plan for '{}'", sql); | ||||||
let plan = ctx | ||||||
.create_logical_plan(&("explain ".to_owned() + sql)) | ||||||
.expect(&msg); | ||||||
let state = ctx.state(); | ||||||
let plan = state.optimize(&plan)?; | ||||||
|
||||||
let expected = vec![ | ||||||
"Explain [plan_type:Utf8, plan:Utf8]", | ||||||
" Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", | ||||||
" Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think eventually it would be great to have these casts unwrapped too, like
Suggested change
To avoid the runtime casting I am not quite sure why https://github.com/apache/arrow-datafusion/blob/master/datafusion/optimizer/src/unwrap_cast_in_comparison.rs is not doing so There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This rule just can apply to the pattern
cc @alamb We can file a new issue to discuss this. I must point out a problem about overflow for |
||||||
" TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", | ||||||
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", | ||||||
]; | ||||||
|
||||||
let formatted = plan.display_indent_schema().to_string(); | ||||||
let actual: Vec<&str> = formatted.trim().lines().collect(); | ||||||
assert_eq!( | ||||||
expected, actual, | ||||||
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", | ||||||
expected, actual | ||||||
); | ||||||
|
||||||
let expected = vec![ | ||||||
"+-------+---------+-------+", | ||||||
"| t1_id | t1_name | t2_id |", | ||||||
"+-------+---------+-------+", | ||||||
"| 11 | a | 22 |", | ||||||
"| 33 | c | 44 |", | ||||||
"| 44 | d | 55 |", | ||||||
"+-------+---------+-------+", | ||||||
]; | ||||||
|
||||||
let results = execute_to_batches(&ctx, sql).await; | ||||||
assert_batches_sorted_eq!(expected, &results); | ||||||
|
||||||
Ok(()) | ||||||
} | ||||||
|
||||||
#[tokio::test] | ||||||
async fn join_only_with_filter() -> Result<()> { | ||||||
let ctx = create_join_context("t1_id", "t2_id", false)?; | ||||||
|
||||||
let sql = "select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on t1.t1_id * 4 < t2.t2_id"; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. after #4562 merged, the plan will be converted to NLJ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated in e73faab |
||||||
|
||||||
// assert logical plan | ||||||
let msg = format!("Creating logical plan for '{}'", sql); | ||||||
let plan = ctx | ||||||
.create_logical_plan(&("explain ".to_owned() + sql)) | ||||||
.expect(&msg); | ||||||
let state = ctx.state(); | ||||||
let plan = state.optimize(&plan)?; | ||||||
|
||||||
let expected = vec![ | ||||||
"Explain [plan_type:Utf8, plan:Utf8]", | ||||||
" Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", | ||||||
" Inner Join: Filter: CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", | ||||||
" TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", | ||||||
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", | ||||||
]; | ||||||
|
||||||
let formatted = plan.display_indent_schema().to_string(); | ||||||
let actual: Vec<&str> = formatted.trim().lines().collect(); | ||||||
assert_eq!( | ||||||
expected, actual, | ||||||
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", | ||||||
expected, actual | ||||||
); | ||||||
|
||||||
let expected = vec![ | ||||||
"+-------+---------+-------+", | ||||||
"| t1_id | t1_name | t2_id |", | ||||||
"+-------+---------+-------+", | ||||||
"| 11 | a | 55 |", | ||||||
"+-------+---------+-------+", | ||||||
]; | ||||||
|
||||||
let results = execute_to_batches(&ctx, sql).await; | ||||||
assert_batches_sorted_eq!(expected, &results); | ||||||
|
||||||
Ok(()) | ||||||
} | ||||||
|
||||||
#[tokio::test] | ||||||
async fn type_coercion_join_with_filter_and_equi_expr() -> Result<()> { | ||||||
let ctx = create_join_context("t1_id", "t2_id", false)?; | ||||||
|
||||||
let sql = "select t1.t1_id, t1.t1_name, t2.t2_id \ | ||||||
from t1 \ | ||||||
inner join t2 \ | ||||||
on t1.t1_id * 5 = t2.t2_id and t1.t1_id * 4 < t2.t2_id"; | ||||||
|
||||||
// assert logical plan | ||||||
let msg = format!("Creating logical plan for '{}'", sql); | ||||||
let plan = ctx | ||||||
.create_logical_plan(&("explain ".to_owned() + sql)) | ||||||
.expect(&msg); | ||||||
let state = ctx.state(); | ||||||
let plan = state.optimize(&plan)?; | ||||||
|
||||||
let expected = vec![ | ||||||
"Explain [plan_type:Utf8, plan:Utf8]", | ||||||
" Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", | ||||||
" Inner Join: CAST(t1.t1_id AS Int64) * Int64(5) = CAST(t2.t2_id AS Int64) Filter: CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||||||
" TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]", | ||||||
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", | ||||||
]; | ||||||
|
||||||
let formatted = plan.display_indent_schema().to_string(); | ||||||
let actual: Vec<&str> = formatted.trim().lines().collect(); | ||||||
assert_eq!( | ||||||
expected, actual, | ||||||
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", | ||||||
expected, actual | ||||||
); | ||||||
|
||||||
let expected = vec![ | ||||||
"+-------+---------+-------+", | ||||||
"| t1_id | t1_name | t2_id |", | ||||||
"+-------+---------+-------+", | ||||||
"| 11 | a | 55 |", | ||||||
"+-------+---------+-------+", | ||||||
]; | ||||||
|
||||||
let results = execute_to_batches(&ctx, sql).await; | ||||||
assert_batches_sorted_eq!(expected, &results); | ||||||
|
||||||
Ok(()) | ||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -253,9 +253,12 @@ impl LogicalPlan { | |
aggr_expr, | ||
.. | ||
}) => group_expr.iter().chain(aggr_expr.iter()).cloned().collect(), | ||
// There are two part of expression for join, equijoin(on) and non-equijoin(filter). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
// 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. | ||
// 2. the second part is non-equijoin(filter). | ||
LogicalPlan::Join(Join { on, filter, .. }) => on | ||
.iter() | ||
.flat_map(|(l, r)| vec![l.clone(), r.clone()]) | ||
.map(|(l, r)| Expr::eq(l.clone(), r.clone())) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the fix, right? It then exposes the Very nice 👍 |
||
.chain( | ||
filter | ||
.as_ref() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,8 @@ use crate::logical_plan::{ | |
SubqueryAlias, Union, Values, Window, | ||
}; | ||
use crate::{ | ||
Cast, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder, TableScan, TryCast, | ||
BinaryExpr, Cast, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder, Operator, | ||
TableScan, TryCast, | ||
}; | ||
use arrow::datatypes::{DataType, TimeUnit}; | ||
use datafusion_common::{ | ||
|
@@ -567,20 +568,35 @@ pub fn from_plan( | |
}) => { | ||
let schema = | ||
build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?; | ||
|
||
let equi_expr_count = on.len(); | ||
assert!(expr.len() >= equi_expr_count); | ||
|
||
// The preceding part of expr is equi-exprs, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
// and the struct of each equi-expr is like `left-expr = right-expr`. | ||
let new_on:Vec<(Expr,Expr)> = expr.iter().take(equi_expr_count).map(|equi_expr| { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should error here if https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.take There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sense, I added a check -- |
||
if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = equi_expr { | ||
assert!(op == &Operator::Eq); | ||
Ok(((**left).clone(), (**right).clone())) | ||
} else { | ||
Err(DataFusionError::Internal(format!( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
"The front part expressions should be an binary expression, actual:{}", | ||
equi_expr | ||
))) | ||
} | ||
}).collect::<Result<Vec<(Expr, Expr)>>>()?; | ||
|
||
// Assume that the last expr, if any, | ||
// is the filter_expr (non equality predicate from ON clause) | ||
let filter_expr = if on.len() * 2 == expr.len() { | ||
None | ||
} else { | ||
Some(expr[expr.len() - 1].clone()) | ||
}; | ||
let filter_expr = | ||
(expr.len() > equi_expr_count).then(|| expr[expr.len() - 1].clone()); | ||
|
||
Ok(LogicalPlan::Join(Join { | ||
left: Arc::new(inputs[0].clone()), | ||
right: Arc::new(inputs[1].clone()), | ||
join_type: *join_type, | ||
join_constraint: *join_constraint, | ||
on: on.clone(), | ||
on: new_on, | ||
filter: filter_expr, | ||
schema: DFSchemaRef::new(schema), | ||
null_equals_null: *null_equals_null, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happened previously with this plan? Would it error at runtime?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it succeeded without type coercion, but I think it is a coincidence.
The reason is for
decimal eq
operation, we only check value, but do not check theprecision
andscale
are same.In this test, each
c3
of the two matched rows have the same value, thethird
row is789000
and thefourth
is-12312
, then it succeeded.