Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Substrait: Fix incorrect join key fields (indices) when same table is being used more than once #6135

Merged
merged 3 commits into from
Jun 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ pub async fn from_substrait_rel(
)),
},
_ => Err(DataFusionError::Internal(
"invalid join condition expresssion".to_string(),
"invalid join condition expression".to_string(),
)),
}
}
Expand Down
191 changes: 130 additions & 61 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::{collections::HashMap, sync::Arc};
use std::collections::HashMap;

use datafusion::{
arrow::datatypes::{DataType, TimeUnit},
Expand All @@ -32,7 +32,7 @@ use datafusion::logical_expr::expr::{
BinaryExpr, Case, Cast, ScalarFunction as DFScalarFunction, Sort, WindowFunction,
};
use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator};
use datafusion::prelude::{binary_expr, Expr};
use datafusion::prelude::Expr;
use prost_types::Any as ProtoAny;
use substrait::{
proto::{
Expand Down Expand Up @@ -156,7 +156,7 @@ pub fn to_substrait_rel(
let expressions = p
.expr
.iter()
.map(|e| to_substrait_rex(e, p.input.schema(), extension_info))
.map(|e| to_substrait_rex(e, p.input.schema(), 0, extension_info))
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Project(Box::new(ProjectRel {
Expand All @@ -172,6 +172,7 @@ pub fn to_substrait_rel(
let filter_expr = to_substrait_rex(
&filter.predicate,
filter.input.schema(),
0,
extension_info,
)?;
Ok(Box::new(Rel {
Expand Down Expand Up @@ -218,7 +219,7 @@ pub fn to_substrait_rel(
let grouping = agg
.group_expr
.iter()
.map(|e| to_substrait_rex(e, agg.input.schema(), extension_info))
.map(|e| to_substrait_rex(e, agg.input.schema(), 0, extension_info))
.collect::<Result<Vec<_>>>()?;
let measures = agg
.aggr_expr
Expand Down Expand Up @@ -281,45 +282,24 @@ pub fn to_substrait_rel(
} else {
Operator::Eq
};
let join_expression = join
.on
.iter()
.map(|(l, r)| binary_expr(l.clone(), eq_op, r.clone()))
.reduce(|acc: Expr, expr: Expr| acc.and(expr));
// join schema from left and right to maintain all nececesary columns from inputs
// note that we cannot simple use join.schema here since we discard some input columns
// when performing semi and anti joins
let join_schema = match join.left.schema().join(join.right.schema()) {
Ok(schema) => Ok(schema),
Err(DataFusionError::SchemaError(
datafusion::common::SchemaError::DuplicateQualifiedField {
qualifier: _,
name: _,
},
)) => Ok(join.schema.as_ref().clone()),
Err(e) => Err(e),
};
if let Some(e) = join_expression {
Ok(Box::new(Rel {
rel_type: Some(RelType::Join(Box::new(JoinRel {
common: None,
left: Some(left),
right: Some(right),
r#type: join_type as i32,
expression: Some(Box::new(to_substrait_rex(
&e,
&Arc::new(join_schema?),
extension_info,
)?)),
post_join_filter: None,
advanced_extension: None,
}))),
}))
} else {
Err(DataFusionError::NotImplemented(
"Empty join condition".to_string(),
))
}

Ok(Box::new(Rel {
rel_type: Some(RelType::Join(Box::new(JoinRel {
common: None,
left: Some(left),
right: Some(right),
r#type: join_type as i32,
expression: Some(Box::new(to_substrait_join_expr(
&join.on,
eq_op,
join.left.schema(),
join.right.schema(),
extension_info,
)?)),
post_join_filter: None,
advanced_extension: None,
}))),
}))
}
LogicalPlan::SubqueryAlias(alias) => {
// Do nothing if encounters SubqueryAlias
Expand Down Expand Up @@ -353,6 +333,7 @@ pub fn to_substrait_rel(
window_exprs.push(to_substrait_rex(
expr,
window.input.schema(),
0,
extension_info,
)?);
}
Expand Down Expand Up @@ -403,6 +384,40 @@ pub fn to_substrait_rel(
}
}

fn to_substrait_join_expr(
join_conditions: &Vec<(Expr, Expr)>,
eq_op: Operator,
left_schema: &DFSchemaRef,
right_schema: &DFSchemaRef,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
),
) -> Result<Expression> {
// Only support AND conjunction for each binary expression in join conditions
let mut exprs: Vec<Expression> = vec![];
for (left, right) in join_conditions {
// Parse left
let l = to_substrait_rex(left, left_schema, 0, extension_info)?;
// Parse right
let r = to_substrait_rex(
right,
right_schema,
left_schema.fields().len(), // offset to return the correct index
extension_info,
)?;
// AND with existing expression
exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extension_info));
}
let join_expr: Expression = exprs
.into_iter()
.reduce(|acc: Expression, e: Expression| {
make_binary_op_scalar_func(&acc, &e, Operator::And, extension_info)
})
.unwrap();
Ok(join_expr)
}

fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType {
match join_type {
JoinType::Inner => join_rel::JoinType::Inner,
Expand Down Expand Up @@ -459,7 +474,7 @@ pub fn to_substrait_agg_measure(
Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by: _order_by }) => {
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, extension_info)?)) });
arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
}
let function_name = fun.to_string().to_lowercase();
let function_anchor = _register_function(function_name, extension_info);
Expand All @@ -478,7 +493,7 @@ pub fn to_substrait_agg_measure(
options: vec![],
}),
filter: match filter {
Some(f) => Some(to_substrait_rex(f, schema, extension_info)?),
Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?),
None => None
}
})
Expand Down Expand Up @@ -566,10 +581,33 @@ pub fn make_binary_op_scalar_func(
}

/// Convert DataFusion Expr to Substrait Rex
///
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

/// # Arguments
///
/// * `expr` - DataFusion expression to be parse into a Substrait expression
/// * `schema` - DataFusion input schema for looking up field qualifiers
/// * `col_ref_offset` - Offset for caculating Substrait field reference indices.
/// This should only be set by caller with more than one input relations i.e. Join.
/// Substrait expects one set of indices when joining two relations.
/// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right`
/// relation will have column indices from `0` to `n-1`, however, Substrait will expect
/// the `right` indices to be offset by the `left`. This means Substrait will expect to
/// evaluate the join condition expression on indices [0 .. n-1, n .. n+m-1]. For example:
/// ```SELECT *
/// FROM t1
/// JOIN t2
/// ON t1.c1 = t2.c0;```
/// where t1 consists of columns [c0, c1, c2], and t2 = columns [c0, c1]
/// the join condition should become
/// `col_ref(1) = col_ref(3 + 0)`
/// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index
/// of the join key column from `right`
/// * `extension_info` - Substrait extension info. Contains registered function information
#[allow(deprecated)]
pub fn to_substrait_rex(
expr: &Expr,
schema: &DFSchemaRef,
col_ref_offset: usize,
extension_info: &mut (
Vec<extensions::SimpleExtensionDeclaration>,
HashMap<String, u32>,
Expand All @@ -583,6 +621,7 @@ pub fn to_substrait_rex(
arg_type: Some(ArgType::Value(to_substrait_rex(
arg,
schema,
col_ref_offset,
extension_info,
)?)),
});
Expand All @@ -607,9 +646,12 @@ pub fn to_substrait_rex(
}) => {
if *negated {
// `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr)
let substrait_expr = to_substrait_rex(expr, schema, extension_info)?;
let substrait_low = to_substrait_rex(low, schema, extension_info)?;
let substrait_high = to_substrait_rex(high, schema, extension_info)?;
let substrait_expr =
to_substrait_rex(expr, schema, col_ref_offset, extension_info)?;
let substrait_low =
to_substrait_rex(low, schema, col_ref_offset, extension_info)?;
let substrait_high =
to_substrait_rex(high, schema, col_ref_offset, extension_info)?;

let l_expr = make_binary_op_scalar_func(
&substrait_expr,
Expand All @@ -632,9 +674,12 @@ pub fn to_substrait_rex(
))
} else {
// `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high)
let substrait_expr = to_substrait_rex(expr, schema, extension_info)?;
let substrait_low = to_substrait_rex(low, schema, extension_info)?;
let substrait_high = to_substrait_rex(high, schema, extension_info)?;
let substrait_expr =
to_substrait_rex(expr, schema, col_ref_offset, extension_info)?;
let substrait_low =
to_substrait_rex(low, schema, col_ref_offset, extension_info)?;
let substrait_high =
to_substrait_rex(high, schema, col_ref_offset, extension_info)?;

let l_expr = make_binary_op_scalar_func(
&substrait_low,
Expand All @@ -659,11 +704,11 @@ pub fn to_substrait_rex(
}
Expr::Column(col) => {
let index = schema.index_of_column(col)?;
substrait_field_ref(index)
substrait_field_ref(index + col_ref_offset)
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let l = to_substrait_rex(left, schema, extension_info)?;
let r = to_substrait_rex(right, schema, extension_info)?;
let l = to_substrait_rex(left, schema, col_ref_offset, extension_info)?;
let r = to_substrait_rex(right, schema, col_ref_offset, extension_info)?;

Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info))
}
Expand All @@ -677,21 +722,41 @@ pub fn to_substrait_rex(
if let Some(e) = expr {
// Base expression exists
ifs.push(IfClause {
r#if: Some(to_substrait_rex(e, schema, extension_info)?),
r#if: Some(to_substrait_rex(
e,
schema,
col_ref_offset,
extension_info,
)?),
then: None,
});
}
// Parse `when`s
for (r#if, then) in when_then_expr {
ifs.push(IfClause {
r#if: Some(to_substrait_rex(r#if, schema, extension_info)?),
then: Some(to_substrait_rex(then, schema, extension_info)?),
r#if: Some(to_substrait_rex(
r#if,
schema,
col_ref_offset,
extension_info,
)?),
then: Some(to_substrait_rex(
then,
schema,
col_ref_offset,
extension_info,
)?),
});
}

// Parse outer `else`
let r#else: Option<Box<Expression>> = match else_expr {
Some(e) => Some(Box::new(to_substrait_rex(e, schema, extension_info)?)),
Some(e) => Some(Box::new(to_substrait_rex(
e,
schema,
col_ref_offset,
extension_info,
)?)),
None => None,
};

Expand All @@ -707,6 +772,7 @@ pub fn to_substrait_rex(
input: Some(Box::new(to_substrait_rex(
expr,
schema,
col_ref_offset,
extension_info,
)?)),
failure_behavior: 0, // FAILURE_BEHAVIOR_UNSPECIFIED
Expand All @@ -715,7 +781,9 @@ pub fn to_substrait_rex(
})
}
Expr::Literal(value) => to_substrait_literal(value),
Expr::Alias(expr, _alias) => to_substrait_rex(expr, schema, extension_info),
Expr::Alias(expr, _alias) => {
to_substrait_rex(expr, schema, col_ref_offset, extension_info)
}
Expr::WindowFunction(WindowFunction {
fun,
args,
Expand All @@ -733,14 +801,15 @@ pub fn to_substrait_rex(
arg_type: Some(ArgType::Value(to_substrait_rex(
arg,
schema,
col_ref_offset,
extension_info,
)?)),
});
}
// partition by expressions
let partition_by = partition_by
.iter()
.map(|e| to_substrait_rex(e, schema, extension_info))
.map(|e| to_substrait_rex(e, schema, col_ref_offset, extension_info))
.collect::<Result<Vec<_>>>()?;
// order by expressions
let order_by = order_by
Expand Down Expand Up @@ -1325,7 +1394,7 @@ fn substrait_sort_field(
asc,
nulls_first,
}) => {
let e = to_substrait_rex(expr, schema, extension_info)?;
let e = to_substrait_rex(expr, schema, 0, extension_info)?;
let d = match (asc, nulls_first) {
(true, true) => SortDirection::AscNullsFirst,
(true, false) => SortDirection::AscNullsLast,
Expand Down
24 changes: 24 additions & 0 deletions datafusion/substrait/tests/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,30 @@ mod tests {
roundtrip("SELECT a,b,c,d,e FROM datafusion.public.data;").await
}

#[tokio::test]
async fn roundtrip_inner_join_table_reuse_zero_index() -> Result<()> {
assert_expected_plan(
"SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.a = d2.a",
"Projection: data.b, data.c\
\n Inner Join: data.a = data.a\
\n TableScan: data projection=[a, b]\
\n TableScan: data projection=[a, c]",
Comment on lines +418 to +422
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this case against the main branch and it can pass. I guess this is not the original query that causes the error?

BTW, if two tables have the same name, does it means they are the same table? (for this case, d1 and d2 are the same table data). If so I think we need not distinguish "different" columns because they will eventually refer to the same column. Please let me know if I missed anything

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I recall correctly, this would have failed in the earlier version of DF where DuplicateQualifiedField gets thrown in DFSchema::new_with_metadata(). However, in the later version, this error does not get thrown anymore and only DuplicateUnqualifiedField gets thrown here.

However, even if this error does not get thrown, and the test passes, it also does not mean that the produced Substrait plan is correct. If we join the same table to itself, the JoinRel in the Substrait plan would expect two input relations. AFAIK, there is no notion of pointer, so we'll need two ReadRels created from the same table. And as far as the JoinRel is concerned, left and right are two separate relations. Since Substrait uses indices as opposed to name qualifier, JoinRel would expect the available input indices to be 0 to size(left output)-1 and size(left output) to size(left output) + size(right output)-1. For example, if the input table has 5 columns, it'll expect to get indices from 0 to 9. And let's say we're trying to join first column from left to second column from right, that would be join condition of col_0 = col_6. Note that without this PR, the join condition would be col_0 = col_1 since from the DF named qualifiers, the producer would find the columns from the left and right to come from the same table, but that will send the wrong message to Substrait.

)
.await
}

#[tokio::test]
async fn roundtrip_inner_join_table_reuse_non_zero_index() -> Result<()> {
assert_expected_plan(
"SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.b = d2.b",
"Projection: data.b, data.c\
\n Inner Join: data.b = data.b\
\n TableScan: data projection=[b]\
\n TableScan: data projection=[b, c]",
)
.await
}

/// Construct a plan that contains several literals of types that are currently supported.
/// This case ignores:
/// - Date64, for this literal is not supported
Expand Down