Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 222 additions & 3 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ use datafusion_expr::expr::{
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
use datafusion_expr::{
Analyze, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType,
Filter, JoinType, RecursiveQuery, SkipType, StringifiedPlan, WindowFrame,
WindowFrameBound, WriteOp,
Analyze, BinaryExpr, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension,
FetchType, Filter, JoinType, LogicalPlanBuilder, RecursiveQuery, SkipType,
StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp,
};
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
use datafusion_physical_expr::expressions::{Column, Literal};
Expand All @@ -91,12 +91,15 @@ use datafusion_physical_plan::unnest::ListUnnest;
use crate::schema_equivalence::schema_satisfied_by;
use async_trait::async_trait;
use datafusion_datasource::file_groups::FileGroup;
use datafusion_expr_common::operator::Operator;
use futures::{StreamExt, TryStreamExt};
use itertools::{multiunzip, Itertools};
use log::{debug, trace};
use sqlparser::ast::NullTreatment;
use tokio::sync::Mutex;

use datafusion_physical_plan::collect;

/// Physical query planner that converts a `LogicalPlan` to an
/// `ExecutionPlan` suitable for execution.
#[async_trait]
Expand Down Expand Up @@ -887,7 +890,60 @@ impl DefaultPhysicalPlanner {
options.clone(),
))
}
LogicalPlan::Pivot(pivot) => {
return if !pivot.pivot_values.is_empty() {
let agg_plan = transform_pivot_to_aggregate(
Arc::new(pivot.input.as_ref().clone()),
&pivot.aggregate_expr,
&pivot.pivot_column,
pivot.pivot_values.clone(),
pivot.default_on_null_expr.as_ref(),
)?;

self.create_physical_plan(&agg_plan, session_state).await
} else if let Some(subquery) = &pivot.value_subquery {
let optimized_subquery = session_state.optimize(subquery.as_ref())?;

let subquery_physical_plan = self
.create_physical_plan(&optimized_subquery, session_state)
.await?;

let subquery_results = collect(
Arc::clone(&subquery_physical_plan),
session_state.task_ctx(),
)
.await?;

let mut pivot_values = Vec::new();
for batch in subquery_results.iter() {
if batch.num_columns() != 1 {
return plan_err!(
"Pivot subquery must return a single column"
);
}

let column = batch.column(0);
for row_idx in 0..batch.num_rows() {
if !column.is_null(row_idx) {
pivot_values
.push(ScalarValue::try_from_array(column, row_idx)?);
}
}
}

let agg_plan = transform_pivot_to_aggregate(
Arc::new(pivot.input.as_ref().clone()),
&pivot.aggregate_expr,
&pivot.pivot_column,
pivot_values,
pivot.default_on_null_expr.as_ref(),
)?;

self.create_physical_plan(&agg_plan, session_state).await
} else {
plan_err!("PIVOT operation requires at least one value to pivot on")
}
}
// 2 Children
LogicalPlan::Join(Join {
left,
Expand Down Expand Up @@ -1683,6 +1739,136 @@ pub use datafusion_physical_expr::{
create_physical_sort_expr, create_physical_sort_exprs,
};

/// Transform a PIVOT operation into a more standard Aggregate + Projection plan
/// For known pivot values, we create a projection that includes "IS NOT DISTINCT FROM" conditions
///
/// For example, for SUM(amount) PIVOT(quarter FOR quarter in ('2023_Q1', '2023_Q2')), we create:
/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q1') AS "2023_Q1"
/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q2') AS "2023_Q2"
///
/// If DEFAULT ON NULL is specified, each aggregate expression is wrapped with an outer projection that
/// applies COALESCE to the results.
pub fn transform_pivot_to_aggregate(
input: Arc<LogicalPlan>,
aggregate_expr: &Expr,
pivot_column: &datafusion_common::Column,
pivot_values: Vec<ScalarValue>,
default_on_null_expr: Option<&Expr>,
) -> Result<LogicalPlan> {
let df_schema = input.schema();

let all_columns: Vec<datafusion_common::Column> = df_schema.columns();

// Filter to include only columns we want for GROUP BY
// (exclude pivot column and aggregate expression columns)
let group_by_columns: Vec<Expr> = all_columns
.into_iter()
.filter(|col: &datafusion_common::Column| {
col.name != pivot_column.name
&& !aggregate_expr
.column_refs()
.iter()
.any(|agg_col| agg_col.name == col.name)
})
.map(|col: datafusion_common::Column| Expr::Column(col))
.collect();

let builder = LogicalPlanBuilder::from(Arc::unwrap_or_clone(input));

// Create the aggregate plan with filtered aggregates
let mut aggregate_exprs = Vec::new();

for value in &pivot_values {
let filter_condition = Expr::BinaryExpr(BinaryExpr::new(
Box::new(Expr::Column(pivot_column.clone())),
Operator::IsNotDistinctFrom,
Box::new(Expr::Literal(value.clone())),
));

let filtered_agg = match aggregate_expr {
Expr::AggregateFunction(agg) => {
let mut new_params = agg.params.clone();
new_params.filter = Some(Box::new(filter_condition));
Expr::AggregateFunction(AggregateFunction {
func: Arc::clone(&agg.func),
params: new_params,
})
}
_ => {
return plan_err!(
"Unsupported aggregate expression should always be AggregateFunction"
);
}
};

// Use the pivot value as the column name
let field_name = value.to_string().trim_matches('\'').to_string();
let aliased_agg = Expr::Alias(Alias {
expr: Box::new(filtered_agg),
relation: None,
name: field_name,
metadata: None,
});

aggregate_exprs.push(aliased_agg);
}

// Create the plan with the aggregate
let aggregate_plan = builder
.aggregate(group_by_columns, aggregate_exprs)?
.build()?;

// If DEFAULT ON NULL is specified, add a projection to apply COALESCE
if let Some(default_expr) = default_on_null_expr {
let schema = aggregate_plan.schema();
let mut projection_exprs = Vec::new();

for field in schema.fields() {
if !pivot_values
.iter()
.any(|v| field.name() == v.to_string().trim_matches('\''))
{
projection_exprs.push(Expr::Column(
datafusion_common::Column::from_name(field.name()),
));
}
}

// Apply COALESCE to aggregate columns
for value in &pivot_values {
let field_name = value.to_string().trim_matches('\'').to_string();
let aggregate_col =
Expr::Column(datafusion_common::Column::from_name(&field_name));

// Create COALESCE expression using CASE: CASE WHEN col IS NULL THEN default_value ELSE col END
let coalesce_expr = Expr::Case(datafusion_expr::expr::Case {
expr: None,
when_then_expr: vec![(
Box::new(Expr::IsNull(Box::new(aggregate_col.clone()))),
Box::new(default_expr.clone()),
)],
else_expr: Some(Box::new(aggregate_col)),
});

let aliased_coalesce = Expr::Alias(Alias {
expr: Box::new(coalesce_expr),
relation: None,
name: field_name,
metadata: None,
});

projection_exprs.push(aliased_coalesce);
}

// Apply the projection
LogicalPlanBuilder::from(aggregate_plan)
.project(projection_exprs)?
.build()
} else {
Ok(aggregate_plan)
}
}

impl DefaultPhysicalPlanner {
/// Handles capturing the various plans for EXPLAIN queries
///
Expand Down Expand Up @@ -2044,6 +2230,39 @@ impl DefaultPhysicalPlanner {
})
.collect::<Result<Vec<_>>>()?;

// When we detect a PIVOT-derived plan with a value_subquery, ensure all generated columns are preserved
if let LogicalPlan::Pivot(pivot) = input.as_ref() {
if pivot.value_subquery.is_some()
&& input_exec
.as_any()
.downcast_ref::<AggregateExec>()
.is_some()
{
let agg_exec =
input_exec.as_any().downcast_ref::<AggregateExec>().unwrap();
let schema = input_exec.schema();
let group_by_len = agg_exec.group_expr().expr().len();

if group_by_len < schema.fields().len() {
let mut all_exprs = physical_exprs.clone();

for (i, field) in
schema.fields().iter().enumerate().skip(group_by_len)
{
if !physical_exprs.iter().any(|(_, name)| name == field.name()) {
all_exprs.push((
Arc::new(Column::new(field.name(), i))
as Arc<dyn PhysicalExpr>,
field.name().clone(),
));
}
}

return Ok(Arc::new(ProjectionExec::try_new(all_exprs, input_exec)?));
}
}
}

Ok(Arc::new(ProjectionExec::try_new(
physical_exprs,
input_exec,
Expand Down
45 changes: 44 additions & 1 deletion datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use crate::expr_rewriter::{
};
use crate::logical_plan::{
Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join,
JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare,
JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, Pivot, PlanType, Prepare,
Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values,
Window,
};
Expand Down Expand Up @@ -1427,6 +1427,23 @@ impl LogicalPlanBuilder {
unnest_with_options(Arc::unwrap_or_clone(self.plan), columns, options)
.map(Self::new)
}

pub fn pivot(
self,
aggregate_expr: Expr,
pivot_column: Column,
pivot_values: Vec<ScalarValue>,
default_on_null: Option<Expr>,
) -> Result<Self> {
let pivot_plan = Pivot::try_new(
self.plan,
aggregate_expr,
pivot_column,
pivot_values,
default_on_null,
)?;
Ok(Self::new(LogicalPlan::Pivot(pivot_plan)))
}
}

impl From<LogicalPlan> for LogicalPlanBuilder {
Expand Down Expand Up @@ -2824,4 +2841,30 @@ mod tests {

Ok(())
}

#[test]
fn plan_builder_pivot() -> Result<()> {
let schema = Schema::new(vec![
Field::new("region", DataType::Utf8, false),
Field::new("product", DataType::Utf8, false),
Field::new("sales", DataType::Int32, false),
]);

let plan = LogicalPlanBuilder::scan("sales", table_source(&schema), None)?
.pivot(
col("sales"),
Column::from_name("product"),
vec![
ScalarValue::Utf8(Some("widget".to_string())),
ScalarValue::Utf8(Some("gadget".to_string())),
],
None,
)?
.build()?;

let expected = "Pivot: sales FOR product IN (widget, gadget)\n TableScan: sales";
assert_eq!(expected, format!("{plan}"));

Ok(())
}
}
Loading
Loading