From e3f3466d1c27faedf2485cc177e0f916bd6e6d41 Mon Sep 17 00:00:00 2001 From: Anlin Chen Date: Mon, 24 Feb 2025 14:58:40 -0500 Subject: [PATCH 1/5] feat: add add_implicit_group_by_exprs option to logical plan builder --- datafusion/expr/src/logical_plan/builder.rs | 48 +++++++++++++++++++-- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index da30f2d7a712..e1ca73228a43 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -63,6 +63,32 @@ use indexmap::IndexSet; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; +#[derive(Debug, Clone)] +pub struct LogicalPlanBuilderOptions { + /// Flag indicating whether the plan builder should add + /// functionally dependent expressions as additional aggregation groupings. + add_implicit_group_by_exprs: bool, +} + +impl Default for LogicalPlanBuilderOptions { + fn default() -> Self { + Self { + add_implicit_group_by_exprs: false, + } + } +} + +impl LogicalPlanBuilderOptions { + pub fn new() -> Self { + Default::default() + } + + pub fn with_add_implicit_group_by_exprs(mut self, add: bool) -> Self { + self.add_implicit_group_by_exprs = add; + self + } +} + /// Builder for logical plans /// /// # Example building a simple plan @@ -103,6 +129,7 @@ pub const UNNAMED_TABLE: &str = "?table?"; #[derive(Debug, Clone)] pub struct LogicalPlanBuilder { plan: Arc, + options: LogicalPlanBuilderOptions, } impl LogicalPlanBuilder { @@ -110,12 +137,21 @@ impl LogicalPlanBuilder { pub fn new(plan: LogicalPlan) -> Self { Self { plan: Arc::new(plan), + options: LogicalPlanBuilderOptions::default(), } } /// Create a builder from an existing plan pub fn new_from_arc(plan: Arc) -> Self { - Self { plan } + Self { + plan, + options: LogicalPlanBuilderOptions::default(), + } + } + + pub fn with_options(mut self, options: LogicalPlanBuilderOptions) -> Self { + self.options = options; + self } /// Return the output schema of the plan build so far @@ -1138,8 +1174,12 @@ impl LogicalPlanBuilder { let group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; - let group_expr = - add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; + let group_expr = if self.options.add_implicit_group_by_exprs { + add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())? + } else { + group_expr + }; + Aggregate::try_new(self.plan, group_expr, aggr_expr) .map(LogicalPlan::Aggregate) .map(Self::new) @@ -1550,6 +1590,7 @@ pub fn add_group_by_exprs_from_dependencies( } Ok(group_expr) } + /// Errors if one or more expressions have equal names. pub fn validate_unique_names<'a>( node_name: &str, @@ -2023,7 +2064,6 @@ pub fn unnest_with_options( #[cfg(test)] mod tests { - use super::*; use crate::logical_plan::StringifiedPlan; use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery}; From 1a2ca1a15b66b100d5dbcd26e424131568a90906 Mon Sep 17 00:00:00 2001 From: Anlin Chen Date: Mon, 24 Feb 2025 14:59:14 -0500 Subject: [PATCH 2/5] fix: do not add implicity group by exprs in substrait path --- datafusion/core/src/dataframe/mod.rs | 6 +++++- datafusion/expr/src/logical_plan/mod.rs | 2 +- datafusion/sql/src/select.rs | 9 ++++++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 6f540fa02c75..b6949d2eea9c 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -33,7 +33,8 @@ use crate::execution::context::{SessionState, TaskContext}; use crate::execution::FunctionRegistry; use crate::logical_expr::utils::find_window_exprs; use crate::logical_expr::{ - col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Partitioning, TableType, + col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, LogicalPlanBuilderOptions, + Partitioning, TableType, }; use crate::physical_plan::{ collect, collect_partitioned, execute_stream, execute_stream_partitioned, @@ -526,7 +527,10 @@ impl DataFrame { ) -> Result { let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); let aggr_expr_len = aggr_expr.len(); + let options = + LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true); let plan = LogicalPlanBuilder::from(self.plan) + .with_options(options) .aggregate(group_expr, aggr_expr)? .build()?; let plan = if is_grouping_set { diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 404941378663..916b2131be04 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -28,7 +28,7 @@ pub mod tree_node; pub use builder::{ build_join_schema, table_scan, union, wrap_projection_for_join_if_necessary, - LogicalPlanBuilder, LogicalTableSource, UNNAMED_TABLE, + LogicalPlanBuilder, LogicalPlanBuilderOptions, LogicalTableSource, UNNAMED_TABLE, }; pub use ddl::{ CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction, diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 05782e6ecd75..e21def4c3941 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -38,7 +38,8 @@ use datafusion_expr::utils::{ }; use datafusion_expr::{ qualified_wildcard_with_options, wildcard_with_options, Aggregate, Expr, Filter, - GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning, + GroupingSet, LogicalPlan, LogicalPlanBuilder, LogicalPlanBuilderOptions, + Partitioning, }; use indexmap::IndexMap; @@ -371,7 +372,10 @@ impl SqlToRel<'_, S> { let agg_expr = agg.aggr_expr.clone(); let (new_input, new_group_by_exprs) = self.try_process_group_by_unnest(agg)?; + let options = LogicalPlanBuilderOptions::new() + .with_add_implicit_group_by_exprs(true); LogicalPlanBuilder::from(new_input) + .with_options(options) .aggregate(new_group_by_exprs, agg_expr)? .build() } @@ -744,7 +748,10 @@ impl SqlToRel<'_, S> { aggr_exprs: &[Expr], ) -> Result<(LogicalPlan, Vec, Option)> { // create the aggregate plan + let options = + LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true); let plan = LogicalPlanBuilder::from(input.clone()) + .with_options(options) .aggregate(group_by_exprs.to_vec(), aggr_exprs.to_vec())? .build()?; let group_by_exprs = if let LogicalPlan::Aggregate(agg) = &plan { From 99a4c791526696546f257ee61d731b016c3ce039 Mon Sep 17 00:00:00 2001 From: Anlin Chen Date: Mon, 24 Feb 2025 15:08:54 -0500 Subject: [PATCH 3/5] test: add substrait tests --- .../substrait/tests/cases/logical_plans.rs | 18 ++ .../tests/cases/roundtrip_logical_plan.rs | 11 + .../multilayer_aggregate.substrait.json | 213 ++++++++++++++++++ 3 files changed, 242 insertions(+) create mode 100644 datafusion/substrait/tests/testdata/test_plans/multilayer_aggregate.substrait.json diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 65f404bbda55..6f5899595548 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -91,4 +91,22 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn multilayer_aggregate() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/multilayer_aggregate.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; + + assert_eq!( + format!("{}", plan), + "Projection: lower(sales.product) AS lower(product), sum(count(sales.product)) AS product_count\ + \n Aggregate: groupBy=[[sales.product]], aggr=[[sum(count(sales.product))]]\ + \n Aggregate: groupBy=[[sales.product]], aggr=[[count(sales.product)]]\ + \n TableScan: sales" + ); + + Ok(()) + } } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 68856117a38c..57363eb390ef 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -308,6 +308,17 @@ async fn aggregate_grouping_rollup() -> Result<()> { ).await } +#[tokio::test] +async fn multilayer_aggregate() -> Result<()> { + assert_expected_plan( + "SELECT a, sum(partial_count_b) FROM (SELECT a, count(b) as partial_count_b FROM data GROUP BY a) GROUP BY a", + "Aggregate: groupBy=[[data.a]], aggr=[[sum(count(data.b)) AS sum(partial_count_b)]]\ + \n Aggregate: groupBy=[[data.a]], aggr=[[count(data.b)]]\ + \n TableScan: data projection=[a, b]", + true + ).await +} + #[tokio::test] async fn decimal_literal() -> Result<()> { roundtrip("SELECT * FROM data WHERE b > 2.5").await diff --git a/datafusion/substrait/tests/testdata/test_plans/multilayer_aggregate.substrait.json b/datafusion/substrait/tests/testdata/test_plans/multilayer_aggregate.substrait.json new file mode 100644 index 000000000000..1f47b916daf0 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/multilayer_aggregate.substrait.json @@ -0,0 +1,213 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_aggregate_generic.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_arithmetic.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_string.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "count:any" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "sum:i64" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "lower:str" + } + }], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [2, 3] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "product" + ], + "struct": { + "types": [ + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "sales" + ] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }], + "expressionReferences": [] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }], + "groupingExpressions": [] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }], + "expressionReferences": [] + }], + "measures": [{ + "measure": { + "functionReference": 1, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }], + "groupingExpressions": [] + } + }, + "expressions": [{ + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "string": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }] + } + }, + "names": ["lower(product)", "product_count"] + } + }], + "expectedTypeUrls": [] +} From 3687bc4b3975d59bbca945515ba8499829cc742c Mon Sep 17 00:00:00 2001 From: Anlin Chen Date: Mon, 24 Feb 2025 15:59:30 -0500 Subject: [PATCH 4/5] test: add builder option tests --- datafusion/expr/src/logical_plan/builder.rs | 78 +++++++++++++++++++-- 1 file changed, 74 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index e1ca73228a43..169b5d667a01 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -53,8 +53,9 @@ use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ exec_err, get_target_functional_dependencies, internal_err, not_impl_err, - plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, - Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, + plan_datafusion_err, plan_err, Column, Constraint, Constraints, DFSchema, + DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, ToDFSchema, + UnnestOptions, }; use datafusion_expr_common::type_coercion::binary::type_union_resolution; @@ -1726,7 +1727,21 @@ pub fn table_scan_with_filter_and_fetch( pub fn table_source(table_schema: &Schema) -> Arc { let table_schema = Arc::new(table_schema.clone()); - Arc::new(LogicalTableSource { table_schema }) + Arc::new(LogicalTableSource { + table_schema, + constraints: Default::default(), + }) +} + +pub fn table_source_with_constraints( + table_schema: &Schema, + constraints: Constraints, +) -> Arc { + let table_schema = Arc::new(table_schema.clone()); + Arc::new(LogicalTableSource { + table_schema, + constraints, + }) } /// Wrap projection for a plan, if the join keys contains normal expression. @@ -1797,12 +1812,21 @@ pub fn wrap_projection_for_join_if_necessary( /// DefaultTableSource. pub struct LogicalTableSource { table_schema: SchemaRef, + constraints: Constraints, } impl LogicalTableSource { /// Create a new LogicalTableSource pub fn new(table_schema: SchemaRef) -> Self { - Self { table_schema } + Self { + table_schema, + constraints: Constraints::default(), + } + } + + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.constraints = constraints; + self } } @@ -1815,6 +1839,10 @@ impl TableSource for LogicalTableSource { Arc::clone(&self.table_schema) } + fn constraints(&self) -> Option<&Constraints> { + Some(&self.constraints) + } + fn supports_filters_pushdown( &self, filters: &[&Expr], @@ -2068,6 +2096,7 @@ mod tests { use crate::logical_plan::StringifiedPlan; use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery}; + use crate::test::function_stub::sum; use datafusion_common::{RecursionUnnestOption, SchemaError}; #[test] @@ -2615,4 +2644,45 @@ mod tests { Ok(()) } + + #[test] + fn plan_builder_aggregate_without_implicit_group_by_exprs() -> Result<()> { + let constraints = + Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]); + let table_source = table_source_with_constraints(&employee_schema(), constraints); + + let plan = + LogicalPlanBuilder::scan("employee_csv", table_source, Some(vec![0, 3, 4]))? + .aggregate(vec![col("id")], vec![sum(col("salary"))])? + .build()?; + + let expected = + "Aggregate: groupBy=[[employee_csv.id]], aggr=[[sum(employee_csv.salary)]]\ + \n TableScan: employee_csv projection=[id, state, salary]"; + assert_eq!(expected, format!("{plan}")); + + Ok(()) + } + + #[test] + fn plan_builder_aggregate_with_implicit_group_by_exprs() -> Result<()> { + let constraints = + Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]); + let table_source = table_source_with_constraints(&employee_schema(), constraints); + + let options = + LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true); + let plan = + LogicalPlanBuilder::scan("employee_csv", table_source, Some(vec![0, 3, 4]))? + .with_options(options) + .aggregate(vec![col("id")], vec![sum(col("salary"))])? + .build()?; + + let expected = + "Aggregate: groupBy=[[employee_csv.id, employee_csv.state, employee_csv.salary]], aggr=[[sum(employee_csv.salary)]]\ + \n TableScan: employee_csv projection=[id, state, salary]"; + assert_eq!(expected, format!("{plan}")); + + Ok(()) + } } From 600d61604c50f3c57755d349a54a5093d8eb45b5 Mon Sep 17 00:00:00 2001 From: Anlin Chen Date: Mon, 24 Feb 2025 16:57:11 -0500 Subject: [PATCH 5/5] style: clippy errors --- datafusion/expr/src/logical_plan/builder.rs | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 169b5d667a01..4d825c6bfe49 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -53,9 +53,8 @@ use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ exec_err, get_target_functional_dependencies, internal_err, not_impl_err, - plan_datafusion_err, plan_err, Column, Constraint, Constraints, DFSchema, - DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, ToDFSchema, - UnnestOptions, + plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, + DataFusionError, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; use datafusion_expr_common::type_coercion::binary::type_union_resolution; @@ -64,26 +63,20 @@ use indexmap::IndexSet; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; -#[derive(Debug, Clone)] +/// Options for [`LogicalPlanBuilder`] +#[derive(Default, Debug, Clone)] pub struct LogicalPlanBuilderOptions { /// Flag indicating whether the plan builder should add /// functionally dependent expressions as additional aggregation groupings. add_implicit_group_by_exprs: bool, } -impl Default for LogicalPlanBuilderOptions { - fn default() -> Self { - Self { - add_implicit_group_by_exprs: false, - } - } -} - impl LogicalPlanBuilderOptions { pub fn new() -> Self { Default::default() } + /// Should the builder add functionally dependent expressions as additional aggregation groupings. pub fn with_add_implicit_group_by_exprs(mut self, add: bool) -> Self { self.add_implicit_group_by_exprs = add; self @@ -2097,7 +2090,7 @@ mod tests { use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery}; use crate::test::function_stub::sum; - use datafusion_common::{RecursionUnnestOption, SchemaError}; + use datafusion_common::{Constraint, RecursionUnnestOption, SchemaError}; #[test] fn plan_builder_simple() -> Result<()> {