diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 6f540fa02c75..b6949d2eea9c 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -33,7 +33,8 @@ use crate::execution::context::{SessionState, TaskContext}; use crate::execution::FunctionRegistry; use crate::logical_expr::utils::find_window_exprs; use crate::logical_expr::{ - col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Partitioning, TableType, + col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, LogicalPlanBuilderOptions, + Partitioning, TableType, }; use crate::physical_plan::{ collect, collect_partitioned, execute_stream, execute_stream_partitioned, @@ -526,7 +527,10 @@ impl DataFrame { ) -> Result { let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); let aggr_expr_len = aggr_expr.len(); + let options = + LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true); let plan = LogicalPlanBuilder::from(self.plan) + .with_options(options) .aggregate(group_expr, aggr_expr)? .build()?; let plan = if is_grouping_set { diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index da30f2d7a712..4d825c6bfe49 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -53,8 +53,8 @@ use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ exec_err, get_target_functional_dependencies, internal_err, not_impl_err, - plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, - Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, + plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, + DataFusionError, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; use datafusion_expr_common::type_coercion::binary::type_union_resolution; @@ -63,6 +63,26 @@ use indexmap::IndexSet; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; +/// Options for [`LogicalPlanBuilder`] +#[derive(Default, Debug, Clone)] +pub struct LogicalPlanBuilderOptions { + /// Flag indicating whether the plan builder should add + /// functionally dependent expressions as additional aggregation groupings. + add_implicit_group_by_exprs: bool, +} + +impl LogicalPlanBuilderOptions { + pub fn new() -> Self { + Default::default() + } + + /// Should the builder add functionally dependent expressions as additional aggregation groupings. + pub fn with_add_implicit_group_by_exprs(mut self, add: bool) -> Self { + self.add_implicit_group_by_exprs = add; + self + } +} + /// Builder for logical plans /// /// # Example building a simple plan @@ -103,6 +123,7 @@ pub const UNNAMED_TABLE: &str = "?table?"; #[derive(Debug, Clone)] pub struct LogicalPlanBuilder { plan: Arc, + options: LogicalPlanBuilderOptions, } impl LogicalPlanBuilder { @@ -110,12 +131,21 @@ impl LogicalPlanBuilder { pub fn new(plan: LogicalPlan) -> Self { Self { plan: Arc::new(plan), + options: LogicalPlanBuilderOptions::default(), } } /// Create a builder from an existing plan pub fn new_from_arc(plan: Arc) -> Self { - Self { plan } + Self { + plan, + options: LogicalPlanBuilderOptions::default(), + } + } + + pub fn with_options(mut self, options: LogicalPlanBuilderOptions) -> Self { + self.options = options; + self } /// Return the output schema of the plan build so far @@ -1138,8 +1168,12 @@ impl LogicalPlanBuilder { let group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; - let group_expr = - add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; + let group_expr = if self.options.add_implicit_group_by_exprs { + add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())? + } else { + group_expr + }; + Aggregate::try_new(self.plan, group_expr, aggr_expr) .map(LogicalPlan::Aggregate) .map(Self::new) @@ -1550,6 +1584,7 @@ pub fn add_group_by_exprs_from_dependencies( } Ok(group_expr) } + /// Errors if one or more expressions have equal names. pub fn validate_unique_names<'a>( node_name: &str, @@ -1685,7 +1720,21 @@ pub fn table_scan_with_filter_and_fetch( pub fn table_source(table_schema: &Schema) -> Arc { let table_schema = Arc::new(table_schema.clone()); - Arc::new(LogicalTableSource { table_schema }) + Arc::new(LogicalTableSource { + table_schema, + constraints: Default::default(), + }) +} + +pub fn table_source_with_constraints( + table_schema: &Schema, + constraints: Constraints, +) -> Arc { + let table_schema = Arc::new(table_schema.clone()); + Arc::new(LogicalTableSource { + table_schema, + constraints, + }) } /// Wrap projection for a plan, if the join keys contains normal expression. @@ -1756,12 +1805,21 @@ pub fn wrap_projection_for_join_if_necessary( /// DefaultTableSource. pub struct LogicalTableSource { table_schema: SchemaRef, + constraints: Constraints, } impl LogicalTableSource { /// Create a new LogicalTableSource pub fn new(table_schema: SchemaRef) -> Self { - Self { table_schema } + Self { + table_schema, + constraints: Constraints::default(), + } + } + + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.constraints = constraints; + self } } @@ -1774,6 +1832,10 @@ impl TableSource for LogicalTableSource { Arc::clone(&self.table_schema) } + fn constraints(&self) -> Option<&Constraints> { + Some(&self.constraints) + } + fn supports_filters_pushdown( &self, filters: &[&Expr], @@ -2023,12 +2085,12 @@ pub fn unnest_with_options( #[cfg(test)] mod tests { - use super::*; use crate::logical_plan::StringifiedPlan; use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery}; - use datafusion_common::{RecursionUnnestOption, SchemaError}; + use crate::test::function_stub::sum; + use datafusion_common::{Constraint, RecursionUnnestOption, SchemaError}; #[test] fn plan_builder_simple() -> Result<()> { @@ -2575,4 +2637,45 @@ mod tests { Ok(()) } + + #[test] + fn plan_builder_aggregate_without_implicit_group_by_exprs() -> Result<()> { + let constraints = + Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]); + let table_source = table_source_with_constraints(&employee_schema(), constraints); + + let plan = + LogicalPlanBuilder::scan("employee_csv", table_source, Some(vec![0, 3, 4]))? + .aggregate(vec![col("id")], vec![sum(col("salary"))])? + .build()?; + + let expected = + "Aggregate: groupBy=[[employee_csv.id]], aggr=[[sum(employee_csv.salary)]]\ + \n TableScan: employee_csv projection=[id, state, salary]"; + assert_eq!(expected, format!("{plan}")); + + Ok(()) + } + + #[test] + fn plan_builder_aggregate_with_implicit_group_by_exprs() -> Result<()> { + let constraints = + Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]); + let table_source = table_source_with_constraints(&employee_schema(), constraints); + + let options = + LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true); + let plan = + LogicalPlanBuilder::scan("employee_csv", table_source, Some(vec![0, 3, 4]))? + .with_options(options) + .aggregate(vec![col("id")], vec![sum(col("salary"))])? + .build()?; + + let expected = + "Aggregate: groupBy=[[employee_csv.id, employee_csv.state, employee_csv.salary]], aggr=[[sum(employee_csv.salary)]]\ + \n TableScan: employee_csv projection=[id, state, salary]"; + assert_eq!(expected, format!("{plan}")); + + Ok(()) + } } diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 404941378663..916b2131be04 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -28,7 +28,7 @@ pub mod tree_node; pub use builder::{ build_join_schema, table_scan, union, wrap_projection_for_join_if_necessary, - LogicalPlanBuilder, LogicalTableSource, UNNAMED_TABLE, + LogicalPlanBuilder, LogicalPlanBuilderOptions, LogicalTableSource, UNNAMED_TABLE, }; pub use ddl::{ CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction, diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 05782e6ecd75..e21def4c3941 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -38,7 +38,8 @@ use datafusion_expr::utils::{ }; use datafusion_expr::{ qualified_wildcard_with_options, wildcard_with_options, Aggregate, Expr, Filter, - GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning, + GroupingSet, LogicalPlan, LogicalPlanBuilder, LogicalPlanBuilderOptions, + Partitioning, }; use indexmap::IndexMap; @@ -371,7 +372,10 @@ impl SqlToRel<'_, S> { let agg_expr = agg.aggr_expr.clone(); let (new_input, new_group_by_exprs) = self.try_process_group_by_unnest(agg)?; + let options = LogicalPlanBuilderOptions::new() + .with_add_implicit_group_by_exprs(true); LogicalPlanBuilder::from(new_input) + .with_options(options) .aggregate(new_group_by_exprs, agg_expr)? .build() } @@ -744,7 +748,10 @@ impl SqlToRel<'_, S> { aggr_exprs: &[Expr], ) -> Result<(LogicalPlan, Vec, Option)> { // create the aggregate plan + let options = + LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true); let plan = LogicalPlanBuilder::from(input.clone()) + .with_options(options) .aggregate(group_by_exprs.to_vec(), aggr_exprs.to_vec())? .build()?; let group_by_exprs = if let LogicalPlan::Aggregate(agg) = &plan { diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 65f404bbda55..6f5899595548 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -91,4 +91,22 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn multilayer_aggregate() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/multilayer_aggregate.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; + + assert_eq!( + format!("{}", plan), + "Projection: lower(sales.product) AS lower(product), sum(count(sales.product)) AS product_count\ + \n Aggregate: groupBy=[[sales.product]], aggr=[[sum(count(sales.product))]]\ + \n Aggregate: groupBy=[[sales.product]], aggr=[[count(sales.product)]]\ + \n TableScan: sales" + ); + + Ok(()) + } } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 68856117a38c..57363eb390ef 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -308,6 +308,17 @@ async fn aggregate_grouping_rollup() -> Result<()> { ).await } +#[tokio::test] +async fn multilayer_aggregate() -> Result<()> { + assert_expected_plan( + "SELECT a, sum(partial_count_b) FROM (SELECT a, count(b) as partial_count_b FROM data GROUP BY a) GROUP BY a", + "Aggregate: groupBy=[[data.a]], aggr=[[sum(count(data.b)) AS sum(partial_count_b)]]\ + \n Aggregate: groupBy=[[data.a]], aggr=[[count(data.b)]]\ + \n TableScan: data projection=[a, b]", + true + ).await +} + #[tokio::test] async fn decimal_literal() -> Result<()> { roundtrip("SELECT * FROM data WHERE b > 2.5").await diff --git a/datafusion/substrait/tests/testdata/test_plans/multilayer_aggregate.substrait.json b/datafusion/substrait/tests/testdata/test_plans/multilayer_aggregate.substrait.json new file mode 100644 index 000000000000..1f47b916daf0 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/multilayer_aggregate.substrait.json @@ -0,0 +1,213 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_aggregate_generic.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_arithmetic.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_string.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "count:any" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "sum:i64" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "lower:str" + } + }], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [2, 3] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "product" + ], + "struct": { + "types": [ + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "sales" + ] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }], + "expressionReferences": [] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }], + "groupingExpressions": [] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }], + "expressionReferences": [] + }], + "measures": [{ + "measure": { + "functionReference": 1, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }], + "groupingExpressions": [] + } + }, + "expressions": [{ + "scalarFunction": { + "functionReference": 2, + "args": [], + "outputType": { + "string": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }], + "options": [] + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }] + } + }, + "names": ["lower(product)", "product_count"] + } + }], + "expectedTypeUrls": [] +}