Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ use crate::execution::context::{SessionState, TaskContext};
use crate::execution::FunctionRegistry;
use crate::logical_expr::utils::find_window_exprs;
use crate::logical_expr::{
col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Partitioning, TableType,
col, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, LogicalPlanBuilderOptions,
Partitioning, TableType,
};
use crate::physical_plan::{
collect, collect_partitioned, execute_stream, execute_stream_partitioned,
Expand Down Expand Up @@ -526,7 +527,10 @@ impl DataFrame {
) -> Result<DataFrame> {
let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
let aggr_expr_len = aggr_expr.len();
let options =
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
let plan = LogicalPlanBuilder::from(self.plan)
.with_options(options)
.aggregate(group_expr, aggr_expr)?
.build()?;
let plan = if is_grouping_set {
Expand Down
121 changes: 112 additions & 9 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ use datafusion_common::display::ToStringifiedPlan;
use datafusion_common::file_options::file_type::FileType;
use datafusion_common::{
exec_err, get_target_functional_dependencies, internal_err, not_impl_err,
plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError,
Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions,
plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef,
DataFusionError, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions,
};
use datafusion_expr_common::type_coercion::binary::type_union_resolution;

Expand All @@ -63,6 +63,26 @@ use indexmap::IndexSet;
/// Default table name for unnamed table
pub const UNNAMED_TABLE: &str = "?table?";

/// Options for [`LogicalPlanBuilder`]
#[derive(Default, Debug, Clone)]
pub struct LogicalPlanBuilderOptions {
/// Flag indicating whether the plan builder should add
/// functionally dependent expressions as additional aggregation groupings.
add_implicit_group_by_exprs: bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the default now false? That's a breaking change I guess? Is that fine, or should we make the default be true to maintain earlier behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the default is now false.

My intuition was that this made the most sense as I'd prefer a composable base builder that suits all needs, that allows opt-ins to advanced features and optimizations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although admittedly, I was not aware that this was going to be a breaking change, since SQL and Dataframes maintain the same behavior.

@vbarua informed me that it may be possible that there are those depending on the builder directly. In which case, I can see an argument for maintaining the default true behavior...

I'm a new contributor to this project so I'm lacking context. I might lean on you folks to make a more informed decision here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I lean towards having the default behaviour be false for this, even if it's a breaking change, because it makes the builder less surprising IMO. Specifically, when invoking the builder for an aggregate with a specific set of grouping expressions, my expectation is that it should produce an aggregate with those specific grouping expressions. If I wanted additional grouping expressions, I would have included them.

There's definitely room and value for optimizations like what is going on here, but I think those need to be opt-in to avoid situation like this were the plan builder tries to be smart along one specific axis and inadvertently shoots you in the foot in another. In the past, I think we've leaned towards having the builder be as straightforward as possible and then handling optimizations in the optimizer.

Copy link
Contributor

@alamb alamb Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it would be less surprising, but I think we should consider it in a follow on PR. This PR fixes the issue with substrait and is a good change on its own. I have filed a follow on ticket to track the idea:

Update: I see this PR changes the defaults. I updated the title of the PR and will mark it as an API change

}

impl LogicalPlanBuilderOptions {
pub fn new() -> Self {
Default::default()
}

/// Should the builder add functionally dependent expressions as additional aggregation groupings.
pub fn with_add_implicit_group_by_exprs(mut self, add: bool) -> Self {
self.add_implicit_group_by_exprs = add;
self
}
}

/// Builder for logical plans
///
/// # Example building a simple plan
Expand Down Expand Up @@ -103,19 +123,29 @@ pub const UNNAMED_TABLE: &str = "?table?";
#[derive(Debug, Clone)]
pub struct LogicalPlanBuilder {
plan: Arc<LogicalPlan>,
options: LogicalPlanBuilderOptions,
}

impl LogicalPlanBuilder {
/// Create a builder from an existing plan
pub fn new(plan: LogicalPlan) -> Self {
Self {
plan: Arc::new(plan),
options: LogicalPlanBuilderOptions::default(),
}
}

/// Create a builder from an existing plan
pub fn new_from_arc(plan: Arc<LogicalPlan>) -> Self {
Self { plan }
Self {
plan,
options: LogicalPlanBuilderOptions::default(),
}
}

pub fn with_options(mut self, options: LogicalPlanBuilderOptions) -> Self {
self.options = options;
self
}

/// Return the output schema of the plan build so far
Expand Down Expand Up @@ -1138,8 +1168,12 @@ impl LogicalPlanBuilder {
let group_expr = normalize_cols(group_expr, &self.plan)?;
let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;

let group_expr =
add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?;
let group_expr = if self.options.add_implicit_group_by_exprs {
add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?
} else {
group_expr
};

Aggregate::try_new(self.plan, group_expr, aggr_expr)
.map(LogicalPlan::Aggregate)
.map(Self::new)
Expand Down Expand Up @@ -1550,6 +1584,7 @@ pub fn add_group_by_exprs_from_dependencies(
}
Ok(group_expr)
}

/// Errors if one or more expressions have equal names.
pub fn validate_unique_names<'a>(
node_name: &str,
Expand Down Expand Up @@ -1685,7 +1720,21 @@ pub fn table_scan_with_filter_and_fetch(

pub fn table_source(table_schema: &Schema) -> Arc<dyn TableSource> {
let table_schema = Arc::new(table_schema.clone());
Arc::new(LogicalTableSource { table_schema })
Arc::new(LogicalTableSource {
table_schema,
constraints: Default::default(),
})
}

pub fn table_source_with_constraints(
table_schema: &Schema,
constraints: Constraints,
) -> Arc<dyn TableSource> {
let table_schema = Arc::new(table_schema.clone());
Arc::new(LogicalTableSource {
table_schema,
constraints,
})
}

/// Wrap projection for a plan, if the join keys contains normal expression.
Expand Down Expand Up @@ -1756,12 +1805,21 @@ pub fn wrap_projection_for_join_if_necessary(
/// DefaultTableSource.
pub struct LogicalTableSource {
table_schema: SchemaRef,
constraints: Constraints,
}

impl LogicalTableSource {
/// Create a new LogicalTableSource
pub fn new(table_schema: SchemaRef) -> Self {
Self { table_schema }
Self {
table_schema,
constraints: Constraints::default(),
}
}

pub fn with_constraints(mut self, constraints: Constraints) -> Self {
self.constraints = constraints;
self
}
}

Expand All @@ -1774,6 +1832,10 @@ impl TableSource for LogicalTableSource {
Arc::clone(&self.table_schema)
}

fn constraints(&self) -> Option<&Constraints> {
Some(&self.constraints)
}

fn supports_filters_pushdown(
&self,
filters: &[&Expr],
Expand Down Expand Up @@ -2023,12 +2085,12 @@ pub fn unnest_with_options(

#[cfg(test)]
mod tests {

use super::*;
use crate::logical_plan::StringifiedPlan;
use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery};

use datafusion_common::{RecursionUnnestOption, SchemaError};
use crate::test::function_stub::sum;
use datafusion_common::{Constraint, RecursionUnnestOption, SchemaError};

#[test]
fn plan_builder_simple() -> Result<()> {
Expand Down Expand Up @@ -2575,4 +2637,45 @@ mod tests {

Ok(())
}

#[test]
fn plan_builder_aggregate_without_implicit_group_by_exprs() -> Result<()> {
let constraints =
Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]);
let table_source = table_source_with_constraints(&employee_schema(), constraints);

let plan =
LogicalPlanBuilder::scan("employee_csv", table_source, Some(vec![0, 3, 4]))?
.aggregate(vec![col("id")], vec![sum(col("salary"))])?
.build()?;

let expected =
"Aggregate: groupBy=[[employee_csv.id]], aggr=[[sum(employee_csv.salary)]]\
\n TableScan: employee_csv projection=[id, state, salary]";
assert_eq!(expected, format!("{plan}"));

Ok(())
}

#[test]
fn plan_builder_aggregate_with_implicit_group_by_exprs() -> Result<()> {
let constraints =
Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]);
let table_source = table_source_with_constraints(&employee_schema(), constraints);

let options =
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
let plan =
LogicalPlanBuilder::scan("employee_csv", table_source, Some(vec![0, 3, 4]))?
.with_options(options)
.aggregate(vec![col("id")], vec![sum(col("salary"))])?
.build()?;

let expected =
"Aggregate: groupBy=[[employee_csv.id, employee_csv.state, employee_csv.salary]], aggr=[[sum(employee_csv.salary)]]\
\n TableScan: employee_csv projection=[id, state, salary]";
assert_eq!(expected, format!("{plan}"));

Ok(())
}
}
2 changes: 1 addition & 1 deletion datafusion/expr/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub mod tree_node;

pub use builder::{
build_join_schema, table_scan, union, wrap_projection_for_join_if_necessary,
LogicalPlanBuilder, LogicalTableSource, UNNAMED_TABLE,
LogicalPlanBuilder, LogicalPlanBuilderOptions, LogicalTableSource, UNNAMED_TABLE,
};
pub use ddl::{
CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction,
Expand Down
9 changes: 8 additions & 1 deletion datafusion/sql/src/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ use datafusion_expr::utils::{
};
use datafusion_expr::{
qualified_wildcard_with_options, wildcard_with_options, Aggregate, Expr, Filter,
GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning,
GroupingSet, LogicalPlan, LogicalPlanBuilder, LogicalPlanBuilderOptions,
Partitioning,
};

use indexmap::IndexMap;
Expand Down Expand Up @@ -371,7 +372,10 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
let agg_expr = agg.aggr_expr.clone();
let (new_input, new_group_by_exprs) =
self.try_process_group_by_unnest(agg)?;
let options = LogicalPlanBuilderOptions::new()
.with_add_implicit_group_by_exprs(true);
LogicalPlanBuilder::from(new_input)
.with_options(options)
.aggregate(new_group_by_exprs, agg_expr)?
.build()
}
Expand Down Expand Up @@ -744,7 +748,10 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
aggr_exprs: &[Expr],
) -> Result<(LogicalPlan, Vec<Expr>, Option<Expr>)> {
// create the aggregate plan
let options =
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
let plan = LogicalPlanBuilder::from(input.clone())
.with_options(options)
.aggregate(group_by_exprs.to_vec(), aggr_exprs.to_vec())?
.build()?;
let group_by_exprs = if let LogicalPlan::Aggregate(agg) = &plan {
Expand Down
18 changes: 18 additions & 0 deletions datafusion/substrait/tests/cases/logical_plans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,22 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn multilayer_aggregate() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/multilayer_aggregate.substrait.json");
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;

assert_eq!(
format!("{}", plan),
"Projection: lower(sales.product) AS lower(product), sum(count(sales.product)) AS product_count\
\n Aggregate: groupBy=[[sales.product]], aggr=[[sum(count(sales.product))]]\
\n Aggregate: groupBy=[[sales.product]], aggr=[[count(sales.product)]]\
\n TableScan: sales"
);

Ok(())
}
}
11 changes: 11 additions & 0 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,17 @@ async fn aggregate_grouping_rollup() -> Result<()> {
).await
}

#[tokio::test]
async fn multilayer_aggregate() -> Result<()> {
assert_expected_plan(
"SELECT a, sum(partial_count_b) FROM (SELECT a, count(b) as partial_count_b FROM data GROUP BY a) GROUP BY a",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still a bit confused by this. Why does the original plan created from SQL already contain the implicit groupBy's here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test goes SQL -> Logical Plan -> Substrait Plan -> Logical Plan.

There were no additional groupBys added between SQL -> Logical Plan -> Substrait Plan.

(Prior to this PR) Additional groupBys were added between Substrait Plan -> Logical Plan, violating the round trip plan expectations.

"Aggregate: groupBy=[[data.a]], aggr=[[sum(count(data.b)) AS sum(partial_count_b)]]\
\n Aggregate: groupBy=[[data.a]], aggr=[[count(data.b)]]\
\n TableScan: data projection=[a, b]",
true
).await
}

#[tokio::test]
async fn decimal_literal() -> Result<()> {
roundtrip("SELECT * FROM data WHERE b > 2.5").await
Expand Down
Loading