From cc0fee8cce92adf8235e2c743ef97ae1e1265cae Mon Sep 17 00:00:00 2001 From: Anlin Chen Date: Thu, 6 Feb 2025 16:23:04 -0500 Subject: [PATCH 1/2] fix(substrait): Substrait input plans should be interpreted literally. Do not implicitly add any expressions when building the LogicalPlan. --- datafusion/expr/src/logical_plan/builder.rs | 43 +++- .../substrait/src/logical_plan/consumer.rs | 7 +- .../substrait/tests/cases/logical_plans.rs | 18 ++ .../tests/cases/roundtrip_logical_plan.rs | 11 + .../multilayer_aggregate.substrait.json | 213 ++++++++++++++++++ 5 files changed, 279 insertions(+), 13 deletions(-) create mode 100644 datafusion/substrait/tests/testdata/test_plans/multilayer_aggregate.substrait.json diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index c7cff3ac26b1..55efa31f1771 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1090,11 +1090,31 @@ impl LogicalPlanBuilder { group_expr: impl IntoIterator>, aggr_expr: impl IntoIterator>, ) -> Result { - let group_expr = normalize_cols(group_expr, &self.plan)?; + self._aggregate(group_expr, aggr_expr, true) + } + + pub fn aggregate_without_implicit_group_by_exprs( + self, + group_expr: impl IntoIterator>, + aggr_expr: impl IntoIterator>, + ) -> Result { + self._aggregate(group_expr, aggr_expr, false) + } + + fn _aggregate( + self, + group_expr: impl IntoIterator>, + aggr_expr: impl IntoIterator>, + include_implicit_group_by_exprs: bool, + ) -> Result { + let mut group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; - let group_expr = - add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; + if include_implicit_group_by_exprs { + group_expr = + add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; + } + Aggregate::try_new(self.plan, group_expr, aggr_expr) .map(LogicalPlan::Aggregate) .map(Self::new) @@ -1235,7 +1255,7 @@ impl LogicalPlanBuilder { .map(|(l, r)| { let left_key = l.into(); let right_key = r.into(); - let mut left_using_columns = HashSet::new(); + let mut left_using_columns = HashSet::new(); expr_to_columns(&left_key, &mut left_using_columns)?; let normalized_left_key = normalize_col_with_schemas_and_ambiguity_check( left_key, @@ -1253,12 +1273,12 @@ impl LogicalPlanBuilder { // find valid equijoin find_valid_equijoin_key_pair( - &normalized_left_key, - &normalized_right_key, - self.plan.schema(), - right.schema(), - )?.ok_or_else(|| - plan_datafusion_err!( + &normalized_left_key, + &normalized_right_key, + self.plan.schema(), + right.schema(), + )?.ok_or_else(|| + plan_datafusion_err!( "can't create join plan, join key should belong to one input, error key: ({normalized_left_key},{normalized_right_key})" )) }) @@ -1495,7 +1515,7 @@ pub fn validate_unique_names<'a>( None => { unique_names.insert(name, (position, expr)); Ok(()) - }, + } Some((existing_position, existing_expr)) => { plan_err!("{node_name} require unique expression names \ but the expression \"{existing_expr}\" at position {existing_position} and \"{expr}\" \ @@ -1962,7 +1982,6 @@ pub fn unnest_with_options( #[cfg(test)] mod tests { - use super::*; use crate::logical_plan::StringifiedPlan; use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery}; diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 5a7d70c5e765..5e032ad41b80 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1245,7 +1245,12 @@ pub async fn from_aggregate_rel( }; aggr_exprs.push(agg_func?.as_ref().clone()); } - input.aggregate(group_exprs, aggr_exprs)?.build() + + // Do not include implicit group by expressions (from functional dependencies) when building plans from Substrait. + // Otherwise, the ordinal-based emits applied later will point to incorrect expressions. + input + .aggregate_without_implicit_group_by_exprs(group_exprs, aggr_exprs)? + .build() } else { not_impl_err!("Aggregate without an input is not valid") } diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 65f404bbda55..6f5899595548 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -91,4 +91,22 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn multilayer_aggregate() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/multilayer_aggregate.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; + + assert_eq!( + format!("{}", plan), + "Projection: lower(sales.product) AS lower(product), sum(count(sales.product)) AS product_count\ + \n Aggregate: groupBy=[[sales.product]], aggr=[[sum(count(sales.product))]]\ + \n Aggregate: groupBy=[[sales.product]], aggr=[[count(sales.product)]]\ + \n TableScan: sales" + ); + + Ok(()) + } } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 7045729493b1..921fc64a9057 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -300,6 +300,17 @@ async fn aggregate_grouping_rollup() -> Result<()> { ).await } +#[tokio::test] +async fn multilayer_aggregate() -> Result<()> { + assert_expected_plan( + "SELECT a, sum(partial_count_b) FROM (SELECT a, count(b) as partial_count_b FROM data GROUP BY a) GROUP BY a", + "Aggregate: groupBy=[[data.a]], aggr=[[sum(count(data.b)) AS sum(partial_count_b)]]\ + \n Aggregate: groupBy=[[data.a]], aggr=[[count(data.b)]]\ + \n TableScan: data projection=[a, b]", + true + ).await +} + #[tokio::test] async fn decimal_literal() -> Result<()> { roundtrip("SELECT * FROM data WHERE b > 2.5").await diff --git a/datafusion/substrait/tests/testdata/test_plans/multilayer_aggregate.substrait.json b/datafusion/substrait/tests/testdata/test_plans/multilayer_aggregate.substrait.json new file mode 100644 index 000000000000..1f47b916daf0 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/multilayer_aggregate.substrait.json @@ -0,0 +1,213 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_aggregate_generic.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_arithmetic.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_string.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "count:any" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "sum:i64" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "lower:str" + } + }], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [2, 3] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "product" + ], + "struct": { + "types": [ + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "sales" + ] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }], + "expressionReferences": [] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }], + "groupingExpressions": [] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }], + "expressionReferences": [] + }], + "measures": [{ + "measure": { + "functionReference": 1, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }], + "groupingExpressions": [] + } + }, + "expressions": [{ + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "string": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }] + } + }, + "names": ["lower(product)", "product_count"] + } + }], + "expectedTypeUrls": [] +} From ab20e446001b7c1e094b0880875f5ef8c59b5218 Mon Sep 17 00:00:00 2001 From: Anlin Chen Date: Wed, 12 Feb 2025 17:47:35 -0500 Subject: [PATCH 2/2] Rename _aggregate helper to aggregate_inner. Minor code syntax change to maintain variable immutability. --- datafusion/expr/src/logical_plan/builder.rs | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 55efa31f1771..897b346348f7 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1090,7 +1090,7 @@ impl LogicalPlanBuilder { group_expr: impl IntoIterator>, aggr_expr: impl IntoIterator>, ) -> Result { - self._aggregate(group_expr, aggr_expr, true) + self.aggregate_inner(group_expr, aggr_expr, true) } pub fn aggregate_without_implicit_group_by_exprs( @@ -1098,22 +1098,23 @@ impl LogicalPlanBuilder { group_expr: impl IntoIterator>, aggr_expr: impl IntoIterator>, ) -> Result { - self._aggregate(group_expr, aggr_expr, false) + self.aggregate_inner(group_expr, aggr_expr, false) } - fn _aggregate( + fn aggregate_inner( self, group_expr: impl IntoIterator>, aggr_expr: impl IntoIterator>, include_implicit_group_by_exprs: bool, ) -> Result { - let mut group_expr = normalize_cols(group_expr, &self.plan)?; + let group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; - if include_implicit_group_by_exprs { - group_expr = - add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; - } + let group_expr = if include_implicit_group_by_exprs { + add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())? + } else { + group_expr + }; Aggregate::try_new(self.plan, group_expr, aggr_expr) .map(LogicalPlan::Aggregate)