Skip to content

Commit

Permalink
include input fields as output for Substrait consumer
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms committed Aug 29, 2024
1 parent 6ffb1f6 commit e041698
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 8 deletions.
87 changes: 79 additions & 8 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use datafusion::logical_expr::{
expr::find_df_window_func, Aggregate, BinaryExpr, Case, EmptyRelation, Expr,
ExprSchemable, LogicalPlan, Operator, Projection, Values,
};
use pbjson_types::field;
use substrait::proto::expression::subquery::set_predicate::PredicateOp;
use url::Url;

Expand Down Expand Up @@ -194,6 +195,27 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality(
(accum_join_keys, nulls_equal_nulls, join_filter)
}

pub(crate) fn equivalent_names_and_types_ignore_order(
schema1: &DFSchema,
schema2: &Arc<DFSchema>
) -> bool {
if schema1.fields().len() != schema2.fields().len() {
return false;
}

let self_fields: HashSet<_> = schema1.fields()
.iter()
.map(|field| (field.name().to_owned(), field.data_type().clone()))
.collect();

let other_fields: HashSet<_> = schema2.fields()
.iter()
.map(|field| (field.name().to_owned(), field.data_type().clone()))
.collect();

self_fields == other_fields
}

/// Convert Substrait Plan to DataFusion LogicalPlan
pub async fn from_substrait_plan(
ctx: &SessionContext,
Expand All @@ -214,29 +236,76 @@ pub async fn from_substrait_plan(
Ok(from_substrait_rel(ctx, rel, &extensions).await?)
},
plan_rel::RelType::Root(root) => {
let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &extensions).await?;
let mut plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &extensions).await?;
if root.names.is_empty() {
// Backwards compatibility for plans missing names
return Ok(plan);
}
// <https://github.com/apache/datafusion/issues/12204> if the names didn't match the root plan's schema, we try to add projections visa rebuild
let plan_schema = plan.schema();
// only cares Projection and Aggregation which all has 1 input
let plan_input_schema = plan.inputs()[0].schema();
let missed_expr: Vec<Expr> = root.names.iter()
.filter_map(|name| {
if !plan_schema.has_column_with_unqualified_name(name) &&
plan_input_schema.has_column_with_unqualified_name(name) {
// we can safely unwrap here
let (qualifier, field) = plan_input_schema.qualified_field_with_unqualified_name(name).unwrap();
let res = Expr::from(Column::from((qualifier, field)));
println!("res is {:?}",res);
Some(res)
} else {
None
}
}).collect();
println!("input's schema is {:?}", plan.inputs()[0].schema());
if !missed_expr.is_empty() {
if let LogicalPlan::Projection(mut projection) = plan {
projection.expr.extend(missed_expr);
plan = LogicalPlan::Projection(Projection::try_new(projection.expr, projection.input)?)
}
}
println!("new plan is {:?}", plan);
let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?;
if renamed_schema.equivalent_names_and_types(plan.schema()) {
println!("renamed_schema is {:?} and plan.schema() is {:?}", renamed_schema, plan.schema());
if equivalent_names_and_types_ignore_order(&renamed_schema, plan.schema()) {
// Nothing to do if the schema is already equivalent
return Ok(plan);
}

match plan {
// If the last node of the plan produces expressions, bake the renames into those expressions.
// This isn't necessary for correctness, but helps with roundtrip tests.
LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), &renamed_schema)?, p.input)?)),
LogicalPlan::Projection(p) => {
Ok(LogicalPlan::Projection(
Projection::try_new(
rename_expressions(p.expr, p.input.schema(), &renamed_schema)?,
p.input
)?
))
},

LogicalPlan::Aggregate(a) => {
let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), &renamed_schema)?;
Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?))
Ok(LogicalPlan::Aggregate(
Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?
))
},

// There are probably more plans where we could bake things in, can add them later as needed.
// Otherwise, add a new Project to handle the renaming.
_ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), &renamed_schema)?, Arc::new(plan))?))
_ => Ok(LogicalPlan::Projection(
Projection::try_new(
rename_expressions(
plan.schema().columns().iter().map(|c| col(c.to_owned())),
plan.schema(),
&renamed_schema
)?,
Arc::new(plan)
)?
)),
}

}
},
None => plan_err!("Cannot parse plan relation: None")
Expand Down Expand Up @@ -358,16 +427,18 @@ fn make_renamed_schema(
name_idx,
)?),
))),
_ => Ok(dtype.to_owned()),
_ => {
Ok(dtype.to_owned())
}
}
}

let mut name_idx = 0;

let (qualifiers, fields): (_, Vec<Field>) = schema
.iter()
.map(|(q, f)| {
let name = next_struct_field_name(0, dfs_names, &mut name_idx)?;
println!("f is {:?}", f);
Ok((
q.cloned(),
(**f)
Expand All @@ -390,7 +461,7 @@ fn make_renamed_schema(
name_idx,
dfs_names.len());
}

println!("all fields are {:?}", fields);
DFSchema::from_field_specific_qualified_schema(
qualifiers,
&Arc::new(Schema::new(fields)),
Expand Down
52 changes: 52 additions & 0 deletions datafusion/substrait/tests/cases/bugs_converage.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Tests for bugs in substrait

#[cfg(test)]
mod tests {
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::Result;
use datafusion::datasource::MemTable;
use datafusion::prelude::SessionContext;
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use std::fs::File;
use std::io::BufReader;
use std::sync::Arc;
use substrait::proto::Plan;
#[tokio::test]
async fn extra_projection_with_input() -> Result<()> {
let ctx = SessionContext::new();
let schema = Schema::new(vec![
Field::new("user_id", DataType::Utf8, false),
Field::new("name", DataType::Utf8, false),
Field::new("paid_for_service", DataType::Boolean, false),
]);
let memory_table = MemTable::try_new(schema.into(), vec![vec![]]).unwrap();
ctx.register_table("users", Arc::new(memory_table))?;
let path = "tests/testdata/extra_projection_with_input.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json");

let plan = from_substrait_plan(&ctx, &proto).await?;
let plan_str = format!("{}", plan);
println!("{:?}", plan_str);
Ok(())
}
}
1 change: 1 addition & 0 deletions datafusion/substrait/tests/cases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

mod bugs_converage;
mod consumer_integration;
mod function_test;
mod logical_plans;
Expand Down
113 changes: 113 additions & 0 deletions datafusion/substrait/tests/testdata/extra_projection_with_input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
{
"extensionUris": [
{
"extensionUriAnchor": 1,
"uri": "/functions_arithmetic.yaml"
}
],
"extensions": [
{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 1,
"name": "row_number"
}
}
],
"relations": [
{
"root": {
"input": {
"project": {
"common": {
"direct": {}
},
"input": {
"read": {
"common": {
"direct": {}
},
"baseSchema": {
"names": [
"user_id",
"name",
"paid_for_service"
],
"struct": {
"types": [
{
"string": {
"nullability": "NULLABILITY_REQUIRED"
}
},
{
"string": {
"nullability": "NULLABILITY_REQUIRED"
}
},
{
"bool": {
"nullability": "NULLABILITY_REQUIRED"
}
}
],
"nullability": "NULLABILITY_REQUIRED"
}
},
"namedTable": {
"names": [
"users"
]
}
}
},
"expressions": [
{
"windowFunction": {
"functionReference": 1,
"sorts": [
{
"expr": {
"selection": {
"directReference": {
"structField": {
"field": 1
}
},
"rootReference": {}
}
},
"direction": "SORT_DIRECTION_ASC_NULLS_FIRST"
}
],
"upperBound": {
"unbounded": {}
},
"lowerBound": {
"unbounded": {}
},
"outputType": {
"i64": {
"nullability": "NULLABILITY_REQUIRED"
}
},
"invocation": "AGGREGATION_INVOCATION_ALL"
}
}
]
}
},
"names": [
"user_id",
"name",
"paid_for_service",
"row_number"
]
}
}
],
"version": {
"minorNumber": 52,
"producer": "spark-substrait-gateway"
}
}

0 comments on commit e041698

Please sign in to comment.