Skip to content

Commit

Permalink
feat: Support duplicate column names in Joins in Substrait consumer (#…
Browse files Browse the repository at this point in the history
…11049)

* add tests for joining tables with same name

* alias tables just before joins instead, for a simpler solution

* test cleanup

* docstrings and some cleanup

* Clippy

---------

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
Blizzara and alamb authored Jun 24, 2024
1 parent c7ac8b8 commit 3ff0bfe
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 56 deletions.
78 changes: 61 additions & 17 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,14 @@ pub async fn from_substrait_plan(
match plan {
// If the last node of the plan produces expressions, bake the renames into those expressions.
// This isn't necessary for correctness, but helps with roundtrip tests.
LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema)?, p.input)?)),
LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), &renamed_schema)?, p.input)?)),
LogicalPlan::Aggregate(a) => {
let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), renamed_schema)?;
let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), &renamed_schema)?;
Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?))
},
// There are probably more plans where we could bake things in, can add them later as needed.
// Otherwise, add a new Project to handle the renaming.
_ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema)?, Arc::new(plan))?))
_ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), &renamed_schema)?, Arc::new(plan))?))
}
}
},
Expand Down Expand Up @@ -308,34 +308,46 @@ pub fn extract_projection(
}
}

/// Ensure the expressions have the right name(s) according to the new schema.
/// This includes the top-level (column) name, which will be renamed through aliasing if needed,
/// as well as nested names (if the expression produces any struct types), which will be renamed
/// through casting if needed.
fn rename_expressions(
exprs: impl IntoIterator<Item = Expr>,
input_schema: &DFSchema,
new_schema: DFSchemaRef,
new_schema: &DFSchema,
) -> Result<Vec<Expr>> {
exprs
.into_iter()
.zip(new_schema.fields())
.map(|(old_expr, new_field)| {
if &old_expr.get_type(input_schema)? == new_field.data_type() {
// Alias column if needed
old_expr.alias_if_changed(new_field.name().into())
} else {
// Use Cast to rename inner struct fields + alias column if needed
// Check if type (i.e. nested struct field names) match, use Cast to rename if needed
let new_expr = if &old_expr.get_type(input_schema)? != new_field.data_type() {
Expr::Cast(Cast::new(
Box::new(old_expr),
new_field.data_type().to_owned(),
))
.alias_if_changed(new_field.name().into())
} else {
old_expr
};
// Alias column if needed to fix the top-level name
match &new_expr {
// If expr is a column reference, alias_if_changed would cause an aliasing if the old expr has a qualifier
Expr::Column(c) if &c.name == new_field.name() => Ok(new_expr),
_ => new_expr.alias_if_changed(new_field.name().to_owned()),
}
})
.collect()
}

/// Produce a version of the given schema with names matching the given list of names.
/// Substrait doesn't deal with column (incl. nested struct field) names within the schema,
/// but it does give us the list of expected names at the end of the plan, so we use this
/// to rename the schema to match the expected names.
fn make_renamed_schema(
schema: &DFSchemaRef,
dfs_names: &Vec<String>,
) -> Result<DFSchemaRef> {
) -> Result<DFSchema> {
fn rename_inner_fields(
dtype: &DataType,
dfs_names: &Vec<String>,
Expand Down Expand Up @@ -401,10 +413,10 @@ fn make_renamed_schema(
dfs_names.len());
}

Ok(Arc::new(DFSchema::from_field_specific_qualified_schema(
DFSchema::from_field_specific_qualified_schema(
qualifiers,
&Arc::new(Schema::new(fields)),
)?))
)
}

/// Convert Substrait Rel to DataFusion DataFrame
Expand Down Expand Up @@ -594,6 +606,8 @@ pub async fn from_substrait_rel(
let right = LogicalPlanBuilder::from(
from_substrait_rel(ctx, join.right.as_ref().unwrap(), extensions).await?,
);
let (left, right) = requalify_sides_if_needed(left, right)?;

let join_type = from_substrait_jointype(join.r#type)?;
// The join condition expression needs full input schema and not the output schema from join since we lose columns from
// certain join types such as semi and anti joins
Expand Down Expand Up @@ -627,13 +641,15 @@ pub async fn from_substrait_rel(
}
}
Some(RelType::Cross(cross)) => {
let left: LogicalPlanBuilder = LogicalPlanBuilder::from(
let left = LogicalPlanBuilder::from(
from_substrait_rel(ctx, cross.left.as_ref().unwrap(), extensions).await?,
);
let right =
let right = LogicalPlanBuilder::from(
from_substrait_rel(ctx, cross.right.as_ref().unwrap(), extensions)
.await?;
left.cross_join(right)?.build()
.await?,
);
let (left, right) = requalify_sides_if_needed(left, right)?;
left.cross_join(right.build()?)?.build()
}
Some(RelType::Read(read)) => match &read.as_ref().read_type {
Some(ReadType::NamedTable(nt)) => {
Expand Down Expand Up @@ -846,6 +862,34 @@ pub async fn from_substrait_rel(
}
}

/// (Re)qualify the sides of a join if needed, i.e. if the columns from one side would otherwise
/// conflict with the columns from the other.
/// Substrait doesn't currently allow specifying aliases, neither for columns nor for tables. For
/// Substrait the names don't matter since it only refers to columns by indices, however DataFusion
/// requires columns to be uniquely identifiable, in some places (see e.g. DFSchema::check_names).
fn requalify_sides_if_needed(
left: LogicalPlanBuilder,
right: LogicalPlanBuilder,
) -> Result<(LogicalPlanBuilder, LogicalPlanBuilder)> {
let left_cols = left.schema().columns();
let right_cols = right.schema().columns();
if left_cols.iter().any(|l| {
right_cols.iter().any(|r| {
l == r || (l.name == r.name && (l.relation.is_none() || r.relation.is_none()))
})
}) {
// These names have no connection to the original plan, but they'll make the columns
// (mostly) unique. There may be cases where this still causes duplicates, if either left
// or right side itself contains duplicate names with different qualifiers.
Ok((
left.alias(TableReference::bare("left"))?,
right.alias(TableReference::bare("right"))?,
))
} else {
Ok((left, right))
}
}

fn from_substrait_jointype(join_type: i32) -> Result<JoinType> {
if let Ok(substrait_join_type) = join_rel::JoinType::try_from(join_type) {
match substrait_join_type {
Expand Down
98 changes: 59 additions & 39 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,23 @@ async fn roundtrip_outer_join() -> Result<()> {
roundtrip("SELECT data.a FROM data FULL OUTER JOIN data2 ON data.a = data2.a").await
}

#[tokio::test]
async fn roundtrip_self_join() -> Result<()> {
// Substrait does currently NOT maintain the alias of the tables.
// Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide.
// This roundtrip works because we set aliases to what the Substrait consumer will generate.
roundtrip("SELECT left.a as left_a, left.b, right.a as right_a, right.c FROM data AS left JOIN data AS right ON left.a = right.a").await?;
roundtrip("SELECT left.a as left_a, left.b, right.a as right_a, right.c FROM data AS left JOIN data AS right ON left.b = right.b").await
}

#[tokio::test]
async fn roundtrip_self_implicit_cross_join() -> Result<()> {
// Substrait does currently NOT maintain the alias of the tables.
// Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide.
// This roundtrip works because we set aliases to what the Substrait consumer will generate.
roundtrip("SELECT left.a left_a, left.b, right.a right_a, right.c FROM data AS left, data AS right").await
}

#[tokio::test]
async fn roundtrip_arithmetic_ops() -> Result<()> {
roundtrip("SELECT a - a FROM data").await?;
Expand Down Expand Up @@ -610,7 +627,22 @@ async fn simple_intersect() -> Result<()> {

#[tokio::test]
async fn simple_intersect_table_reuse() -> Result<()> {
roundtrip("SELECT count(1) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);").await
// Substrait does currently NOT maintain the alias of the tables.
// Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide.
// In this case the aliasing happens at a different point in the plan, so we cannot use roundtrip.
// Schema check works because we set aliases to what the Substrait consumer will generate.
assert_expected_plan(
"SELECT count(1) FROM (SELECT left.a FROM data AS left INTERSECT SELECT right.a FROM data AS right);",
"Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\
\n Projection: \
\n LeftSemi Join: left.a = right.a\
\n SubqueryAlias: left\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n SubqueryAlias: right\
\n TableScan: data projection=[a]",
true
).await
}

#[tokio::test]
Expand All @@ -628,32 +660,6 @@ async fn qualified_catalog_schema_table_reference() -> Result<()> {
roundtrip("SELECT a,b,c,d,e FROM datafusion.public.data;").await
}

#[tokio::test]
async fn roundtrip_inner_join_table_reuse_zero_index() -> Result<()> {
assert_expected_plan(
"SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.a = d2.a",
"Projection: data.b, data.c\
\n Inner Join: data.a = data.a\
\n TableScan: data projection=[a, b]\
\n TableScan: data projection=[a, c]",
false, // "d1" vs "data" field qualifier
)
.await
}

#[tokio::test]
async fn roundtrip_inner_join_table_reuse_non_zero_index() -> Result<()> {
assert_expected_plan(
"SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.b = d2.b",
"Projection: data.b, data.c\
\n Inner Join: data.b = data.b\
\n TableScan: data projection=[b]\
\n TableScan: data projection=[b, c]",
false, // "d1" vs "data" field qualifier
)
.await
}

/// Construct a plan that contains several literals of types that are currently supported.
/// This case ignores:
/// - Date64, for this literal is not supported
Expand Down Expand Up @@ -707,20 +713,17 @@ async fn roundtrip_literal_struct() -> Result<()> {
#[tokio::test]
async fn roundtrip_values() -> Result<()> {
// TODO: would be nice to have a struct inside the LargeList, but arrow_cast doesn't support that currently
let values = "(\
assert_expected_plan(
"VALUES \
(\
1, \
'a', \
[[-213.1, NULL, 5.5, 2.0, 1.0], []], \
arrow_cast([1,2,3], 'LargeList(Int64)'), \
STRUCT(true, 1 AS int_field, CAST(NULL AS STRING)), \
[STRUCT(STRUCT('a' AS string_field) AS struct_field)]\
)";

// Test LogicalPlan::Values
assert_expected_plan(
format!("VALUES \
{values}, \
(NULL, NULL, NULL, NULL, NULL, NULL)").as_str(),
), \
(NULL, NULL, NULL, NULL, NULL, NULL)",
"Values: \
(\
Int64(1), \
Expand All @@ -731,11 +734,28 @@ async fn roundtrip_values() -> Result<()> {
List([{struct_field: {string_field: a}}])\
), \
(Int64(NULL), Utf8(NULL), List(), LargeList(), Struct({c0:,int_field:,c2:}), List())",
true)
.await?;
true).await
}

#[tokio::test]
async fn roundtrip_values_empty_relation() -> Result<()> {
roundtrip("SELECT * FROM (VALUES ('a')) LIMIT 0").await
}

// Test LogicalPlan::EmptyRelation
roundtrip(format!("SELECT * FROM (VALUES {values}) LIMIT 0").as_str()).await
#[tokio::test]
async fn roundtrip_values_duplicate_column_join() -> Result<()> {
// Substrait does currently NOT maintain the alias of the tables.
// Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide.
// This roundtrip works because we set aliases to what the Substrait consumer will generate.
roundtrip(
"SELECT left.column1 as c1, right.column1 as c2 \
FROM \
(VALUES (1)) AS left \
JOIN \
(VALUES (2)) AS right \
ON left.column1 == right.column1",
)
.await
}

/// Construct a plan that cast columns. Only those SQL types are supported for now.
Expand Down

0 comments on commit 3ff0bfe

Please sign in to comment.