Skip to content

Commit

Permalink
Substrait: Fix incorrect join key fields (indices) when same table is…
Browse files Browse the repository at this point in the history
… being used more than once (#6135)

* Fix incorrect join key fields (indices) when same table is being used more than once

* Addressed comments

Update datafusion/substrait/src/logical_plan/producer.rs

Co-authored-by: Ruihang Xia <waynestxia@gmail.com>

Update datafusion/substrait/src/logical_plan/producer.rs

Co-authored-by: Ruihang Xia <waynestxia@gmail.com>

* Fixed bugs after rebase

---------

Co-authored-by: Ruihang Xia <waynestxia@gmail.com>
  • Loading branch information
nseekhao and waynexia authored Jun 11, 2023
1 parent 57bc5b0 commit e6265c1
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 62 deletions.
2 changes: 1 addition & 1 deletion datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ pub async fn from_substrait_rel(
)),
},
_ => Err(DataFusionError::Internal(
"invalid join condition expresssion".to_string(),
"invalid join condition expression".to_string(),
)),
}
}
Expand Down
191 changes: 130 additions & 61 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::{collections::HashMap, sync::Arc};
use std::collections::HashMap;

use datafusion::{
arrow::datatypes::{DataType, TimeUnit},
Expand All @@ -32,7 +32,7 @@ use datafusion::logical_expr::expr::{
BinaryExpr, Case, Cast, ScalarFunction as DFScalarFunction, Sort, WindowFunction,
};
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator};
use datafusion::prelude::{binary_expr, Expr};
use datafusion::prelude::Expr;
use prost_types::Any as ProtoAny;
use substrait::{
proto::{
Expand Down Expand Up @@ -156,7 +156,7 @@ pub fn to_substrait_rel(
let expressions = p
.expr
.iter()
.map(|e| to_substrait_rex(e, p.input.schema(), extension_info))
.map(|e| to_substrait_rex(e, p.input.schema(), 0, extension_info))
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Project(Box::new(ProjectRel {
Expand All @@ -172,6 +172,7 @@ pub fn to_substrait_rel(
let filter_expr = to_substrait_rex(
&filter.predicate,
filter.input.schema(),
0,
extension_info,
)?;
Ok(Box::new(Rel {
Expand Down Expand Up @@ -218,7 +219,7 @@ pub fn to_substrait_rel(
let grouping = agg
.group_expr
.iter()
.map(|e| to_substrait_rex(e, agg.input.schema(), extension_info))
.map(|e| to_substrait_rex(e, agg.input.schema(), 0, extension_info))
.collect::<Result<Vec<_>>>()?;
let measures = agg
.aggr_expr
Expand Down Expand Up @@ -281,45 +282,24 @@ pub fn to_substrait_rel(
} else {
Operator::Eq
};
let join_expression = join
.on
.iter()
.map(|(l, r)| binary_expr(l.clone(), eq_op, r.clone()))
.reduce(|acc: Expr, expr: Expr| acc.and(expr));
// join schema from left and right to maintain all nececesary columns from inputs
// note that we cannot simple use join.schema here since we discard some input columns
// when performing semi and anti joins
let join_schema = match join.left.schema().join(join.right.schema()) {
Ok(schema) => Ok(schema),
Err(DataFusionError::SchemaError(
datafusion::common::SchemaError::DuplicateQualifiedField {
qualifier: _,
name: _,
},
)) => Ok(join.schema.as_ref().clone()),
Err(e) => Err(e),
};
if let Some(e) = join_expression {
Ok(Box::new(Rel {
rel_type: Some(RelType::Join(Box::new(JoinRel {
common: None,
left: Some(left),
right: Some(right),
r#type: join_type as i32,
expression: Some(Box::new(to_substrait_rex(
&e,
&Arc::new(join_schema?),
extension_info,
)?)),
post_join_filter: None,
advanced_extension: None,
}))),
}))
} else {
Err(DataFusionError::NotImplemented(
"Empty join condition".to_string(),
))
}

Ok(Box::new(Rel {
rel_type: Some(RelType::Join(Box::new(JoinRel {
common: None,
left: Some(left),
right: Some(right),
r#type: join_type as i32,
expression: Some(Box::new(to_substrait_join_expr(
&join.on,
eq_op,
join.left.schema(),
join.right.schema(),
extension_info,
)?)),
post_join_filter: None,
advanced_extension: None,
}))),
}))
}
LogicalPlan::SubqueryAlias(alias) => {
// Do nothing if encounters SubqueryAlias
Expand Down Expand Up @@ -353,6 +333,7 @@ pub fn to_substrait_rel(
window_exprs.push(to_substrait_rex(
expr,
window.input.schema(),
0,
extension_info,
)?);
}
Expand Down Expand Up @@ -403,6 +384,40 @@ pub fn to_substrait_rel(
}
}

fn to_substrait_join_expr(
join_conditions: &Vec<(Expr, Expr)>,
eq_op: Operator,
left_schema: &DFSchemaRef,
right_schema: &DFSchemaRef,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> Result<Expression> {
// Only support AND conjunction for each binary expression in join conditions
let mut exprs: Vec<Expression> = vec![];
for (left, right) in join_conditions {
// Parse left
let l = to_substrait_rex(left, left_schema, 0, extension_info)?;
// Parse right
let r = to_substrait_rex(
right,
right_schema,
left_schema.fields().len(), // offset to return the correct index
extension_info,
)?;
// AND with existing expression
exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extension_info));
}
let join_expr: Expression = exprs
.into_iter()
.reduce(|acc: Expression, e: Expression| {
make_binary_op_scalar_func(&acc, &e, Operator::And, extension_info)
})
.unwrap();
Ok(join_expr)
}

fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType {
match join_type {
JoinType::Inner => join_rel::JoinType::Inner,
Expand Down Expand Up @@ -459,7 +474,7 @@ pub fn to_substrait_agg_measure(
Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by: _order_by }) => {
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, extension_info)?)) });
arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
}
let function_name = fun.to_string().to_lowercase();
let function_anchor = _register_function(function_name, extension_info);
Expand All @@ -478,7 +493,7 @@ pub fn to_substrait_agg_measure(
options: vec![],
}),
filter: match filter {
Some(f) => Some(to_substrait_rex(f, schema, extension_info)?),
Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?),
None => None
}
})
Expand Down Expand Up @@ -566,10 +581,33 @@ pub fn make_binary_op_scalar_func(
}

/// Convert DataFusion Expr to Substrait Rex
///
/// # Arguments
///
/// * `expr` - DataFusion expression to be parse into a Substrait expression
/// * `schema` - DataFusion input schema for looking up field qualifiers
/// * `col_ref_offset` - Offset for caculating Substrait field reference indices.
/// This should only be set by caller with more than one input relations i.e. Join.
/// Substrait expects one set of indices when joining two relations.
/// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right`
/// relation will have column indices from `0` to `n-1`, however, Substrait will expect
/// the `right` indices to be offset by the `left`. This means Substrait will expect to
/// evaluate the join condition expression on indices [0 .. n-1, n .. n+m-1]. For example:
/// ```SELECT *
/// FROM t1
/// JOIN t2
/// ON t1.c1 = t2.c0;```
/// where t1 consists of columns [c0, c1, c2], and t2 = columns [c0, c1]
/// the join condition should become
/// `col_ref(1) = col_ref(3 + 0)`
/// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index
/// of the join key column from `right`
/// * `extension_info` - Substrait extension info. Contains registered function information
#[allow(deprecated)]
pub fn to_substrait_rex(
expr: &Expr,
schema: &DFSchemaRef,
col_ref_offset: usize,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
Expand All @@ -583,6 +621,7 @@ pub fn to_substrait_rex(
arg_type: Some(ArgType::Value(to_substrait_rex(
arg,
schema,
col_ref_offset,
extension_info,
)?)),
});
Expand All @@ -607,9 +646,12 @@ pub fn to_substrait_rex(
}) => {
if *negated {
// `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr)
let substrait_expr = to_substrait_rex(expr, schema, extension_info)?;
let substrait_low = to_substrait_rex(low, schema, extension_info)?;
let substrait_high = to_substrait_rex(high, schema, extension_info)?;
let substrait_expr =
to_substrait_rex(expr, schema, col_ref_offset, extension_info)?;
let substrait_low =
to_substrait_rex(low, schema, col_ref_offset, extension_info)?;
let substrait_high =
to_substrait_rex(high, schema, col_ref_offset, extension_info)?;

let l_expr = make_binary_op_scalar_func(
&substrait_expr,
Expand All @@ -632,9 +674,12 @@ pub fn to_substrait_rex(
))
} else {
// `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high)
let substrait_expr = to_substrait_rex(expr, schema, extension_info)?;
let substrait_low = to_substrait_rex(low, schema, extension_info)?;
let substrait_high = to_substrait_rex(high, schema, extension_info)?;
let substrait_expr =
to_substrait_rex(expr, schema, col_ref_offset, extension_info)?;
let substrait_low =
to_substrait_rex(low, schema, col_ref_offset, extension_info)?;
let substrait_high =
to_substrait_rex(high, schema, col_ref_offset, extension_info)?;

let l_expr = make_binary_op_scalar_func(
&substrait_low,
Expand All @@ -659,11 +704,11 @@ pub fn to_substrait_rex(
}
Expr::Column(col) => {
let index = schema.index_of_column(col)?;
substrait_field_ref(index)
substrait_field_ref(index + col_ref_offset)
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let l = to_substrait_rex(left, schema, extension_info)?;
let r = to_substrait_rex(right, schema, extension_info)?;
let l = to_substrait_rex(left, schema, col_ref_offset, extension_info)?;
let r = to_substrait_rex(right, schema, col_ref_offset, extension_info)?;

Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info))
}
Expand All @@ -677,21 +722,41 @@ pub fn to_substrait_rex(
if let Some(e) = expr {
// Base expression exists
ifs.push(IfClause {
r#if: Some(to_substrait_rex(e, schema, extension_info)?),
r#if: Some(to_substrait_rex(
e,
schema,
col_ref_offset,
extension_info,
)?),
then: None,
});
}
// Parse `when`s
for (r#if, then) in when_then_expr {
ifs.push(IfClause {
r#if: Some(to_substrait_rex(r#if, schema, extension_info)?),
then: Some(to_substrait_rex(then, schema, extension_info)?),
r#if: Some(to_substrait_rex(
r#if,
schema,
col_ref_offset,
extension_info,
)?),
then: Some(to_substrait_rex(
then,
schema,
col_ref_offset,
extension_info,
)?),
});
}

// Parse outer `else`
let r#else: Option<Box<Expression>> = match else_expr {
Some(e) => Some(Box::new(to_substrait_rex(e, schema, extension_info)?)),
Some(e) => Some(Box::new(to_substrait_rex(
e,
schema,
col_ref_offset,
extension_info,
)?)),
None => None,
};

Expand All @@ -707,6 +772,7 @@ pub fn to_substrait_rex(
input: Some(Box::new(to_substrait_rex(
expr,
schema,
col_ref_offset,
extension_info,
)?)),
failure_behavior: 0, // FAILURE_BEHAVIOR_UNSPECIFIED
Expand All @@ -715,7 +781,9 @@ pub fn to_substrait_rex(
})
}
Expr::Literal(value) => to_substrait_literal(value),
Expr::Alias(expr, _alias) => to_substrait_rex(expr, schema, extension_info),
Expr::Alias(expr, _alias) => {
to_substrait_rex(expr, schema, col_ref_offset, extension_info)
}
Expr::WindowFunction(WindowFunction {
fun,
args,
Expand All @@ -733,14 +801,15 @@ pub fn to_substrait_rex(
arg_type: Some(ArgType::Value(to_substrait_rex(
arg,
schema,
col_ref_offset,
extension_info,
)?)),
});
}
// partition by expressions
let partition_by = partition_by
.iter()
.map(|e| to_substrait_rex(e, schema, extension_info))
.map(|e| to_substrait_rex(e, schema, col_ref_offset, extension_info))
.collect::<Result<Vec<_>>>()?;
// order by expressions
let order_by = order_by
Expand Down Expand Up @@ -1325,7 +1394,7 @@ fn substrait_sort_field(
asc,
nulls_first,
}) => {
let e = to_substrait_rex(expr, schema, extension_info)?;
let e = to_substrait_rex(expr, schema, 0, extension_info)?;
let d = match (asc, nulls_first) {
(true, true) => SortDirection::AscNullsFirst,
(true, false) => SortDirection::AscNullsLast,
Expand Down
24 changes: 24 additions & 0 deletions datafusion/substrait/tests/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,30 @@ mod tests {
roundtrip("SELECT a,b,c,d,e FROM datafusion.public.data;").await
}

#[tokio::test]
async fn roundtrip_inner_join_table_reuse_zero_index() -> Result<()> {
assert_expected_plan(
"SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.a = d2.a",
"Projection: data.b, data.c\
\n Inner Join: data.a = data.a\
\n TableScan: data projection=[a, b]\
\n TableScan: data projection=[a, c]",
)
.await
}

#[tokio::test]
async fn roundtrip_inner_join_table_reuse_non_zero_index() -> Result<()> {
assert_expected_plan(
"SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.b = d2.b",
"Projection: data.b, data.c\
\n Inner Join: data.b = data.b\
\n TableScan: data projection=[b]\
\n TableScan: data projection=[b, c]",
)
.await
}

/// Construct a plan that contains several literals of types that are currently supported.
/// This case ignores:
/// - Date64, for this literal is not supported
Expand Down

0 comments on commit e6265c1

Please sign in to comment.