diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index be24206c676c..0ed5fd3bca2b 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -76,9 +76,9 @@ use datafusion_expr::expr::{ use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - Analyze, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType, - Filter, JoinType, RecursiveQuery, SkipType, StringifiedPlan, WindowFrame, - WindowFrameBound, WriteOp, + Analyze, BinaryExpr, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, + FetchType, Filter, JoinType, LogicalPlanBuilder, RecursiveQuery, SkipType, + StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::{Column, Literal}; @@ -91,12 +91,15 @@ use datafusion_physical_plan::unnest::ListUnnest; use crate::schema_equivalence::schema_satisfied_by; use async_trait::async_trait; use datafusion_datasource::file_groups::FileGroup; +use datafusion_expr_common::operator::Operator; use futures::{StreamExt, TryStreamExt}; use itertools::{multiunzip, Itertools}; use log::{debug, trace}; use sqlparser::ast::NullTreatment; use tokio::sync::Mutex; +use datafusion_physical_plan::collect; + /// Physical query planner that converts a `LogicalPlan` to an /// `ExecutionPlan` suitable for execution. #[async_trait] @@ -887,7 +890,60 @@ impl DefaultPhysicalPlanner { options.clone(), )) } + LogicalPlan::Pivot(pivot) => { + return if !pivot.pivot_values.is_empty() { + let agg_plan = transform_pivot_to_aggregate( + Arc::new(pivot.input.as_ref().clone()), + &pivot.aggregate_expr, + &pivot.pivot_column, + pivot.pivot_values.clone(), + pivot.default_on_null_expr.as_ref(), + )?; + + self.create_physical_plan(&agg_plan, session_state).await + } else if let Some(subquery) = &pivot.value_subquery { + let optimized_subquery = session_state.optimize(subquery.as_ref())?; + + let subquery_physical_plan = self + .create_physical_plan(&optimized_subquery, session_state) + .await?; + + let subquery_results = collect( + Arc::clone(&subquery_physical_plan), + session_state.task_ctx(), + ) + .await?; + + let mut pivot_values = Vec::new(); + for batch in subquery_results.iter() { + if batch.num_columns() != 1 { + return plan_err!( + "Pivot subquery must return a single column" + ); + } + + let column = batch.column(0); + for row_idx in 0..batch.num_rows() { + if !column.is_null(row_idx) { + pivot_values + .push(ScalarValue::try_from_array(column, row_idx)?); + } + } + } + + let agg_plan = transform_pivot_to_aggregate( + Arc::new(pivot.input.as_ref().clone()), + &pivot.aggregate_expr, + &pivot.pivot_column, + pivot_values, + pivot.default_on_null_expr.as_ref(), + )?; + self.create_physical_plan(&agg_plan, session_state).await + } else { + plan_err!("PIVOT operation requires at least one value to pivot on") + } + } // 2 Children LogicalPlan::Join(Join { left, @@ -1683,6 +1739,136 @@ pub use datafusion_physical_expr::{ create_physical_sort_expr, create_physical_sort_exprs, }; +/// Transform a PIVOT operation into a more standard Aggregate + Projection plan +/// For known pivot values, we create a projection that includes "IS NOT DISTINCT FROM" conditions +/// +/// For example, for SUM(amount) PIVOT(quarter FOR quarter in ('2023_Q1', '2023_Q2')), we create: +/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q1') AS "2023_Q1" +/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q2') AS "2023_Q2" +/// +/// If DEFAULT ON NULL is specified, each aggregate expression is wrapped with an outer projection that +/// applies COALESCE to the results. +pub fn transform_pivot_to_aggregate( + input: Arc, + aggregate_expr: &Expr, + pivot_column: &datafusion_common::Column, + pivot_values: Vec, + default_on_null_expr: Option<&Expr>, +) -> Result { + let df_schema = input.schema(); + + let all_columns: Vec = df_schema.columns(); + + // Filter to include only columns we want for GROUP BY + // (exclude pivot column and aggregate expression columns) + let group_by_columns: Vec = all_columns + .into_iter() + .filter(|col: &datafusion_common::Column| { + col.name != pivot_column.name + && !aggregate_expr + .column_refs() + .iter() + .any(|agg_col| agg_col.name == col.name) + }) + .map(|col: datafusion_common::Column| Expr::Column(col)) + .collect(); + + let builder = LogicalPlanBuilder::from(Arc::unwrap_or_clone(input)); + + // Create the aggregate plan with filtered aggregates + let mut aggregate_exprs = Vec::new(); + + for value in &pivot_values { + let filter_condition = Expr::BinaryExpr(BinaryExpr::new( + Box::new(Expr::Column(pivot_column.clone())), + Operator::IsNotDistinctFrom, + Box::new(Expr::Literal(value.clone())), + )); + + let filtered_agg = match aggregate_expr { + Expr::AggregateFunction(agg) => { + let mut new_params = agg.params.clone(); + new_params.filter = Some(Box::new(filter_condition)); + Expr::AggregateFunction(AggregateFunction { + func: Arc::clone(&agg.func), + params: new_params, + }) + } + _ => { + return plan_err!( + "Unsupported aggregate expression should always be AggregateFunction" + ); + } + }; + + // Use the pivot value as the column name + let field_name = value.to_string().trim_matches('\'').to_string(); + let aliased_agg = Expr::Alias(Alias { + expr: Box::new(filtered_agg), + relation: None, + name: field_name, + metadata: None, + }); + + aggregate_exprs.push(aliased_agg); + } + + // Create the plan with the aggregate + let aggregate_plan = builder + .aggregate(group_by_columns, aggregate_exprs)? + .build()?; + + // If DEFAULT ON NULL is specified, add a projection to apply COALESCE + if let Some(default_expr) = default_on_null_expr { + let schema = aggregate_plan.schema(); + let mut projection_exprs = Vec::new(); + + for field in schema.fields() { + if !pivot_values + .iter() + .any(|v| field.name() == v.to_string().trim_matches('\'')) + { + projection_exprs.push(Expr::Column( + datafusion_common::Column::from_name(field.name()), + )); + } + } + + // Apply COALESCE to aggregate columns + for value in &pivot_values { + let field_name = value.to_string().trim_matches('\'').to_string(); + let aggregate_col = + Expr::Column(datafusion_common::Column::from_name(&field_name)); + + // Create COALESCE expression using CASE: CASE WHEN col IS NULL THEN default_value ELSE col END + let coalesce_expr = Expr::Case(datafusion_expr::expr::Case { + expr: None, + when_then_expr: vec![( + Box::new(Expr::IsNull(Box::new(aggregate_col.clone()))), + Box::new(default_expr.clone()), + )], + else_expr: Some(Box::new(aggregate_col)), + }); + + let aliased_coalesce = Expr::Alias(Alias { + expr: Box::new(coalesce_expr), + relation: None, + name: field_name, + metadata: None, + }); + + projection_exprs.push(aliased_coalesce); + } + + // Apply the projection + LogicalPlanBuilder::from(aggregate_plan) + .project(projection_exprs)? + .build() + } else { + Ok(aggregate_plan) + } +} + impl DefaultPhysicalPlanner { /// Handles capturing the various plans for EXPLAIN queries /// @@ -2044,6 +2230,39 @@ impl DefaultPhysicalPlanner { }) .collect::>>()?; + // When we detect a PIVOT-derived plan with a value_subquery, ensure all generated columns are preserved + if let LogicalPlan::Pivot(pivot) = input.as_ref() { + if pivot.value_subquery.is_some() + && input_exec + .as_any() + .downcast_ref::() + .is_some() + { + let agg_exec = + input_exec.as_any().downcast_ref::().unwrap(); + let schema = input_exec.schema(); + let group_by_len = agg_exec.group_expr().expr().len(); + + if group_by_len < schema.fields().len() { + let mut all_exprs = physical_exprs.clone(); + + for (i, field) in + schema.fields().iter().enumerate().skip(group_by_len) + { + if !physical_exprs.iter().any(|(_, name)| name == field.name()) { + all_exprs.push(( + Arc::new(Column::new(field.name(), i)) + as Arc, + field.name().clone(), + )); + } + } + + return Ok(Arc::new(ProjectionExec::try_new(all_exprs, input_exec)?)); + } + } + } + Ok(Arc::new(ProjectionExec::try_new( physical_exprs, input_exec, diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 05a43444d4ae..866a9ca49e1a 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -32,7 +32,7 @@ use crate::expr_rewriter::{ }; use crate::logical_plan::{ Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join, - JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, + JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, Pivot, PlanType, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, Window, }; @@ -1427,6 +1427,23 @@ impl LogicalPlanBuilder { unnest_with_options(Arc::unwrap_or_clone(self.plan), columns, options) .map(Self::new) } + + pub fn pivot( + self, + aggregate_expr: Expr, + pivot_column: Column, + pivot_values: Vec, + default_on_null: Option, + ) -> Result { + let pivot_plan = Pivot::try_new( + self.plan, + aggregate_expr, + pivot_column, + pivot_values, + default_on_null, + )?; + Ok(Self::new(LogicalPlan::Pivot(pivot_plan))) + } } impl From for LogicalPlanBuilder { @@ -2824,4 +2841,30 @@ mod tests { Ok(()) } + + #[test] + fn plan_builder_pivot() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("region", DataType::Utf8, false), + Field::new("product", DataType::Utf8, false), + Field::new("sales", DataType::Int32, false), + ]); + + let plan = LogicalPlanBuilder::scan("sales", table_source(&schema), None)? + .pivot( + col("sales"), + Column::from_name("product"), + vec![ + ScalarValue::Utf8(Some("widget".to_string())), + ScalarValue::Utf8(Some("gadget".to_string())), + ], + None, + )? + .build()?; + + let expected = "Pivot: sales FOR product IN (widget, gadget)\n TableScan: sales"; + assert_eq!(expected, format!("{plan}")); + + Ok(()) + } } diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 14758b61e859..07a069cbb400 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -22,7 +22,7 @@ use std::fmt; use crate::{ expr_vec_fmt, Aggregate, DescribeTable, Distinct, DistinctOn, DmlStatement, Expr, - Filter, Join, Limit, LogicalPlan, Partitioning, Projection, RecursiveQuery, + Filter, Join, Limit, LogicalPlan, Partitioning, Pivot, Projection, RecursiveQuery, Repartition, Sort, Subquery, SubqueryAlias, TableProviderFilterPushDown, TableScan, Unnest, Values, Window, }; @@ -650,6 +650,41 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "StructColumn": expr_vec_fmt!(struct_type_columns), }) } + LogicalPlan::Pivot(Pivot { + aggregate_expr, + pivot_column, + pivot_values, + value_subquery, + default_on_null_expr, + .. + }) => { + let mut object = json!({ + "Node Type": "Pivot", + "Aggregate": format!("{}", aggregate_expr), + "Pivot Column": format!("{}", pivot_column), + }); + + if !pivot_values.is_empty() { + object["Pivot Values"] = serde_json::Value::Array( + pivot_values + .iter() + .map(|v| serde_json::Value::String(v.to_string())) + .collect(), + ); + } + + if value_subquery.is_some() { + object["Value Subquery"] = + serde_json::Value::String("Provided".to_string()); + } + + if default_on_null_expr.is_some() { + object["Default On Null"] = + serde_json::Value::String("Provided".to_string()); + } + + object + } } } } @@ -721,7 +756,10 @@ impl<'n> TreeNodeVisitor<'n> for PgJsonVisitor<'_, '_> { #[cfg(test)] mod tests { - use arrow::datatypes::{DataType, Field}; + use crate::EmptyRelation; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{Column, DFSchema, ScalarValue}; + use std::sync::Arc; use super::*; @@ -743,4 +781,84 @@ mod tests { format!("{}", display_schema(&schema)) ); } + + #[test] + fn test_pivot_to_json_value() { + // Create a mock schema + let schema = Arc::new(DFSchema::empty().to_owned()); + + // Create mock pivot values + let pivot_values = vec![ + ScalarValue::Utf8(Some("A".to_string())), + ScalarValue::Utf8(Some("B".to_string())), + ]; + + // Create a Pivot plan + let pivot = Pivot { + input: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: schema.clone(), + })), + aggregate_expr: Expr::Column(Column::from_name("sum_value")), + pivot_column: Column::from_name("category"), + pivot_values, + schema: schema.clone(), + value_subquery: None, + default_on_null_expr: None, + }; + + // Test the to_json_value function + let json_value = PgJsonVisitor::to_json_value(&LogicalPlan::Pivot(pivot)); + + // Check the JSON structure + assert_eq!(json_value["Node Type"], "Pivot"); + assert_eq!(json_value["Aggregate"], "sum_value"); + assert_eq!(json_value["Pivot Column"], "category"); + + // Check the pivot values + let pivot_values = json_value["Pivot Values"].as_array().unwrap(); + assert_eq!(pivot_values.len(), 2); + assert_eq!(pivot_values[0], "A"); + assert_eq!(pivot_values[1], "B"); + + // Check that Value Subquery is not present + assert!(json_value.get("Value Subquery").is_none()); + } + + #[test] + fn test_pivot_with_subquery_to_json_value() { + // Create a mock schema + let schema = Arc::new(DFSchema::empty().to_owned()); + + // Create a Pivot plan with a value subquery + let pivot = Pivot { + input: Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: schema.clone(), + })), + aggregate_expr: Expr::Column(Column::from_name("sum_value")), + pivot_column: Column::from_name("category"), + pivot_values: vec![], + schema: schema.clone(), + value_subquery: Some(Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: schema.clone(), + }))), + default_on_null_expr: None, + }; + + // Test the to_json_value function + let json_value = PgJsonVisitor::to_json_value(&LogicalPlan::Pivot(pivot)); + + // Check the JSON structure + assert_eq!(json_value["Node Type"], "Pivot"); + assert_eq!(json_value["Aggregate"], "sum_value"); + assert_eq!(json_value["Pivot Column"], "category"); + + // Check that pivot values are not present + assert!(json_value.get("Pivot Values").is_none()); + + // Check that Value Subquery is present + assert_eq!(json_value["Value Subquery"], "Provided"); + } } diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index a55f4d97b212..2225f57e66f4 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -39,7 +39,7 @@ pub use dml::{DmlStatement, WriteOp}; pub use plan::{ projection_schema, Aggregate, Analyze, ColumnUnnestList, DescribeTable, Distinct, DistinctOn, EmptyRelation, Explain, ExplainFormat, Extension, FetchType, Filter, - Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, + Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, Pivot, PlanType, Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, }; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index edf5f1126be9..a284a3ea1068 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -287,6 +287,8 @@ pub enum LogicalPlan { Unnest(Unnest), /// A variadic query (e.g. "Recursive CTEs") RecursiveQuery(RecursiveQuery), + /// Pivot + Pivot(Pivot), } impl Default for LogicalPlan { @@ -351,6 +353,7 @@ impl LogicalPlan { // we take the schema of the static term as the schema of the entire recursive query static_term.schema() } + LogicalPlan::Pivot(Pivot { schema, .. }) => schema, } } @@ -467,7 +470,8 @@ impl LogicalPlan { LogicalPlan::Dml(write) => vec![&write.input], LogicalPlan::Copy(copy) => vec![©.input], LogicalPlan::Ddl(ddl) => ddl.inputs(), - LogicalPlan::Unnest(Unnest { input, .. }) => vec![input], + LogicalPlan::Unnest(Unnest { input, .. }) + | LogicalPlan::Pivot(Pivot { input, .. }) => vec![input], LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, recursive_term, @@ -589,7 +593,8 @@ impl LogicalPlan { | LogicalPlan::Copy(_) | LogicalPlan::Ddl(_) | LogicalPlan::DescribeTable(_) - | LogicalPlan::Unnest(_) => Ok(None), + | LogicalPlan::Unnest(_) + | LogicalPlan::Pivot(_) => Ok(None), } } @@ -749,6 +754,39 @@ impl LogicalPlan { // Update schema with unnested column type. unnest_with_options(Arc::unwrap_or_clone(input), exec_columns, options) } + LogicalPlan::Pivot(Pivot { + input, + aggregate_expr, + pivot_column, + pivot_values, + schema, + value_subquery, + default_on_null_expr, + .. + }) => { + // Create Pivot with the same value_subquery + let new_pivot = if let Some(subquery) = value_subquery { + Pivot { + input, + aggregate_expr, + pivot_column: pivot_column.clone(), + pivot_values: pivot_values.clone(), + schema: Arc::clone(&schema), + value_subquery: Some(Arc::clone(&subquery)), + default_on_null_expr: None, + } + } else { + Pivot::try_new( + Arc::clone(&input), + aggregate_expr.clone(), + pivot_column.clone(), + pivot_values.clone(), + default_on_null_expr.clone(), + )? + }; + + Ok(LogicalPlan::Pivot(new_pivot)) + } } } @@ -1141,6 +1179,39 @@ impl LogicalPlan { unnest_with_options(input, columns.clone(), options.clone())?; Ok(new_plan) } + LogicalPlan::Pivot(Pivot { + aggregate_expr: _, + pivot_column, + pivot_values, + schema: _, + value_subquery, + default_on_null_expr, + .. + }) => { + let input = self.only_input(inputs)?; + let new_aggregate_expr = self.only_expr(expr)?; + + // Create Pivot with the same value_subquery + let new_pivot = if let Some(subquery) = value_subquery { + Pivot::try_new_with_subquery( + Arc::new(input), + new_aggregate_expr, + pivot_column.clone(), + Arc::clone(subquery), + default_on_null_expr.clone(), + )? + } else { + Pivot::try_new( + Arc::new(input), + new_aggregate_expr, + pivot_column.clone(), + pivot_values.clone(), + default_on_null_expr.clone(), + )? + }; + + Ok(LogicalPlan::Pivot(new_pivot)) + } } } @@ -1373,7 +1444,8 @@ impl LogicalPlan { | LogicalPlan::Copy(_) | LogicalPlan::DescribeTable(_) | LogicalPlan::Statement(_) - | LogicalPlan::Extension(_) => None, + | LogicalPlan::Extension(_) + | LogicalPlan::Pivot(_) => None, } } @@ -2018,6 +2090,20 @@ impl LogicalPlan { expr_vec_fmt!(list_type_columns), expr_vec_fmt!(struct_type_columns)) } + LogicalPlan::Pivot(Pivot { + aggregate_expr, + pivot_column, + pivot_values, + .. + }) => { + write!( + f, + "Pivot: {} FOR {} IN ({})", + aggregate_expr, + pivot_column, + pivot_values.iter().map(|v| v.to_string()).collect::>().join(", ") + ) + } } } } @@ -2193,6 +2279,155 @@ pub fn projection_schema(input: &LogicalPlan, exprs: &[Expr]) -> Result, + /// Aggregate expression (e.g., SUM(amount)) + pub aggregate_expr: Expr, + /// Column whose values become new columns + pub pivot_column: Column, + /// List of pivot values (distinct values from pivot column) + pub pivot_values: Vec, + /// Output schema after pivot + pub schema: DFSchemaRef, + /// Optional subquery for pivot values + /// When provided, this will be executed during physical planning + /// to dynamically determine the pivot values + pub value_subquery: Option>, + /// Optional default value for replacing NULL values in the pivot result + pub default_on_null_expr: Option, +} + +impl PartialOrd for Pivot { + fn partial_cmp(&self, other: &Self) -> Option { + let self_tuple = ( + &self.input, + &self.aggregate_expr, + &self.pivot_column, + &self.pivot_values, + &self.value_subquery, + &self.default_on_null_expr, + ); + let other_tuple = ( + &other.input, + &other.aggregate_expr, + &other.pivot_column, + &other.pivot_values, + &other.value_subquery, + &other.default_on_null_expr, + ); + self_tuple.partial_cmp(&other_tuple) + } +} + +impl Pivot { + pub fn try_new( + input: Arc, + aggregate_expr: Expr, + pivot_column: Column, + pivot_values: Vec, + default_on_null_expr: Option, + ) -> Result { + let schema = pivot_schema( + input.schema(), + &aggregate_expr, + &pivot_column, + &pivot_values, + )?; + + Ok(Self { + input, + aggregate_expr, + pivot_column, + pivot_values, + schema: Arc::new(schema), + value_subquery: None, + default_on_null_expr, + }) + } + + /// Create a new Pivot with a subquery for pivot values + pub fn try_new_with_subquery( + input: Arc, + aggregate_expr: Expr, + pivot_column: Column, + value_subquery: Arc, + default_on_null_expr: Option, + ) -> Result { + let schema = + pivot_schema_without_values(input.schema(), &aggregate_expr, &pivot_column)?; + + Ok(Self { + input, + aggregate_expr, + pivot_column, + pivot_values: Vec::new(), + schema: Arc::new(schema), + value_subquery: Some(value_subquery), + default_on_null_expr, + }) + } +} + +fn pivot_schema_without_values( + input_schema: &DFSchemaRef, + aggregate_expr: &Expr, + pivot_column: &Column, +) -> Result { + let mut fields = vec![]; + + // Include all fields except pivot and value columns + for field in input_schema.fields() { + if !aggregate_expr + .column_refs() + .iter() + .any(|col| col.name() == field.name()) + && field.name() != pivot_column.name() + { + fields.push(Arc::clone(field)); + } + } + + let fields_with_table_ref: Vec<(Option, Arc)> = + fields.into_iter().map(|field| (None, field)).collect(); + + DFSchema::new_with_metadata(fields_with_table_ref, input_schema.metadata().clone()) +} + +fn pivot_schema( + input_schema: &DFSchemaRef, + aggregate_expr: &Expr, + pivot_column: &Column, + pivot_values: &[ScalarValue], +) -> Result { + let mut fields = vec![]; + + for field in input_schema.fields() { + if !aggregate_expr + .column_refs() + .iter() + .any(|col| col.name() == field.name()) + && field.name() != pivot_column.name() + { + fields.push(Arc::clone(field)); + } + } + + for pivot_value in pivot_values { + let field_name = format!("{}", pivot_value); + let data_type = aggregate_expr.get_type(input_schema)?; + fields.push(Arc::new(Field::new(field_name, data_type, true))); + } + + let fields_with_table_ref: Vec<(Option, Arc)> = + fields.into_iter().map(|field| (None, field)).collect(); + + DFSchema::new_with_metadata(fields_with_table_ref, input_schema.metadata().clone()) +} + /// Aliased subquery #[derive(Debug, Clone, PartialEq, Eq, Hash)] // mark non_exhaustive to encourage use of try_new/new() diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 7f6e1e025387..88101137e031 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -40,8 +40,8 @@ use crate::{ dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, Filter, Join, - Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, Repartition, - Sort, Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, + Limit, LogicalPlan, Partitioning, Pivot, Prepare, Projection, RecursiveQuery, + Repartition, Sort, Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, Values, Window, }; use datafusion_common::tree_node::TreeNodeRefContainer; @@ -328,6 +328,25 @@ impl TreeNode for LogicalPlan { options, }) }), + LogicalPlan::Pivot(Pivot { + input, + aggregate_expr, + pivot_column, + pivot_values, + schema, + value_subquery, + default_on_null_expr, + }) => input.map_elements(f)?.update_data(|input| { + LogicalPlan::Pivot(Pivot { + input, + aggregate_expr, + pivot_column, + pivot_values, + schema, + value_subquery, + default_on_null_expr, + }) + }), LogicalPlan::RecursiveQuery(RecursiveQuery { name, static_term, @@ -467,6 +486,7 @@ impl LogicalPlan { } _ => Ok(TreeNodeRecursion::Continue), }, + LogicalPlan::Pivot(Pivot { aggregate_expr, .. }) => f(aggregate_expr), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) @@ -644,6 +664,25 @@ impl LogicalPlan { LogicalPlan::Limit(Limit { skip, fetch, input }) }) } + LogicalPlan::Pivot(Pivot { + input, + aggregate_expr, + pivot_column, + pivot_values, + schema, + value_subquery, + default_on_null_expr, + }) => f(aggregate_expr)?.update_data(|aggregate_expr| { + LogicalPlan::Pivot(Pivot { + input, + aggregate_expr, + pivot_column, + pivot_values, + schema, + value_subquery, + default_on_null_expr, + }) + }), LogicalPlan::Statement(stmt) => match stmt { Statement::Execute(e) => { e.parameters.map_elements(f)?.update_data(|parameters| { diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index 142fdf815a7e..ccea816ccf78 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -166,7 +166,7 @@ mod tests { use arrow::datatypes::DataType; use arrow::{compute::kernels::cast_utils::Parser, datatypes::Date32Type}; use datafusion_common::ScalarValue; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; #[test] diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 69b5fbb9f8c0..09b3fbeef25f 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -564,7 +564,8 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Dml(_) | LogicalPlan::Copy(_) | LogicalPlan::Unnest(_) - | LogicalPlan::RecursiveQuery(_) => { + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::Pivot(_) => { // This rule handles recursion itself in a `ApplyOrder::TopDown` like // manner. plan.map_children(|c| self.rewrite(c, config))? diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 4452b2d4ce03..ba9af7baff6a 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -382,6 +382,9 @@ fn optimize_projections( dependency_indices.clone(), )] } + LogicalPlan::Pivot(_) => { + return Ok(Transformed::no(plan)); + } }; // Required indices are currently ordered (child0, child1, ...) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 39236da3b9a8..4f96c8ab7a63 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -62,6 +62,7 @@ message LogicalPlanNode { RecursiveQueryNode recursive_query = 31; CteWorkTableScanNode cte_work_table_scan = 32; DmlNode dml = 33; + PivotNode pivot = 34; } } @@ -1285,3 +1286,14 @@ message CteWorkTableScanNode { string name = 1; datafusion_common.Schema schema = 2; } + +message PivotNode { + LogicalPlanNode input = 1; + LogicalExprNode aggregate_expr = 2; + datafusion_common.Column pivot_column = 3; + repeated datafusion_common.ScalarValue pivot_values = 4; + datafusion_common.DfSchema schema = 5; + LogicalPlanNode value_subquery = 6; + LogicalExprNode default_on_null_expr = 7; + +} diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 6166b6ec4796..6db2979f839a 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -11039,6 +11039,9 @@ impl serde::Serialize for LogicalPlanNode { logical_plan_node::LogicalPlanType::Dml(v) => { struct_ser.serialize_field("dml", v)?; } + logical_plan_node::LogicalPlanType::Pivot(v) => { + struct_ser.serialize_field("pivot", v)?; + } } } struct_ser.end() @@ -11098,6 +11101,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "cte_work_table_scan", "cteWorkTableScan", "dml", + "pivot", ]; #[allow(clippy::enum_variant_names)] @@ -11134,6 +11138,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { RecursiveQuery, CteWorkTableScan, Dml, + Pivot, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -11187,6 +11192,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "recursiveQuery" | "recursive_query" => Ok(GeneratedField::RecursiveQuery), "cteWorkTableScan" | "cte_work_table_scan" => Ok(GeneratedField::CteWorkTableScan), "dml" => Ok(GeneratedField::Dml), + "pivot" => Ok(GeneratedField::Pivot), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -11431,6 +11437,13 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { return Err(serde::de::Error::duplicate_field("dml")); } logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Dml) +; + } + GeneratedField::Pivot => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("pivot")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Pivot) ; } } @@ -16983,6 +16996,204 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { deserializer.deserialize_struct("datafusion.PhysicalWindowExprNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PivotNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.aggregate_expr.is_some() { + len += 1; + } + if self.pivot_column.is_some() { + len += 1; + } + if !self.pivot_values.is_empty() { + len += 1; + } + if self.schema.is_some() { + len += 1; + } + if self.value_subquery.is_some() { + len += 1; + } + if self.default_on_null_expr.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PivotNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.aggregate_expr.as_ref() { + struct_ser.serialize_field("aggregateExpr", v)?; + } + if let Some(v) = self.pivot_column.as_ref() { + struct_ser.serialize_field("pivotColumn", v)?; + } + if !self.pivot_values.is_empty() { + struct_ser.serialize_field("pivotValues", &self.pivot_values)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + if let Some(v) = self.value_subquery.as_ref() { + struct_ser.serialize_field("valueSubquery", v)?; + } + if let Some(v) = self.default_on_null_expr.as_ref() { + struct_ser.serialize_field("defaultOnNullExpr", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PivotNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "aggregate_expr", + "aggregateExpr", + "pivot_column", + "pivotColumn", + "pivot_values", + "pivotValues", + "schema", + "value_subquery", + "valueSubquery", + "default_on_null_expr", + "defaultOnNullExpr", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + AggregateExpr, + PivotColumn, + PivotValues, + Schema, + ValueSubquery, + DefaultOnNullExpr, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "aggregateExpr" | "aggregate_expr" => Ok(GeneratedField::AggregateExpr), + "pivotColumn" | "pivot_column" => Ok(GeneratedField::PivotColumn), + "pivotValues" | "pivot_values" => Ok(GeneratedField::PivotValues), + "schema" => Ok(GeneratedField::Schema), + "valueSubquery" | "value_subquery" => Ok(GeneratedField::ValueSubquery), + "defaultOnNullExpr" | "default_on_null_expr" => Ok(GeneratedField::DefaultOnNullExpr), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PivotNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PivotNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut aggregate_expr__ = None; + let mut pivot_column__ = None; + let mut pivot_values__ = None; + let mut schema__ = None; + let mut value_subquery__ = None; + let mut default_on_null_expr__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::AggregateExpr => { + if aggregate_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("aggregateExpr")); + } + aggregate_expr__ = map_.next_value()?; + } + GeneratedField::PivotColumn => { + if pivot_column__.is_some() { + return Err(serde::de::Error::duplicate_field("pivotColumn")); + } + pivot_column__ = map_.next_value()?; + } + GeneratedField::PivotValues => { + if pivot_values__.is_some() { + return Err(serde::de::Error::duplicate_field("pivotValues")); + } + pivot_values__ = Some(map_.next_value()?); + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + GeneratedField::ValueSubquery => { + if value_subquery__.is_some() { + return Err(serde::de::Error::duplicate_field("valueSubquery")); + } + value_subquery__ = map_.next_value()?; + } + GeneratedField::DefaultOnNullExpr => { + if default_on_null_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("defaultOnNullExpr")); + } + default_on_null_expr__ = map_.next_value()?; + } + } + } + Ok(PivotNode { + input: input__, + aggregate_expr: aggregate_expr__, + pivot_column: pivot_column__, + pivot_values: pivot_values__.unwrap_or_default(), + schema: schema__, + value_subquery: value_subquery__, + default_on_null_expr: default_on_null_expr__, + }) + } + } + deserializer.deserialize_struct("datafusion.PivotNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PlaceholderNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 41c60b22e3bc..763b84f03940 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -5,7 +5,7 @@ pub struct LogicalPlanNode { #[prost( oneof = "logical_plan_node::LogicalPlanType", - tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33" + tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34" )] pub logical_plan_type: ::core::option::Option, } @@ -77,6 +77,8 @@ pub mod logical_plan_node { CteWorkTableScan(super::CteWorkTableScanNode), #[prost(message, tag = "33")] Dml(::prost::alloc::boxed::Box), + #[prost(message, tag = "34")] + Pivot(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1932,6 +1934,25 @@ pub struct CteWorkTableScanNode { #[prost(message, optional, tag = "2")] pub schema: ::core::option::Option, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PivotNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub aggregate_expr: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub pivot_column: ::core::option::Option, + #[prost(message, repeated, tag = "4")] + pub pivot_values: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "5")] + pub schema: ::core::option::Option, + #[prost(message, optional, boxed, tag = "6")] + pub value_subquery: ::core::option::Option< + ::prost::alloc::boxed::Box, + >, + #[prost(message, optional, tag = "7")] + pub default_on_null_expr: ::core::option::Option, +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum WindowFrameUnits { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index a39e6dac37c1..1f3c256ed618 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -64,7 +64,7 @@ use datafusion_expr::{ logical_plan::{ builder::project, Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateView, DdlStatement, Distinct, EmptyRelation, - Extension, Join, JoinConstraint, Prepare, Projection, Repartition, Sort, + Extension, Join, JoinConstraint, Pivot, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Values, Window, }, DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, SortExpr, @@ -994,6 +994,61 @@ impl AsLogicalPlan for LogicalPlanNode { Arc::new(into_logical_plan!(dml_node.input, ctx, extension_codec)?), ), )), + LogicalPlanType::Pivot(pivot) => { + let aggregate_expr = pivot + .aggregate_expr + .as_ref() + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) + .transpose()? + .ok_or_else(|| { + DataFusionError::Internal("aggregate_expr required".to_string()) + })?; + let pivot_column = pivot + .pivot_column + .as_ref() + .map(|col| col.clone().into()) + .ok_or_else(|| { + DataFusionError::Internal("pivot_column required".to_string()) + })?; + let pivot_values = pivot + .pivot_values + .iter() + .map(|val| val.try_into()) + .collect::, _>>( + )?; + let schema = Arc::new(convert_required!(pivot.schema)?); + let value_subquery = if pivot.value_subquery.is_some() { + Some(Arc::new(into_logical_plan!( + pivot.value_subquery, + ctx, + extension_codec + )?)) + } else { + None + }; + let default_on_null_expr = if pivot.default_on_null_expr.is_some() { + pivot + .default_on_null_expr + .as_ref() + .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) + .transpose()? + } else { + None + }; + Ok(LogicalPlan::Pivot(Pivot { + input: Arc::new(into_logical_plan!( + pivot.input, + ctx, + extension_codec + )?), + aggregate_expr, + pivot_column, + pivot_values, + schema, + value_subquery, + default_on_null_expr, + })) + } } } @@ -1805,6 +1860,9 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } + LogicalPlan::Pivot(_) => Err(proto_error( + "LogicalPlan serde is not yet implemented for Statement", + )), } } } diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 8425af61c080..4bdff68baadb 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -21,7 +21,8 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ - not_impl_err, plan_err, DFSchema, Diagnostic, Result, Span, Spans, TableReference, + not_impl_err, plan_err, Column, DFSchema, Diagnostic, Result, Span, Spans, + TableReference, }; use datafusion_expr::builder::subquery_alias; use datafusion_expr::{expr::Unnest, Expr, LogicalPlan, LogicalPlanBuilder}; @@ -169,7 +170,136 @@ impl SqlToRel<'_, S> { "UNNEST table factor with offset is not supported yet" ); } - // @todo Support TableFactory::TableFunction? + TableFactor::Pivot { + table, + aggregate_functions, + value_column, + value_source, + default_on_null, + alias, + } => { + let input_plan = self.create_relation(*table, planner_context)?; + + if aggregate_functions.len() != 1 { + return plan_err!("PIVOT requires exactly one aggregate function"); + } + + let agg_expr = self.sql_expr_to_logical_expr( + aggregate_functions[0].expr.clone(), + input_plan.schema(), + planner_context, + )?; + + if value_column.is_empty() { + return plan_err!("PIVOT value column is required"); + } + + let column_name = value_column.last().unwrap().value.clone(); + let pivot_column = Column::new(None::<&str>, column_name); + + let default_on_null_expr = default_on_null + .map(|expr| { + self.sql_expr_to_logical_expr( + expr, + input_plan.schema(), // Default expression should be context-independent or use input schema + planner_context, + ) + }) + .transpose()?; + + match value_source { + sqlparser::ast::PivotValueSource::List(exprs) => { + let pivot_values = exprs + .iter() + .map(|expr| { + let logical_expr = self.sql_expr_to_logical_expr( + expr.expr.clone(), + input_plan.schema(), + planner_context, + )?; + + match logical_expr { + Expr::Literal(scalar) => Ok(scalar), + _ => plan_err!("PIVOT values must be literals"), + } + }) + .collect::>>()?; + + let input_arc = Arc::new(input_plan); + + let pivot_plan = datafusion_expr::Pivot::try_new( + input_arc, + agg_expr, + pivot_column, + pivot_values, + default_on_null_expr.clone(), + )?; + + (LogicalPlan::Pivot(pivot_plan), alias) + } + sqlparser::ast::PivotValueSource::Any(order_by) => { + let input_arc = Arc::new(input_plan); + + let mut subquery_builder = + LogicalPlanBuilder::from(input_arc.as_ref().clone()) + .project(vec![Expr::Column(pivot_column.clone())])? + .distinct()?; + + if !order_by.is_empty() { + let sort_exprs = order_by + .iter() + .map(|item| { + let input_schema = subquery_builder.schema(); + + let expr = self.sql_expr_to_logical_expr( + item.expr.clone(), + input_schema, + planner_context, + ); + + expr.map(|e| { + e.sort( + item.options.asc.unwrap_or(true), + item.options.nulls_first.unwrap_or(false), + ) + }) + }) + .collect::>>()?; + + subquery_builder = subquery_builder.sort(sort_exprs)?; + } + + let subquery_plan = subquery_builder.build()?; + + let pivot_plan = datafusion_expr::Pivot::try_new_with_subquery( + input_arc, + agg_expr, + pivot_column, + Arc::new(subquery_plan), + default_on_null_expr.clone(), + )?; + + (LogicalPlan::Pivot(pivot_plan), alias) + } + sqlparser::ast::PivotValueSource::Subquery(subquery) => { + let subquery_plan = + self.query_to_plan(*subquery.clone(), planner_context)?; + + let input_arc = Arc::new(input_plan); + + let pivot_plan = datafusion_expr::Pivot::try_new_with_subquery( + input_arc, + agg_expr, + pivot_column, + Arc::new(subquery_plan), + default_on_null_expr.clone(), + )?; + + (LogicalPlan::Pivot(pivot_plan), alias) + } + } + } + // @todo: Support TableFactory::TableFunction _ => { return not_impl_err!( "Unsupported ast node {relation:?} in create_relation" diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 77edef4f8602..9e1bda9cf807 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -118,6 +118,7 @@ impl Unparser<'_> { LogicalPlan::Extension(extension) => { self.extension_to_statement(extension.node.as_ref()) } + LogicalPlan::Pivot(_) => not_impl_err!("Unsupported plan Pivot: {plan:?}"), LogicalPlan::Explain(_) | LogicalPlan::Analyze(_) | LogicalPlan::Ddl(_) diff --git a/datafusion/sqllogictest/test_files/pivot.slt b/datafusion/sqllogictest/test_files/pivot.slt new file mode 100644 index 000000000000..d076e0704bac --- /dev/null +++ b/datafusion/sqllogictest/test_files/pivot.slt @@ -0,0 +1,431 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +####### +# Setup test data table +####### +statement ok +CREATE TABLE quarterly_sales( + empid INT, + amount INT, + quarter TEXT) + AS SELECT * FROM VALUES + (1, 10000, '2023_Q1'), + (1, 400, '2023_Q1'), + (2, 4500, '2023_Q1'), + (2, 35000, '2023_Q1'), + (1, 5000, '2023_Q2'), + (1, 3000, '2023_Q2'), + (2, 200, '2023_Q2'), + (2, 90500, '2023_Q2'), + (1, 6000, '2023_Q3'), + (1, 5000, '2023_Q3'), + (2, 2500, '2023_Q3'), + (2, 9500, '2023_Q3'), + (3, 2700, '2023_Q3'), + (1, 8000, '2023_Q4'), + (1, 10000, '2023_Q4'), + (2, 800, '2023_Q4'), + (2, 4500, '2023_Q4'), + (3, 2700, '2023_Q4'), + (3, 16000, '2023_Q4'), + (3, 10200, '2023_Q4'); + +query IIIII +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q3', '2023_Q4')) +ORDER BY empid; +---- +1 10400 8000 11000 18000 +2 39500 90700 12000 5300 +3 NULL NULL 2700 28900 + +# PIVOT with NULL handling +query III +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) + FOR quarter IN ('2023_Q1', '2023_Q2') + DEFAULT ON NULL (1001)) +ORDER BY empid; +---- +1 10400 8000 +2 39500 90700 +3 1001 1001 + +# PIVOT with automatic detection of all distinct column values using ANY +query TIII +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) FOR empid IN (ANY ORDER BY empid)) +ORDER BY quarter; +---- +2023_Q1 10400 39500 NULL +2023_Q2 8000 90700 NULL +2023_Q3 11000 12000 2700 +2023_Q4 18000 5300 28900 + +# PIVOT with ANY that includes output column reordering +query IIIII +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) FOR quarter IN (ANY ORDER BY quarter DESC)) +ORDER BY empid; +---- +1 18000 11000 8000 10400 +2 5300 12000 90700 39500 +3 28900 2700 NULL NULL + +# PIVOT with a subquery to specify the values +query III +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) + FOR quarter IN ( + SELECT DISTINCT quarter FROM quarterly_sales WHERE quarter LIKE '%Q1' OR quarter LIKE '%Q3' + )) +ORDER BY empid; +---- +1 10400 11000 +2 39500 12000 +3 NULL 2700 + +query IIIII +WITH sales_without_discount AS + (SELECT empid, amount, quarter FROM quarterly_sales) +SELECT * +FROM sales_without_discount +PIVOT(SUM(amount) FOR quarter IN (ANY ORDER BY quarter)) +ORDER BY empid; +---- +1 10400 8000 11000 18000 +2 39500 90700 12000 5300 +3 NULL NULL 2700 28900 + + +# Non-existent column in the FOR clause +query error DataFusion error: Schema error: No field named non_existent_column\. Valid fields are quarterly_sales\.empid, quarterly_sales\.amount, quarterly_sales\.quarter\. +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) FOR non_existent_column IN ('2023_Q1', '2023_Q2')) +ORDER BY empid; + +# Non-existent column in the aggregate function +query error DataFusion error: Schema error: No field named non_existent_column\. Valid fields are quarterly_sales\.empid, quarterly_sales\.amount, quarterly_sales\.quarter\. +SELECT * +FROM quarterly_sales +PIVOT(SUM(non_existent_column) FOR quarter IN ('2023_Q1', '2023_Q2')) +ORDER BY empid; + +# Trying to use non-aggregate function +query error DataFusion error: Error during planning: Unsupported aggregate expression should always be AggregateFunction +SELECT * +FROM quarterly_sales +PIVOT(ABS(amount) FOR quarter IN ('2023_Q1', '2023_Q2')) +ORDER BY empid; + +# Invalid subquery in the IN list - multiple columns +query error DataFusion error: Error during planning: Pivot subquery must return a single column +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) + FOR quarter IN (SELECT quarter, empid FROM quarterly_sales LIMIT 2)) +ORDER BY empid; + +# Invalid DEFAULT ON NULL value (dependent on pivot/aggregation columns) +query error DataFusion error: Schema error: No field named quarterly_sales\.amount\. Valid fields are quarterly_sales\.empid, "2023_Q1", "2023_Q2"\. +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) + FOR quarter IN ('2023_Q1', '2023_Q2') + DEFAULT ON NULL (amount)) +ORDER BY empid; + +# PIVOT after a PIVOT +query error DataFusion error: Schema error: No field named empid\. Valid fields are "2023_Q2", "0", "10000", "20000", "2023_Q2", "0", "10000", "20000"\. +SELECT * +FROM ( + SELECT * + FROM quarterly_sales + PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q2')) +) +PIVOT(AVG(empid) FOR "2023_Q1" IN (0, 10000, 20000)) +ORDER BY empid; + +# PIVOT with window functions in the pivot expression +query error DataFusion error: Schema error: No field named empid\. Valid fields are "2023_Q1", "2023_Q2", "2023_Q1", "2023_Q2"\. +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount) OVER (PARTITION BY empid) FOR quarter IN ('2023_Q1', '2023_Q2')) +ORDER BY empid; + +# PIVOT with ORDER BY in the aggregate function +query error DataFusion error: Schema error: No field named empid\. Valid fields are "2023_Q1", "2023_Q2", "2023_Q1", "2023_Q2"\. +SELECT * +FROM quarterly_sales +PIVOT(SUM(amount ORDER BY empid) FOR quarter IN ('2023_Q1', '2023_Q2')) +ORDER BY empid; + +statement ok +CREATE TABLE employees( + empid INT, + name TEXT, + department TEXT, + hire_date DATE) + AS SELECT * FROM VALUES + (1, 'Alice', 'Sales', '2020-01-15'), + (2, 'Bob', 'Sales', '2021-03-10'), + (3, 'Charlie', 'Marketing', '2022-06-22'), + (4, 'David', 'Engineering', '2019-11-08'), + (5, 'Eve', 'Marketing', '2023-02-01'); + +statement ok +CREATE TABLE product_sales( + product_id INT, + category TEXT, + sale_amount INT, + sale_date DATE) + AS SELECT * FROM VALUES + (101, 'Electronics', 1200, '2023-01-10'), + (102, 'Clothing', 500, '2023-01-15'), + (103, 'Home', 800, '2023-01-20'), + (104, 'Electronics', 1500, '2023-02-05'), + (105, 'Clothing', 600, '2023-02-12'), + (106, 'Home', 900, '2023-02-25'), + (107, 'Electronics', 2000, '2023-03-08'), + (108, 'Clothing', 700, '2023-03-15'), + (109, 'Home', 1100, '2023-03-22'), + (110, 'Electronics', 1800, '2023-04-05'), + (111, 'Clothing', 550, '2023-04-14'), + (112, 'Home', 950, '2023-04-28'); + +query TIIIII +SELECT e.name, s.* +FROM employees e +JOIN ( + SELECT empid, "2023_Q1", "2023_Q2", "2023_Q3", "2023_Q4" + FROM quarterly_sales + PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q3', '2023_Q4')) +) s ON e.empid = s.empid +ORDER BY e.empid; +---- +Alice 1 10400 8000 11000 18000 +Bob 2 39500 90700 12000 5300 +Charlie 3 NULL NULL 2700 28900 + +# PIVOT with filtered subquery +query III +SELECT * +FROM ( + SELECT empid, amount, quarter + FROM quarterly_sales + WHERE amount > 5000 +) +PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q4')) +ORDER BY empid; +---- +1 10000 18000 +2 35000 NULL +3 NULL 26200 + +query TII +SELECT + category, + "Q1", + "Q2" +FROM ( + SELECT + category, + CASE + WHEN EXTRACT(QUARTER FROM sale_date) = 1 THEN 'Q1' + WHEN EXTRACT(QUARTER FROM sale_date) = 2 THEN 'Q2' + END AS quarter, + sale_amount + FROM product_sales +) +PIVOT( + SUM(sale_amount) AS total + FOR quarter IN ('Q1' AS "Q1", 'Q2' AS "Q2") +) +ORDER BY category; +---- +Clothing 1800 550 +Electronics 4700 1800 +Home 2800 950 + +# PIVOT with arithmetic operations on the aggregated values +query TIIIR +SELECT + e.name, + p."2023_Q1", + p."2023_Q4", + p."2023_Q4" - p."2023_Q1" AS q4_minus_q1, + CASE + WHEN p."2023_Q1" = 0 THEN NULL + ELSE (p."2023_Q4" - p."2023_Q1")*100.0/p."2023_Q1" + END AS percent_change +FROM employees e +LEFT JOIN ( + SELECT empid, "2023_Q1", "2023_Q4" + FROM quarterly_sales + PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q4')) +) p ON e.empid = p.empid +ORDER BY e.name; +---- +Alice 10400 18000 7600 73.076923076923 +Bob 39500 5300 -34200 -86.582278481013 +Charlie NULL 28900 NULL NULL +David NULL NULL NULL NULL +Eve NULL NULL NULL NULL + +# PIVOT with HAVING clause +query TII +WITH dept_pivot AS ( + SELECT + e.department, + q."2023_Q1", + q."2023_Q4" + FROM employees e + LEFT JOIN ( + SELECT empid, "2023_Q1", "2023_Q4" + FROM quarterly_sales + PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q4')) + ) q ON e.empid = q.empid +) +SELECT department, SUM("2023_Q1") as q1_total, SUM("2023_Q4") as q4_total +FROM dept_pivot +GROUP BY department +HAVING SUM("2023_Q4") > 0 +ORDER BY department; +---- +Marketing NULL 28900 +Sales 49900 23300 + +# PIVOT with CASE expressions for custom grouping +query III +SELECT * +FROM ( + SELECT + empid, + amount, + CASE + WHEN quarter IN ('2023_Q1', '2023_Q2') THEN 'H1' + WHEN quarter IN ('2023_Q3', '2023_Q4') THEN 'H2' + END AS half_year + FROM quarterly_sales +) +PIVOT(SUM(amount) FOR half_year IN ('H1', 'H2')) +ORDER BY empid; +---- +1 18400 29000 +2 130200 17300 +3 NULL 31600 + +# PIVOT WITH UNION +query TIRRR +SELECT 'Average sale amount' AS aggregate, * + FROM quarterly_sales + PIVOT(AVG(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q4')) +UNION +SELECT 'Highest value sale' AS aggregate, * + FROM quarterly_sales + PIVOT(MAX(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q4')) +UNION +SELECT 'Lowest value sale' AS aggregate, * + FROM quarterly_sales + PIVOT(MIN(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q4')) +UNION +SELECT 'Number of sales' AS aggregate, * + FROM quarterly_sales + PIVOT(COUNT(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q4')) +UNION +SELECT 'Total amount' AS aggregate, * + FROM quarterly_sales + PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q4')) +ORDER BY aggregate, empid; +---- +Average sale amount 1 5200 4000 9000 +Average sale amount 2 19750 45350 2650 +Average sale amount 3 NULL NULL 9633.333333333334 +Highest value sale 1 10000 5000 10000 +Highest value sale 2 35000 90500 4500 +Highest value sale 3 NULL NULL 16000 +Lowest value sale 1 400 3000 8000 +Lowest value sale 2 4500 200 800 +Lowest value sale 3 NULL NULL 2700 +Number of sales 1 2 2 2 +Number of sales 2 2 2 2 +Number of sales 3 0 0 3 +Total amount 1 10400 8000 18000 +Total amount 2 39500 90700 5300 +Total amount 3 NULL NULL 28900 + + +query TIIII +WITH sales_sum AS ( + SELECT + empid, + "2023_Q1" AS "Q1_Sales", + "2023_Q4" AS "Q4_Sales" + FROM quarterly_sales + PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q4')) +), +sales_count AS ( + SELECT + empid, + "2023_Q1" AS "Q1_Count", + "2023_Q4" AS "Q4_Count" + FROM quarterly_sales + PIVOT(COUNT(amount) FOR quarter IN ('2023_Q1', '2023_Q4')) +), +combined_sales AS ( + SELECT + ss.empid, + ss."Q1_Sales", + ss."Q4_Sales", + sc."Q1_Count", + sc."Q4_Count" + FROM sales_sum ss + JOIN sales_count sc ON ss.empid = sc.empid +) +SELECT dept.* +FROM ( + SELECT + e.department, + s."Q1_Sales", + s."Q4_Sales", + s."Q1_Count", + s."Q4_Count" + FROM employees e JOIN combined_sales s ON e.empid = s.empid +) dept +WHERE dept.department IN ('Sales', 'Marketing') +ORDER BY dept.department +---- +Marketing NULL 28900 0 3 +Sales 39500 5300 2 2 +Sales 10400 18000 2 2 + +# Test PIVOT subquery with projection +query TIRRRR +SELECT 'Average sale amount' AS aggregate, * + FROM quarterly_sales + PIVOT(AVG(amount) FOR quarter IN (ANY ORDER BY quarter)) ORDER by empid +---- +Average sale amount 1 5200 4000 5500 9000 +Average sale amount 2 19750 45350 6000 2650 +Average sale amount 3 NULL NULL 2700 9633.333333333334 diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 07bf0cb96aa3..099f418afcae 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -531,6 +531,7 @@ pub fn to_substrait_rel( LogicalPlan::RecursiveQuery(plan) => { not_impl_err!("Unsupported plan type: {plan:?}")? } + LogicalPlan::Pivot(plan) => not_impl_err!("Unsupported plan type: {plan:?}")?, } }