Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
Lordworms committed Aug 29, 2024
1 parent 6ffb1f6 commit d8d8abb
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 7 deletions.
58 changes: 51 additions & 7 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,34 @@ pub async fn from_substrait_plan(
Ok(from_substrait_rel(ctx, rel, &extensions).await?)
},
plan_rel::RelType::Root(root) => {
let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &extensions).await?;
let mut plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &extensions).await?;
if root.names.is_empty() {
// Backwards compatibility for plans missing names
return Ok(plan);
}
// <https://github.com/apache/datafusion/issues/12204> if the names didn't match the root plan's schema, we try to add projections visa rebuild
if !plan.inputs().is_empty() {
let plan_schema = plan.schema();
// only cares Projection and Aggregation which all has 1 input
let plan_input_schema = plan.inputs()[0].schema();
let mut missed_expr: Vec<Expr> = root.names.iter()
.filter_map(|name| {
if !plan_schema.has_column_with_unqualified_name(name) &&
plan_input_schema.has_column_with_unqualified_name(name) {
// we can safely unwrap here
let (qualifier, field) = plan_input_schema.qualified_field_with_unqualified_name(name).unwrap();
Some(Expr::from(Column::from((qualifier, field))))
} else {
None
}
}).collect();
if !missed_expr.is_empty() {
if let LogicalPlan::Projection(projection) = plan {
missed_expr.extend(projection.expr);
plan = LogicalPlan::Projection(Projection::try_new(missed_expr, projection.input)?)
}
}
}
let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?;
if renamed_schema.equivalent_names_and_types(plan.schema()) {
// Nothing to do if the schema is already equivalent
Expand All @@ -228,15 +251,36 @@ pub async fn from_substrait_plan(
match plan {
// If the last node of the plan produces expressions, bake the renames into those expressions.
// This isn't necessary for correctness, but helps with roundtrip tests.
LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), &renamed_schema)?, p.input)?)),
LogicalPlan::Projection(p) => {
Ok(LogicalPlan::Projection(
Projection::try_new(
rename_expressions(p.expr, p.input.schema(), &renamed_schema)?,
p.input
)?
))
},

LogicalPlan::Aggregate(a) => {
let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), &renamed_schema)?;
Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?))
Ok(LogicalPlan::Aggregate(
Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?
))
},

// There are probably more plans where we could bake things in, can add them later as needed.
// Otherwise, add a new Project to handle the renaming.
_ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), &renamed_schema)?, Arc::new(plan))?))
_ => Ok(LogicalPlan::Projection(
Projection::try_new(
rename_expressions(
plan.schema().columns().iter().map(|c| col(c.to_owned())),
plan.schema(),
&renamed_schema
)?,
Arc::new(plan)
)?
)),
}

}
},
None => plan_err!("Cannot parse plan relation: None")
Expand Down Expand Up @@ -358,12 +402,13 @@ fn make_renamed_schema(
name_idx,
)?),
))),
_ => Ok(dtype.to_owned()),
_ => {
Ok(dtype.to_owned())
}
}
}

let mut name_idx = 0;

let (qualifiers, fields): (_, Vec<Field>) = schema
.iter()
.map(|(q, f)| {
Expand All @@ -390,7 +435,6 @@ fn make_renamed_schema(
name_idx,
dfs_names.len());
}

DFSchema::from_field_specific_qualified_schema(
qualifiers,
&Arc::new(Schema::new(fields)),
Expand Down
54 changes: 54 additions & 0 deletions datafusion/substrait/tests/cases/bugs_converage.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Tests for bugs in substrait
#[cfg(test)]
mod tests {
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::common::Result;
use datafusion::datasource::MemTable;
use datafusion::prelude::SessionContext;
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use std::fs::File;
use std::io::BufReader;
use std::sync::Arc;
use substrait::proto::Plan;
#[tokio::test]
async fn extra_projection_with_input() -> Result<()> {
let ctx = SessionContext::new();
let schema = Schema::new(vec![
Field::new("user_id", DataType::Utf8, false),
Field::new("name", DataType::Utf8, false),
Field::new("paid_for_service", DataType::Boolean, false),
]);
let memory_table = MemTable::try_new(schema.into(), vec![vec![]]).unwrap();
ctx.register_table("users", Arc::new(memory_table))?;
let path = "tests/testdata/extra_projection_with_input.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json");

let plan = from_substrait_plan(&ctx, &proto).await?;
let plan_str = format!("{}", plan);
assert_eq!(plan_str, "Projection: users.user_id, users.name, users.paid_for_service, row_number() ORDER BY [users.name ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS row_number\
\n WindowAggr: windowExpr=[[row_number() ORDER BY [users.name ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\
\n TableScan: users projection=[user_id, name, paid_for_service]");
Ok(())
}
}
1 change: 1 addition & 0 deletions datafusion/substrait/tests/cases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

mod bugs_converage;
mod consumer_integration;
mod function_test;
mod logical_plans;
Expand Down
113 changes: 113 additions & 0 deletions datafusion/substrait/tests/testdata/extra_projection_with_input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
{
"extensionUris": [
{
"extensionUriAnchor": 1,
"uri": "/functions_arithmetic.yaml"
}
],
"extensions": [
{
"extensionFunction": {
"extensionUriReference": 1,
"functionAnchor": 1,
"name": "row_number"
}
}
],
"relations": [
{
"root": {
"input": {
"project": {
"common": {
"direct": {}
},
"input": {
"read": {
"common": {
"direct": {}
},
"baseSchema": {
"names": [
"user_id",
"name",
"paid_for_service"
],
"struct": {
"types": [
{
"string": {
"nullability": "NULLABILITY_REQUIRED"
}
},
{
"string": {
"nullability": "NULLABILITY_REQUIRED"
}
},
{
"bool": {
"nullability": "NULLABILITY_REQUIRED"
}
}
],
"nullability": "NULLABILITY_REQUIRED"
}
},
"namedTable": {
"names": [
"users"
]
}
}
},
"expressions": [
{
"windowFunction": {
"functionReference": 1,
"sorts": [
{
"expr": {
"selection": {
"directReference": {
"structField": {
"field": 1
}
},
"rootReference": {}
}
},
"direction": "SORT_DIRECTION_ASC_NULLS_FIRST"
}
],
"upperBound": {
"unbounded": {}
},
"lowerBound": {
"unbounded": {}
},
"outputType": {
"i64": {
"nullability": "NULLABILITY_REQUIRED"
}
},
"invocation": "AGGREGATION_INVOCATION_ALL"
}
}
]
}
},
"names": [
"user_id",
"name",
"paid_for_service",
"row_number"
]
}
}
],
"version": {
"minorNumber": 52,
"producer": "spark-substrait-gateway"
}
}

0 comments on commit d8d8abb

Please sign in to comment.