Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support type coercion for equijoin #4666

Merged
merged 5 commits into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 156 additions & 16 deletions datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1448,11 +1448,11 @@ async fn hash_join_with_decimal() -> Result<()> {
let state = ctx.state();
let plan = state.optimize(&plan)?;
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.c1, t1.c2, t1.c3, t1.c4, t2.c1, t2.c2, t2.c3, t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
" Right Join: t1.c3 = t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happened previously with this plan? Would it error at runtime?

Copy link
Contributor Author

@ygf11 ygf11 Dec 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it succeeded without type coercion, but I think it is a coincidence.

The reason is for decimal eq operation, we only check value, but do not check the precision and scale are same.

In this test, each c3 of the two matched rows have the same value, the third row is 789000 and the fourth is -12312, then it succeeded.

    "+------------+------------+---------+-----+------------+------------+-----------+---------+",
    "| c1         | c2         | c3      | c4  | c1         | c2         | c3        | c4      |",
    "+------------+------------+---------+-----+------------+------------+-----------+---------+",
    "|            |            |         |     |            |            | 100000.00 | abcdefg |",
    "|            |            |         |     |            | 1970-01-04 | 0.00      | qwerty  |",
    "|            | 1970-01-04 | 789.00  | ghi | 1970-01-04 |            | 789.00    |         |",
    "| 1970-01-04 |            | -123.12 | jkl | 1970-01-02 | 1970-01-02 | -123.12   | abc     |",
    "+------------+------------+---------+-----+------------+------------+-----------+---------+",

" TableScan: t1 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N]",
" TableScan: t2 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.c1, t1.c2, t1.c3, t1.c4, t2.c1, t2.c2, t2.c3, t2.c4 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
" Right Join: CAST(t1.c3 AS Decimal128(10, 2)) = t2.c3 [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N, c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
" TableScan: t1 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(5, 2);N, c4:Dictionary(Int32, Utf8);N]",
" TableScan: t2 projection=[c1, c2, c3, c4] [c1:Date32;N, c2:Date64;N, c3:Decimal128(10, 2);N, c4:Dictionary(Int32, Utf8);N]",
];
let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down Expand Up @@ -1982,19 +1982,22 @@ async fn sort_merge_join_on_decimal() -> Result<()> {
let state = ctx.state();
let logical_plan = state.optimize(&plan)?;
let physical_plan = state.create_physical_plan(&logical_plan).await?;

let expected = vec![
"ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c1@4 as c1, c2@5 as c2, c3@6 as c3, c4@7 as c4]",
" SortMergeJoin: join_type=Right, on=[(Column { name: \"c3\", index: 2 }, Column { name: \"c3\", index: 2 })]",
" SortExec: [c3@2 ASC]",
" CoalesceBatchesExec: target_batch_size=4096",
" RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)",
" RepartitionExec: partitioning=RoundRobinBatch(2)",
" MemoryExec: partitions=1, partition_sizes=[1]",
" SortExec: [c3@2 ASC]",
" CoalesceBatchesExec: target_batch_size=4096",
" RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)",
" RepartitionExec: partitioning=RoundRobinBatch(2)",
" MemoryExec: partitions=1, partition_sizes=[1]",
" ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c1@5 as c1, c2@6 as c2, c3@7 as c3, c4@8 as c4]",
" SortMergeJoin: join_type=Right, on=[(Column { name: \"CAST(t1.c3 AS Decimal128(10, 2))\", index: 4 }, Column { name: \"c3\", index: 2 })]",
" SortExec: [CAST(t1.c3 AS Decimal128(10, 2))@4 ASC]",
" CoalesceBatchesExec: target_batch_size=4096",
" RepartitionExec: partitioning=Hash([Column { name: \"CAST(t1.c3 AS Decimal128(10, 2))\", index: 4 }], 2)",
" ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, CAST(c3@2 AS Decimal128(10, 2)) as CAST(t1.c3 AS Decimal128(10, 2))]",
" RepartitionExec: partitioning=RoundRobinBatch(2)",
" MemoryExec: partitions=1, partition_sizes=[1]",
" SortExec: [c3@2 ASC]",
" CoalesceBatchesExec: target_batch_size=4096",
" RepartitionExec: partitioning=Hash([Column { name: \"c3\", index: 2 }], 2)",
" RepartitionExec: partitioning=RoundRobinBatch(2)",
" MemoryExec: partitions=1, partition_sizes=[1]",
];
let formatted = displayable(physical_plan.as_ref()).indent().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
Expand Down Expand Up @@ -2776,3 +2779,140 @@ async fn select_wildcard_with_expr_key_inner_join() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn join_with_type_coercion_for_equi_expr() -> Result<()> {
let ctx = create_join_context("t1_id", "t2_id", false)?;

let sql = "select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on t1.t1_id + 11 = t2.t2_id";

// assert logical plan
let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx
.create_logical_plan(&("explain ".to_owned() + sql))
.expect(&msg);
let state = ctx.state();
let plan = state.optimize(&plan)?;

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
" Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think eventually it would be great to have these casts unwrapped too, like

Suggested change
" Inner Join: CAST(t1.t1_id AS Int64) + Int64(11) = CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
" Inner Join: t1.t1_id + Int32(11) = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",

To avoid the runtime casting

I am not quite sure why https://github.com/apache/arrow-datafusion/blob/master/datafusion/optimizer/src/unwrap_cast_in_comparison.rs is not doing so

Copy link
Contributor

@liukun4515 liukun4515 Dec 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rule just can apply to the pattern

expr `op` literal

cc @alamb

We can file a new issue to discuss this.

I must point out a problem about overflow for add operation, for example i32::max + i32::max maybe overflow.

" TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];

let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
);

let expected = vec![
"+-------+---------+-------+",
"| t1_id | t1_name | t2_id |",
"+-------+---------+-------+",
"| 11 | a | 22 |",
"| 33 | c | 44 |",
"| 44 | d | 55 |",
"+-------+---------+-------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

Ok(())
}

#[tokio::test]
async fn join_only_with_filter() -> Result<()> {
let ctx = create_join_context("t1_id", "t2_id", false)?;

let sql = "select t1.t1_id, t1.t1_name, t2.t2_id from t1 inner join t2 on t1.t1_id * 4 < t2.t2_id";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after #4562 merged, the plan will be converted to NLJ

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in e73faab


// assert logical plan
let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx
.create_logical_plan(&("explain ".to_owned() + sql))
.expect(&msg);
let state = ctx.state();
let plan = state.optimize(&plan)?;

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
" Inner Join: Filter: CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];

let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
);

let expected = vec![
"+-------+---------+-------+",
"| t1_id | t1_name | t2_id |",
"+-------+---------+-------+",
"| 11 | a | 55 |",
"+-------+---------+-------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

Ok(())
}

#[tokio::test]
async fn type_coercion_join_with_filter_and_equi_expr() -> Result<()> {
let ctx = create_join_context("t1_id", "t2_id", false)?;

let sql = "select t1.t1_id, t1.t1_name, t2.t2_id \
from t1 \
inner join t2 \
on t1.t1_id * 5 = t2.t2_id and t1.t1_id * 4 < t2.t2_id";

// assert logical plan
let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx
.create_logical_plan(&("explain ".to_owned() + sql))
.expect(&msg);
let state = ctx.state();
let plan = state.optimize(&plan)?;

let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
" Inner Join: CAST(t1.t1_id AS Int64) * Int64(5) = CAST(t2.t2_id AS Int64) Filter: CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

" TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
" TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]",
];

let formatted = plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
);

let expected = vec![
"+-------+---------+-------+",
"| t1_id | t1_name | t2_id |",
"+-------+---------+-------+",
"| 11 | a | 55 |",
"+-------+---------+-------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

Ok(())
}
5 changes: 4 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,12 @@ impl LogicalPlan {
aggr_expr,
..
}) => group_expr.iter().chain(aggr_expr.iter()).cloned().collect(),
// There are two part of expression for join, equijoin(on) and non-equijoin(filter).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

// 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`.
// 2. the second part is non-equijoin(filter).
LogicalPlan::Join(Join { on, filter, .. }) => on
.iter()
.flat_map(|(l, r)| vec![l.clone(), r.clone()])
.map(|(l, r)| Expr::eq(l.clone(), r.clone()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the fix, right? It then exposes the <l> = <r> expr to the existing type coercion logic ?

Very nice 👍

.chain(
filter
.as_ref()
Expand Down
30 changes: 23 additions & 7 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ use crate::logical_plan::{
SubqueryAlias, Union, Values, Window,
};
use crate::{
Cast, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder, TableScan, TryCast,
BinaryExpr, Cast, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder, Operator,
TableScan, TryCast,
};
use arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::{
Expand Down Expand Up @@ -567,20 +568,35 @@ pub fn from_plan(
}) => {
let schema =
build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?;

let equi_expr_count = on.len();
assert!(expr.len() >= equi_expr_count);

// The preceding part of expr is equi-exprs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

// and the struct of each equi-expr is like `left-expr = right-expr`.
let new_on:Vec<(Expr,Expr)> = expr.iter().take(equi_expr_count).map(|equi_expr| {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should error here if expr does not has at least equi_expr_count elements left. Otherwise I think take will silently return fewer than equi_expr_count elements, which might result in quite hard to track down bugs

https://doc.rust-lang.org/std/iter/trait.Iterator.html#method.take

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, I added a check -- assert!(expr.len() >= equi_expr_count).

if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = equi_expr {
assert!(op == &Operator::Eq);
Ok(((**left).clone(), (**right).clone()))
} else {
Err(DataFusionError::Internal(format!(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

"The front part expressions should be an binary expression, actual:{}",
equi_expr
)))
}
}).collect::<Result<Vec<(Expr, Expr)>>>()?;

// Assume that the last expr, if any,
// is the filter_expr (non equality predicate from ON clause)
let filter_expr = if on.len() * 2 == expr.len() {
None
} else {
Some(expr[expr.len() - 1].clone())
};
let filter_expr =
(expr.len() > equi_expr_count).then(|| expr[expr.len() - 1].clone());

Ok(LogicalPlan::Join(Join {
left: Arc::new(inputs[0].clone()),
right: Arc::new(inputs[1].clone()),
join_type: *join_type,
join_constraint: *join_constraint,
on: on.clone(),
on: new_on,
filter: filter_expr,
schema: DFSchemaRef::new(schema),
null_equals_null: *null_equals_null,
Expand Down