diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index a923aaf31abb..fada827875b0 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -61,7 +61,9 @@ use substrait::proto::expression::literal::{ use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; use substrait::proto::read_rel::VirtualTable; -use substrait::proto::{CrossRel, ExchangeRel}; +use substrait::proto::rel_common::EmitKind; +use substrait::proto::rel_common::EmitKind::Emit; +use substrait::proto::{rel_common, CrossRel, ExchangeRel, RelCommon}; use substrait::{ proto::{ aggregate_function::AggregationInvocation, @@ -219,9 +221,20 @@ pub fn to_substrait_rel( .iter() .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extensions)) .collect::>>()?; + + let emit_kind = create_project_remapping( + expressions.len(), + p.input.as_ref().schema().fields().len(), + ); + let common = RelCommon { + emit_kind: Some(emit_kind), + hint: None, + advanced_extension: None, + }; + Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { - common: None, + common: Some(common), input: Some(to_substrait_rel(p.input.as_ref(), ctx, extensions)?), expressions, advanced_extension: None, @@ -432,29 +445,15 @@ pub fn to_substrait_rel( } LogicalPlan::Window(window) => { let input = to_substrait_rel(window.input.as_ref(), ctx, extensions)?; - // If the input is a Project relation, we can just append the WindowFunction expressions - // before returning - // Otherwise, wrap the input in a Project relation before appending the WindowFunction - // expressions - let mut project_rel: Box = match &input.as_ref().rel_type { - Some(RelType::Project(p)) => Box::new(*p.clone()), - _ => { - // Create Projection with field referencing all output fields in the input relation - let expressions = (0..window.input.schema().fields().len()) - .map(substrait_field_ref) - .collect::>>()?; - Box::new(ProjectRel { - common: None, - input: Some(input), - expressions, - advanced_extension: None, - }) - } - }; - // Parse WindowFunction expression - let mut window_exprs = vec![]; + + // create a field reference for each input field + let mut expressions = (0..window.input.schema().fields().len()) + .map(substrait_field_ref) + .collect::>>()?; + + // process and add each window function expression for expr in &window.window_expr { - window_exprs.push(to_substrait_rex( + expressions.push(to_substrait_rex( ctx, expr, window.input.schema(), @@ -462,8 +461,23 @@ pub fn to_substrait_rel( extensions, )?); } - // Append parsed WindowFunction expressions - project_rel.expressions.extend(window_exprs); + + let emit_kind = create_project_remapping( + expressions.len(), + window.input.schema().fields().len(), + ); + let common = RelCommon { + emit_kind: Some(emit_kind), + hint: None, + advanced_extension: None, + }; + let project_rel = Box::new(ProjectRel { + common: Some(common), + input: Some(input), + expressions, + advanced_extension: None, + }); + Ok(Box::new(Rel { rel_type: Some(RelType::Project(project_rel)), })) @@ -553,6 +567,19 @@ pub fn to_substrait_rel( } } +/// By default, a Substrait Project outputs all input fields followed by all expressions. +/// A DataFusion Projection only outputs expressions. In order to keep the Substrait +/// plan consistent with DataFusion, we must apply an output mapping that skips the input +/// fields so that the Substrait Project will only output the expression fields. +fn create_project_remapping(expr_count: usize, input_field_count: usize) -> EmitKind { + let expression_field_start = input_field_count; + let expression_field_end = expression_field_start + expr_count; + let output_mapping = (expression_field_start..expression_field_end) + .map(|i| i as i32) + .collect(); + Emit(rel_common::Emit { output_mapping }) +} + fn to_substrait_named_struct( schema: &DFSchemaRef, extensions: &mut Extensions, diff --git a/datafusion/substrait/tests/cases/serialize.rs b/datafusion/substrait/tests/cases/serialize.rs index d792ac33c333..da0898d222c4 100644 --- a/datafusion/substrait/tests/cases/serialize.rs +++ b/datafusion/substrait/tests/cases/serialize.rs @@ -26,7 +26,11 @@ mod tests { use datafusion::error::Result; use datafusion::prelude::*; + use datafusion_substrait::logical_plan::producer::to_substrait_plan; use std::fs; + use substrait::proto::plan_rel::RelType; + use substrait::proto::rel_common::{Emit, EmitKind}; + use substrait::proto::{rel, RelCommon}; #[tokio::test] async fn serialize_simple_select() -> Result<()> { @@ -63,6 +67,103 @@ mod tests { Ok(()) } + #[tokio::test] + async fn include_remaps_for_projects() -> Result<()> { + let ctx = create_context().await?; + let df = ctx.sql("SELECT b, a + a, a FROM data").await?; + let datafusion_plan = df.into_optimized_plan()?; + + assert_eq!( + format!("{}", datafusion_plan), + "Projection: data.b, data.a + data.a, data.a\ + \n TableScan: data projection=[a, b]", + ); + + let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone(); + + let relation = plan.relations.first().unwrap().rel_type.as_ref(); + let root_rel = match relation { + Some(RelType::Root(root)) => root.input.as_ref().unwrap(), + _ => panic!("expected Root"), + }; + if let Some(rel::RelType::Project(p)) = root_rel.rel_type.as_ref() { + // The input has 2 columns [a, b], the Projection has 3 expressions [b, a + a, a] + // The required output mapping is [2,3,4], which skips the 2 input columns. + assert_emit(p.common.as_ref(), vec![2, 3, 4]); + + if let Some(rel::RelType::Read(r)) = + p.input.as_ref().unwrap().rel_type.as_ref() + { + let mask_expression = r.projection.as_ref().unwrap(); + let select = mask_expression.select.as_ref().unwrap(); + assert_eq!( + 2, + select.struct_items.len(), + "Read outputs two columns: a, b" + ); + return Ok(()); + } + } + panic!("plan did not match expected structure") + } + + #[tokio::test] + async fn include_remaps_for_windows() -> Result<()> { + let ctx = create_context().await?; + // let df = ctx.sql("SELECT a, b, lead(b) OVER (PARTITION BY a) FROM data").await?; + let df = ctx + .sql("SELECT b, RANK() OVER (PARTITION BY a), c FROM data;") + .await?; + let datafusion_plan = df.into_optimized_plan()?; + assert_eq!( + format!("{}", datafusion_plan), + "Projection: data.b, RANK() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, data.c\ + \n WindowAggr: windowExpr=[[RANK() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n TableScan: data projection=[a, b, c]", + ); + + let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone(); + + let relation = plan.relations.first().unwrap().rel_type.as_ref(); + let root_rel = match relation { + Some(RelType::Root(root)) => root.input.as_ref().unwrap(), + _ => panic!("expected Root"), + }; + + if let Some(rel::RelType::Project(p1)) = root_rel.rel_type.as_ref() { + // The WindowAggr outputs 4 columns, the Projection has 4 columns + assert_emit(p1.common.as_ref(), vec![4, 5, 6]); + + if let Some(rel::RelType::Project(p2)) = + p1.input.as_ref().unwrap().rel_type.as_ref() + { + // The input has 3 columns, the WindowAggr has 4 expression + assert_emit(p2.common.as_ref(), vec![3, 4, 5, 6]); + + if let Some(rel::RelType::Read(r)) = + p2.input.as_ref().unwrap().rel_type.as_ref() + { + let mask_expression = r.projection.as_ref().unwrap(); + let select = mask_expression.select.as_ref().unwrap(); + assert_eq!( + 3, + select.struct_items.len(), + "Read outputs three columns: a, b, c" + ); + return Ok(()); + } + } + } + panic!("plan did not match expected structure") + } + + fn assert_emit(rel_common: Option<&RelCommon>, output_mapping: Vec) { + assert_eq!( + rel_common.unwrap().emit_kind.clone(), + Some(EmitKind::Emit(Emit { output_mapping })) + ); + } + async fn create_context() -> Result { let ctx = SessionContext::new(); ctx.register_csv("data", "tests/testdata/data.csv", CsvReadOptions::new())