diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index b1b510f1792de..97c78f7b5aa11 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -33,6 +33,7 @@ use datafusion::logical_expr::{ expr::find_df_window_func, Aggregate, BinaryExpr, Case, EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, Values, }; +use pbjson_types::field; use substrait::proto::expression::subquery::set_predicate::PredicateOp; use url::Url; @@ -194,6 +195,27 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality( (accum_join_keys, nulls_equal_nulls, join_filter) } +pub(crate) fn equivalent_names_and_types_ignore_order( + schema1: &DFSchema, + schema2: &Arc +) -> bool { + if schema1.fields().len() != schema2.fields().len() { + return false; + } + + let self_fields: HashSet<_> = schema1.fields() + .iter() + .map(|field| (field.name().to_owned(), field.data_type().clone())) + .collect(); + + let other_fields: HashSet<_> = schema2.fields() + .iter() + .map(|field| (field.name().to_owned(), field.data_type().clone())) + .collect(); + + self_fields == other_fields +} + /// Convert Substrait Plan to DataFusion LogicalPlan pub async fn from_substrait_plan( ctx: &SessionContext, @@ -214,13 +236,39 @@ pub async fn from_substrait_plan( Ok(from_substrait_rel(ctx, rel, &extensions).await?) }, plan_rel::RelType::Root(root) => { - let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &extensions).await?; + let mut plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &extensions).await?; if root.names.is_empty() { // Backwards compatibility for plans missing names return Ok(plan); } + // if the names didn't match the root plan's schema, we try to add projections visa rebuild + let plan_schema = plan.schema(); + // only cares Projection and Aggregation which all has 1 input + let plan_input_schema = plan.inputs()[0].schema(); + let missed_expr: Vec = root.names.iter() + .filter_map(|name| { + if !plan_schema.has_column_with_unqualified_name(name) && + plan_input_schema.has_column_with_unqualified_name(name) { + // we can safely unwrap here + let (qualifier, field) = plan_input_schema.qualified_field_with_unqualified_name(name).unwrap(); + let res = Expr::from(Column::from((qualifier, field))); + println!("res is {:?}",res); + Some(res) + } else { + None + } + }).collect(); + println!("input's schema is {:?}", plan.inputs()[0].schema()); + if !missed_expr.is_empty() { + if let LogicalPlan::Projection(mut projection) = plan { + projection.expr.extend(missed_expr); + plan = LogicalPlan::Projection(Projection::try_new(projection.expr, projection.input)?) + } + } + println!("new plan is {:?}", plan); let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?; - if renamed_schema.equivalent_names_and_types(plan.schema()) { + println!("renamed_schema is {:?} and plan.schema() is {:?}", renamed_schema, plan.schema()); + if equivalent_names_and_types_ignore_order(&renamed_schema, plan.schema()) { // Nothing to do if the schema is already equivalent return Ok(plan); } @@ -228,15 +276,36 @@ pub async fn from_substrait_plan( match plan { // If the last node of the plan produces expressions, bake the renames into those expressions. // This isn't necessary for correctness, but helps with roundtrip tests. - LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), &renamed_schema)?, p.input)?)), + LogicalPlan::Projection(p) => { + Ok(LogicalPlan::Projection( + Projection::try_new( + rename_expressions(p.expr, p.input.schema(), &renamed_schema)?, + p.input + )? + )) + }, + LogicalPlan::Aggregate(a) => { let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), &renamed_schema)?; - Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?)) + Ok(LogicalPlan::Aggregate( + Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)? + )) }, + // There are probably more plans where we could bake things in, can add them later as needed. // Otherwise, add a new Project to handle the renaming. - _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), &renamed_schema)?, Arc::new(plan))?)) + _ => Ok(LogicalPlan::Projection( + Projection::try_new( + rename_expressions( + plan.schema().columns().iter().map(|c| col(c.to_owned())), + plan.schema(), + &renamed_schema + )?, + Arc::new(plan) + )? + )), } + } }, None => plan_err!("Cannot parse plan relation: None") @@ -358,16 +427,18 @@ fn make_renamed_schema( name_idx, )?), ))), - _ => Ok(dtype.to_owned()), + _ => { + Ok(dtype.to_owned()) + } } } let mut name_idx = 0; - let (qualifiers, fields): (_, Vec) = schema .iter() .map(|(q, f)| { let name = next_struct_field_name(0, dfs_names, &mut name_idx)?; + println!("f is {:?}", f); Ok(( q.cloned(), (**f) @@ -390,7 +461,7 @@ fn make_renamed_schema( name_idx, dfs_names.len()); } - + println!("all fields are {:?}", fields); DFSchema::from_field_specific_qualified_schema( qualifiers, &Arc::new(Schema::new(fields)), diff --git a/datafusion/substrait/tests/cases/bugs_converage.rs b/datafusion/substrait/tests/cases/bugs_converage.rs new file mode 100644 index 0000000000000..5607e5f72d1ec --- /dev/null +++ b/datafusion/substrait/tests/cases/bugs_converage.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for bugs in substrait + +#[cfg(test)] +mod tests { + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::common::Result; + use datafusion::datasource::MemTable; + use datafusion::prelude::SessionContext; + use datafusion_substrait::logical_plan::consumer::from_substrait_plan; + use std::fs::File; + use std::io::BufReader; + use std::sync::Arc; + use substrait::proto::Plan; + #[tokio::test] + async fn extra_projection_with_input() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Schema::new(vec![ + Field::new("user_id", DataType::Utf8, false), + Field::new("name", DataType::Utf8, false), + Field::new("paid_for_service", DataType::Boolean, false), + ]); + let memory_table = MemTable::try_new(schema.into(), vec![vec![]]).unwrap(); + ctx.register_table("users", Arc::new(memory_table))?; + let path = "tests/testdata/extra_projection_with_input.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{}", plan); + println!("{:?}", plan_str); + Ok(()) + } +} diff --git a/datafusion/substrait/tests/cases/mod.rs b/datafusion/substrait/tests/cases/mod.rs index d3ea7695e4b9e..816790388660c 100644 --- a/datafusion/substrait/tests/cases/mod.rs +++ b/datafusion/substrait/tests/cases/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +mod bugs_converage; mod consumer_integration; mod function_test; mod logical_plans; diff --git a/datafusion/substrait/tests/testdata/extra_projection_with_input.json b/datafusion/substrait/tests/testdata/extra_projection_with_input.json new file mode 100644 index 0000000000000..41b93a8f2e10f --- /dev/null +++ b/datafusion/substrait/tests/testdata/extra_projection_with_input.json @@ -0,0 +1,113 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "row_number" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "user_id", + "name", + "paid_for_service" + ], + "struct": { + "types": [ + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "users" + ] + } + } + }, + "expressions": [ + { + "windowFunction": { + "functionReference": 1, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_FIRST" + } + ], + "upperBound": { + "unbounded": {} + }, + "lowerBound": { + "unbounded": {} + }, + "outputType": { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL" + } + } + ] + } + }, + "names": [ + "user_id", + "name", + "paid_for_service", + "row_number" + ] + } + } + ], + "version": { + "minorNumber": 52, + "producer": "spark-substrait-gateway" + } +} \ No newline at end of file