diff --git a/datafusion/substrait/src/extensions.rs b/datafusion/substrait/src/extensions.rs index 459d0e0c5ae58..9bff79371dca0 100644 --- a/datafusion/substrait/src/extensions.rs +++ b/datafusion/substrait/src/extensions.rs @@ -33,6 +33,7 @@ pub struct Extensions { pub functions: HashMap, // anchor -> function name pub types: HashMap, // anchor -> type name pub type_variations: HashMap, // anchor -> type variation name + pub names: Option>, } impl Extensions { @@ -75,6 +76,11 @@ impl Extensions { } } } + /// with the predefined names + pub fn with_projection_names(mut self, names: Vec) -> Self { + self.names = Some(names); + self + } } impl TryFrom<&Vec> for Extensions { @@ -107,6 +113,7 @@ impl TryFrom<&Vec> for Extensions { functions, types, type_variations, + names: None, }) } } diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index b1b510f1792de..890a7c4e87706 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -200,7 +200,7 @@ pub async fn from_substrait_plan( plan: &Plan, ) -> Result { // Register function extension - let extensions = Extensions::try_from(&plan.extensions)?; + let mut extensions = Extensions::try_from(&plan.extensions)?; if !extensions.type_variations.is_empty() { return not_impl_err!("Type variation extensions are not supported"); } @@ -214,6 +214,9 @@ pub async fn from_substrait_plan( Ok(from_substrait_rel(ctx, rel, &extensions).await?) }, plan_rel::RelType::Root(root) => { + if !root.names.is_empty() { + extensions = extensions.with_projection_names(root.names.clone()); + } let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &extensions).await?; if root.names.is_empty() { // Backwards compatibility for plans missing names @@ -228,14 +231,32 @@ pub async fn from_substrait_plan( match plan { // If the last node of the plan produces expressions, bake the renames into those expressions. // This isn't necessary for correctness, but helps with roundtrip tests. - LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), &renamed_schema)?, p.input)?)), + LogicalPlan::Projection(p) => { + Ok(LogicalPlan::Projection( + Projection::try_new( + rename_expressions(p.expr, p.input.schema(), &renamed_schema)?, + p.input + )? + )) + }, LogicalPlan::Aggregate(a) => { let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), &renamed_schema)?; - Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)?)) + Ok(LogicalPlan::Aggregate( + Aggregate::try_new(a.input, a.group_expr, new_aggr_exprs)? + )) }, // There are probably more plans where we could bake things in, can add them later as needed. // Otherwise, add a new Project to handle the renaming. - _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), &renamed_schema)?, Arc::new(plan))?)) + _ => Ok(LogicalPlan::Projection( + Projection::try_new( + rename_expressions( + plan.schema().columns().iter().map(|c| col(c.to_owned())), + plan.schema(), + &renamed_schema + )?, + Arc::new(plan) + )? + )), } } }, @@ -363,7 +384,6 @@ fn make_renamed_schema( } let mut name_idx = 0; - let (qualifiers, fields): (_, Vec) = schema .iter() .map(|(q, f)| { @@ -390,7 +410,6 @@ fn make_renamed_schema( name_idx, dfs_names.len()); } - DFSchema::from_field_specific_qualified_schema( qualifiers, &Arc::new(Schema::new(fields)), @@ -410,19 +429,19 @@ pub async fn from_substrait_rel( let mut input = LogicalPlanBuilder::from( from_substrait_rel(ctx, input, extensions).await?, ); + println!("{:?}", p); let mut names: HashSet = HashSet::new(); - let mut exprs: Vec = vec![]; + let mut exprs: Vec = Vec::new(); for e in &p.expressions { let x = from_substrait_rex(ctx, e, input.clone().schema(), extensions) .await?; + // if the expression is WindowFunction, wrap in a Window relation if let Expr::WindowFunction(_) = &x { - // Adding the same expression here and in the project below - // works because the project's builder uses columnize_expr(..) - // to transform it into a column reference input = input.window(vec![x.clone()])? } + // Ensure the expression has a unique display name, so that project's // validate_unique_names doesn't fail let name = x.schema_name().to_string(); @@ -439,6 +458,27 @@ pub async fn from_substrait_rel( } names.insert(new_name); } + let schema = input.schema(); + if let (Some(extensions_names), true) = + (extensions.names.as_ref(), p.common.is_some()) + { + extensions_names.iter().for_each(|name| { + if let Ok(field) = + schema.qualified_field_with_unqualified_name(name) + { + let expr = Expr::from(Column::from(field)); + let schema_name = expr.schema_name().to_string(); + + if names.insert(schema_name.clone()) { + let position = extensions_names + .iter() + .position(|n| n == name) + .unwrap_or(exprs.len()); + exprs.insert(position, expr); + } + } + }); + } input.project(exprs)?.build() } else { not_impl_err!("Projection without an input is not supported") diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 72b6760be29c1..d756b6c08208e 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -2313,6 +2313,7 @@ mod test { INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string() )]), type_variations: HashMap::new(), + names: None, } ); @@ -2423,6 +2424,7 @@ mod test { INTERVAL_MONTH_DAY_NANO_TYPE_NAME.to_string() )]), type_variations: HashMap::new(), + names: None, } ); diff --git a/datafusion/substrait/tests/cases/bugs_converage.rs b/datafusion/substrait/tests/cases/bugs_converage.rs new file mode 100644 index 0000000000000..da3d9825f5411 --- /dev/null +++ b/datafusion/substrait/tests/cases/bugs_converage.rs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for bugs in substrait + +#[cfg(test)] +mod tests { + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::common::Result; + use datafusion::datasource::MemTable; + use datafusion::prelude::SessionContext; + use datafusion_substrait::logical_plan::consumer::from_substrait_plan; + use std::fs::File; + use std::io::BufReader; + use std::sync::Arc; + use substrait::proto::Plan; + #[tokio::test] + async fn extra_projection_with_input() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Schema::new(vec![ + Field::new("user_id", DataType::Utf8, false), + Field::new("name", DataType::Utf8, false), + Field::new("paid_for_service", DataType::Boolean, false), + ]); + let memory_table = MemTable::try_new(schema.into(), vec![vec![]]).unwrap(); + ctx.register_table("users", Arc::new(memory_table))?; + let path = "tests/testdata/extra_projection_with_input.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + let plan_str = format!("{}", plan); + assert_eq!(plan_str, "Projection: users.user_id, users.name, users.paid_for_service, row_number() ORDER BY [users.name ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS row_number\ + \n WindowAggr: windowExpr=[[row_number() ORDER BY [users.name ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n TableScan: users projection=[user_id, name, paid_for_service]"); + Ok(()) + } +} diff --git a/datafusion/substrait/tests/cases/mod.rs b/datafusion/substrait/tests/cases/mod.rs index d3ea7695e4b9e..816790388660c 100644 --- a/datafusion/substrait/tests/cases/mod.rs +++ b/datafusion/substrait/tests/cases/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +mod bugs_converage; mod consumer_integration; mod function_test; mod logical_plans; diff --git a/datafusion/substrait/tests/testdata/extra_projection_with_input.json b/datafusion/substrait/tests/testdata/extra_projection_with_input.json new file mode 100644 index 0000000000000..41b93a8f2e10f --- /dev/null +++ b/datafusion/substrait/tests/testdata/extra_projection_with_input.json @@ -0,0 +1,113 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "row_number" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "user_id", + "name", + "paid_for_service" + ], + "struct": { + "types": [ + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "users" + ] + } + } + }, + "expressions": [ + { + "windowFunction": { + "functionReference": 1, + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_FIRST" + } + ], + "upperBound": { + "unbounded": {} + }, + "lowerBound": { + "unbounded": {} + }, + "outputType": { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL" + } + } + ] + } + }, + "names": [ + "user_id", + "name", + "paid_for_service", + "row_number" + ] + } + } + ], + "version": { + "minorNumber": 52, + "producer": "spark-substrait-gateway" + } +} \ No newline at end of file