diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 633d933eb845..c725b49bc303 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -20,17 +20,18 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::stack::StackGuard; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{not_impl_err, Constraints, DFSchema, Result}; -use datafusion_expr::expr::{Sort, WildcardOptions}; +use datafusion_expr::expr::{AggregateFunction, Sort, WildcardOptions}; use datafusion_expr::select_expr::SelectExpr; use datafusion_expr::{ - CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, + col, CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, }; use sqlparser::ast::{ Expr as SQLExpr, ExprWithAliasAndOrderBy, Ident, LimitClause, Offset, OffsetRows, - OrderBy, OrderByExpr, OrderByKind, PipeOperator, Query, SelectInto, SetExpr, - SetOperator, SetQuantifier, TableAlias, + OrderBy, OrderByExpr, OrderByKind, PipeOperator, PivotValueSource, Query, SelectInto, + SetExpr, SetOperator, SetQuantifier, TableAlias, }; use sqlparser::tokenizer::Span; @@ -194,6 +195,22 @@ impl SqlToRel<'_, S> { group_by_expr, planner_context, ), + PipeOperator::Join(join) => { + self.parse_relation_join(plan, join, planner_context) + } + PipeOperator::Pivot { + aggregate_functions, + value_column, + value_source, + alias, + } => self.pipe_operator_pivot( + plan, + aggregate_functions, + value_column, + value_source, + alias, + planner_context, + ), x => not_impl_err!("`{x}` pipe operator is not supported yet"), } @@ -336,6 +353,137 @@ impl SqlToRel<'_, S> { .build() } + /// Handle PIVOT pipe operator + fn pipe_operator_pivot( + &self, + plan: LogicalPlan, + aggregate_functions: Vec, + value_column: Vec, + value_source: PivotValueSource, + alias: Option, + planner_context: &mut PlannerContext, + ) -> Result { + let pivot_values = if let PivotValueSource::List(values) = value_source { + values + } else { + return not_impl_err!( + "Only static pivot value lists are supported currently" + ); + }; + + if value_column.len() != 1 { + return not_impl_err!("Multi-column pivot is not supported yet"); + } + let pivot_col_name = &value_column[0].value; + let pivot_col_expr = col(pivot_col_name); + + let input_schema = plan.schema(); + + // Convert sql to DF exprs + let aggregate_functions = aggregate_functions + .into_iter() + .map(|f| self.sql_to_expr_with_alias(f, input_schema, planner_context)) + .collect::, _>>()?; + + // Convert aggregate functions to logical expressions to extract measure columns + let mut measure_columns = std::collections::HashSet::new(); + for agg_func_with_alias in &aggregate_functions { + agg_func_with_alias.apply(|e| { + if let Expr::Column(col) = e { + measure_columns.insert(col.name.clone()); + }; + Ok(TreeNodeRecursion::Continue) + })?; + } + + // Get all column names from the input plan to determine group-by columns. + // Add all columns except the pivot column and measure columns to group by + let mut group_by_cols = Vec::new(); + for field in input_schema.fields() { + let col_name = field.name(); + if col_name != pivot_col_name && !measure_columns.contains(col_name) { + group_by_cols.push(col(col_name)); + } + } + + let mut aggr_exprs = Vec::new(); + + // For each pivot value and aggregate function combination, create a conditional aggregate + // Process pivot values first to get the desired column order + for pivot_value in pivot_values { + let pivot_value_expr = self.sql_to_expr( + pivot_value.expr.clone(), + input_schema, + planner_context, + )?; + for agg_func_with_alias in &aggregate_functions { + let (alias_name, mut agg_fn) = match agg_func_with_alias { + Expr::Alias(alias) => match *alias.expr.clone() { + Expr::Alias(inner_alias) => { + let Expr::AggregateFunction( + agg_func @ AggregateFunction { .. }, + ) = *inner_alias.expr.clone() + else { + return not_impl_err!("Only function expressions are supported in PIVOT aggregate functions"); + }; + (Some(alias.name.clone()), agg_func) + } + Expr::AggregateFunction(agg_func @ AggregateFunction { .. }) => { + (Some(alias.name.clone()), agg_func) + } + _ => { + return not_impl_err!("Only function expressions are supported in PIVOT aggregate functions"); + } + }, + Expr::AggregateFunction(agg_func) => (None, agg_func.clone()), + _ => { + return not_impl_err!("Expected aggregate function"); + } + }; + + let new_filter = pivot_col_expr.clone().eq(pivot_value_expr.clone()); + if let Some(existing_filter) = agg_fn.params.filter { + agg_fn.params.filter = + Some(Box::new(existing_filter.and(new_filter))); + } else { + agg_fn.params.filter = Some(Box::new(new_filter)); + } + + let agg_expr = Expr::AggregateFunction(agg_fn); + let aggr_func_alias = alias_name.unwrap_or(agg_expr.name_for_alias()?); + + let pivot_value_name = if let Some(alias) = &pivot_value.alias { + alias.value.clone() + } else { + // Use the pivot value as column name, stripping quotes + pivot_value.expr.to_string().trim_matches('\'').to_string() + }; + + aggr_exprs.push( + // Give unique name based on pivot column name + agg_expr.alias(format!("{aggr_func_alias}_{pivot_value_name}")), + ); + } + } + + let result_plan = LogicalPlanBuilder::from(plan) + .aggregate(group_by_cols, aggr_exprs)? + .build()?; + + // Apply table alias if provided + if let Some(table_alias) = alias { + self.apply_table_alias( + result_plan, + TableAlias { + name: table_alias, + columns: vec![], + }, + ) + } else { + Ok(result_plan) + } + } + /// Wrap the logical plan in a `SelectInto` fn select_into( &self, diff --git a/datafusion/sql/src/relation/join.rs b/datafusion/sql/src/relation/join.rs index 10491963e3ce..f8603e29bdcf 100644 --- a/datafusion/sql/src/relation/join.rs +++ b/datafusion/sql/src/relation/join.rs @@ -43,7 +43,7 @@ impl SqlToRel<'_, S> { Ok(left) } - fn parse_relation_join( + pub(crate) fn parse_relation_join( &self, left: LogicalPlan, join: Join, diff --git a/datafusion/sqllogictest/test_files/pipe_operator.slt b/datafusion/sqllogictest/test_files/pipe_operator.slt index 57d1fc064201..406d8b9b324b 100644 --- a/datafusion/sqllogictest/test_files/pipe_operator.slt +++ b/datafusion/sqllogictest/test_files/pipe_operator.slt @@ -177,3 +177,120 @@ query TII rowsort |> WHERE num_items > 1; ---- apples 2 9 + +# JOIN pipe +query TII +( + SELECT 'apples' AS item, 2 AS sales + UNION ALL + SELECT 'bananas' AS item, 5 AS sales +) +|> AS produce_sales +|> LEFT JOIN + ( + SELECT "apples" AS item, 123 AS id + ) AS produce_data + ON produce_sales.item = produce_data.item +|> SELECT produce_sales.item, sales, id; +---- +apples 2 123 +bananas 5 NULL + +# PIVOT pipe + +statement ok +CREATE TABLE pipe_test( + product VARCHAR, + sales INT, + quarter VARCHAR, + year INT +) AS VALUES + ('Kale', 51, 'Q1', 2020), + ('Kale', 23, 'Q2', 2020), + ('Kale', 45, 'Q3', 2020), + ('Kale', 3, 'Q4', 2020), + ('Kale', 70, 'Q1', 2021), + ('Kale', 85, 'Q2', 2021), + ('Apple', 77, 'Q1', 2020), + ('Apple', 0, 'Q2', 2020), + ('Apple', 1, 'Q1', 2021) +; + +query TIIIII rowsort +SELECT * FROM pipe_test +|> PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2', 'Q3', 'Q4')); +---- +Apple 2020 77 0 NULL NULL +Apple 2021 1 NULL NULL NULL +Kale 2020 51 23 45 3 +Kale 2021 70 85 NULL NULL + +query TIIII rowsort +SELECT * FROM pipe_test +|> select product, sales, quarter +|> PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2', 'Q3', 'Q4')); +---- +Apple 78 0 NULL NULL +Kale 121 108 45 3 + +query TIII rowsort +SELECT * FROM pipe_test +|> select product, sales, quarter +|> PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2', 'Q3')); +---- +Apple 78 0 NULL +Kale 121 108 45 + +query TIIII rowsort +SELECT * FROM pipe_test +|> select product, sales, quarter +|> PIVOT(SUM(sales) as total_sales, count(*) as num_records FOR quarter IN ('Q1', 'Q2')); +---- +Apple 78 2 0 1 +Kale 121 2 108 2 + + +query TT +EXPLAIN SELECT * FROM pipe_test +|> select product, sales, quarter +|> PIVOT(SUM(sales) as total_sales, count(*) as num_records FOR quarter IN ('Q1', 'Q2')); +---- +logical_plan +01)Aggregate: groupBy=[[pipe_test.product]], aggr=[[sum(__common_expr_1) FILTER (WHERE __common_expr_2) AS total_sales_Q1, count(Int64(1)) FILTER (WHERE __common_expr_2) AS num_records_Q1, sum(__common_expr_1) FILTER (WHERE __common_expr_3) AS total_sales_Q2, count(Int64(1)) FILTER (WHERE __common_expr_3) AS num_records_Q2]] +02)--Projection: CAST(pipe_test.sales AS Int64) AS __common_expr_1, pipe_test.quarter = Utf8View("Q1") AS __common_expr_2, pipe_test.quarter = Utf8View("Q2") AS __common_expr_3, pipe_test.product +03)----TableScan: pipe_test projection=[product, sales, quarter] +physical_plan +01)AggregateExec: mode=FinalPartitioned, gby=[product@0 as product], aggr=[total_sales_Q1, num_records_Q1, total_sales_Q2, num_records_Q2] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----RepartitionExec: partitioning=Hash([product@0], 4), input_partitions=4 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[product@3 as product], aggr=[total_sales_Q1, num_records_Q1, total_sales_Q2, num_records_Q2] +06)----------ProjectionExec: expr=[CAST(sales@1 AS Int64) as __common_expr_1, quarter@2 = Q1 as __common_expr_2, quarter@2 = Q2 as __common_expr_3, product@0 as product] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] + +# With explicit pivot value alias +query TT +EXPLAIN SELECT * FROM pipe_test +|> select product, sales, quarter +|> PIVOT(SUM(sales) as total_sales, count(*) as num_records FOR quarter IN ('Q1' as q1, 'Q2')); +---- +logical_plan +01)Aggregate: groupBy=[[pipe_test.product]], aggr=[[sum(__common_expr_1) FILTER (WHERE __common_expr_2) AS total_sales_q1, count(Int64(1)) FILTER (WHERE __common_expr_2) AS num_records_q1, sum(__common_expr_1) FILTER (WHERE __common_expr_3) AS total_sales_Q2, count(Int64(1)) FILTER (WHERE __common_expr_3) AS num_records_Q2]] +02)--Projection: CAST(pipe_test.sales AS Int64) AS __common_expr_1, pipe_test.quarter = Utf8View("Q1") AS __common_expr_2, pipe_test.quarter = Utf8View("Q2") AS __common_expr_3, pipe_test.product +03)----TableScan: pipe_test projection=[product, sales, quarter] +physical_plan +01)AggregateExec: mode=FinalPartitioned, gby=[product@0 as product], aggr=[total_sales_q1, num_records_q1, total_sales_Q2, num_records_Q2] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----RepartitionExec: partitioning=Hash([product@0], 4), input_partitions=4 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[product@3 as product], aggr=[total_sales_q1, num_records_q1, total_sales_Q2, num_records_Q2] +06)----------ProjectionExec: expr=[CAST(sales@1 AS Int64) as __common_expr_1, quarter@2 = Q1 as __common_expr_2, quarter@2 = Q2 as __common_expr_3, product@0 as product] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] + +# Aggregation functions with multiple parameters +query TTT rowsort +SELECT product, sales, quarter FROM pipe_test +|> PIVOT(string_agg(sales, '_' order by sales) as agg FOR quarter IN ('Q1', 'Q2')); +---- +Apple 1_77 0 +Kale 51_70 23_85 diff --git a/docs/source/user-guide/sql/select.md b/docs/source/user-guide/sql/select.md index 8c1bc401d3aa..e854d67cc1f6 100644 --- a/docs/source/user-guide/sql/select.md +++ b/docs/source/user-guide/sql/select.md @@ -350,6 +350,8 @@ DataFusion currently supports the following pipe operators: - [INTERSECT](#pipe_intersect) - [EXCEPT](#pipe_except) - [AGGREGATE](#pipe_aggregate) +- [PIVOT](#pipe_pivot) +- [JOIN](#pipe_join) (pipe_where)= @@ -514,3 +516,55 @@ select * from range(0,3) | 3 | +-------+ ``` + +(pipe_pivot)= + +### PIVOT + +Rotates rows into columns. + +```sql +> ( + SELECT 'kale' AS product, 51 AS sales, 'Q1' AS quarter + UNION ALL + SELECT 'kale' AS product, 4 AS sales, 'Q1' AS quarter + UNION ALL + SELECT 'kale' AS product, 45 AS sales, 'Q2' AS quarter + UNION ALL + SELECT 'apple' AS product, 8 AS sales, 'Q1' AS quarter + UNION ALL + SELECT 'apple' AS product, 10 AS sales, 'Q2' AS quarter +) +|> PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2')); ++---------+-----+-----+ +| product | Q1 | Q2 | ++---------+-----+-----+ +| apple | 8 | 10 | +| kale | 55 | 45 | ++---------+-----+-----+ +``` + +(pipe_join)= + +### JOIN + +```sql +> ( + SELECT 'apples' AS item, 2 AS sales + UNION ALL + SELECT 'bananas' AS item, 5 AS sales +) +|> AS produce_sales +|> LEFT JOIN + ( + SELECT 'apples' AS item, 123 AS id + ) AS produce_data + ON produce_sales.item = produce_data.item +|> SELECT produce_sales.item, sales, id; ++--------+-------+------+ +| item | sales | id | ++--------+-------+------+ +| apples | 2 | 123 | +| bananas| 5 | NULL | ++--------+-------+------+ +```