Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 152 additions & 4 deletions datafusion/sql/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,18 @@ use std::sync::Arc;
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};

use crate::stack::StackGuard;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{not_impl_err, Constraints, DFSchema, Result};
use datafusion_expr::expr::{Sort, WildcardOptions};
use datafusion_expr::expr::{AggregateFunction, Sort, WildcardOptions};

use datafusion_expr::select_expr::SelectExpr;
use datafusion_expr::{
CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder,
col, CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder,
};
use sqlparser::ast::{
Expr as SQLExpr, ExprWithAliasAndOrderBy, Ident, LimitClause, Offset, OffsetRows,
OrderBy, OrderByExpr, OrderByKind, PipeOperator, Query, SelectInto, SetExpr,
SetOperator, SetQuantifier, TableAlias,
OrderBy, OrderByExpr, OrderByKind, PipeOperator, PivotValueSource, Query, SelectInto,
SetExpr, SetOperator, SetQuantifier, TableAlias,
};
use sqlparser::tokenizer::Span;

Expand Down Expand Up @@ -194,6 +195,22 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
group_by_expr,
planner_context,
),
PipeOperator::Join(join) => {
self.parse_relation_join(plan, join, planner_context)
}
PipeOperator::Pivot {
aggregate_functions,
value_column,
value_source,
alias,
} => self.pipe_operator_pivot(
plan,
aggregate_functions,
value_column,
value_source,
alias,
planner_context,
),

x => not_impl_err!("`{x}` pipe operator is not supported yet"),
}
Expand Down Expand Up @@ -336,6 +353,137 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
.build()
}

/// Handle PIVOT pipe operator
fn pipe_operator_pivot(
&self,
plan: LogicalPlan,
aggregate_functions: Vec<sqlparser::ast::ExprWithAlias>,
value_column: Vec<Ident>,
value_source: PivotValueSource,
alias: Option<Ident>,
planner_context: &mut PlannerContext,
) -> Result<LogicalPlan> {
let pivot_values = if let PivotValueSource::List(values) = value_source {
values
} else {
return not_impl_err!(
"Only static pivot value lists are supported currently"
);
};

if value_column.len() != 1 {
return not_impl_err!("Multi-column pivot is not supported yet");
}
let pivot_col_name = &value_column[0].value;
let pivot_col_expr = col(pivot_col_name);

let input_schema = plan.schema();

// Convert sql to DF exprs
let aggregate_functions = aggregate_functions
.into_iter()
.map(|f| self.sql_to_expr_with_alias(f, input_schema, planner_context))
.collect::<Result<Vec<_>, _>>()?;

// Convert aggregate functions to logical expressions to extract measure columns
let mut measure_columns = std::collections::HashSet::new();
for agg_func_with_alias in &aggregate_functions {
agg_func_with_alias.apply(|e| {
if let Expr::Column(col) = e {
measure_columns.insert(col.name.clone());
};
Ok(TreeNodeRecursion::Continue)
})?;
}

// Get all column names from the input plan to determine group-by columns.
// Add all columns except the pivot column and measure columns to group by
let mut group_by_cols = Vec::new();
for field in input_schema.fields() {
let col_name = field.name();
if col_name != pivot_col_name && !measure_columns.contains(col_name) {
group_by_cols.push(col(col_name));
}
}

let mut aggr_exprs = Vec::new();

// For each pivot value and aggregate function combination, create a conditional aggregate
// Process pivot values first to get the desired column order
for pivot_value in pivot_values {
let pivot_value_expr = self.sql_to_expr(
pivot_value.expr.clone(),
input_schema,
planner_context,
)?;
for agg_func_with_alias in &aggregate_functions {
let (alias_name, mut agg_fn) = match agg_func_with_alias {
Expr::Alias(alias) => match *alias.expr.clone() {
Expr::Alias(inner_alias) => {
let Expr::AggregateFunction(
agg_func @ AggregateFunction { .. },
) = *inner_alias.expr.clone()
else {
return not_impl_err!("Only function expressions are supported in PIVOT aggregate functions");
};
(Some(alias.name.clone()), agg_func)
}
Expr::AggregateFunction(agg_func @ AggregateFunction { .. }) => {
(Some(alias.name.clone()), agg_func)
}
_ => {
return not_impl_err!("Only function expressions are supported in PIVOT aggregate functions");
}
},
Expr::AggregateFunction(agg_func) => (None, agg_func.clone()),
_ => {
return not_impl_err!("Expected aggregate function");
}
};

let new_filter = pivot_col_expr.clone().eq(pivot_value_expr.clone());
if let Some(existing_filter) = agg_fn.params.filter {
agg_fn.params.filter =
Some(Box::new(existing_filter.and(new_filter)));
} else {
agg_fn.params.filter = Some(Box::new(new_filter));
}

let agg_expr = Expr::AggregateFunction(agg_fn);
let aggr_func_alias = alias_name.unwrap_or(agg_expr.name_for_alias()?);

let pivot_value_name = if let Some(alias) = &pivot_value.alias {
alias.value.clone()
} else {
// Use the pivot value as column name, stripping quotes
pivot_value.expr.to_string().trim_matches('\'').to_string()
};

aggr_exprs.push(
// Give unique name based on pivot column name
agg_expr.alias(format!("{aggr_func_alias}_{pivot_value_name}")),
);
}
}

let result_plan = LogicalPlanBuilder::from(plan)
.aggregate(group_by_cols, aggr_exprs)?
.build()?;

// Apply table alias if provided
if let Some(table_alias) = alias {
self.apply_table_alias(
result_plan,
TableAlias {
name: table_alias,
columns: vec![],
},
)
} else {
Ok(result_plan)
}
}

/// Wrap the logical plan in a `SelectInto`
fn select_into(
&self,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/src/relation/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
Ok(left)
}

fn parse_relation_join(
pub(crate) fn parse_relation_join(
&self,
left: LogicalPlan,
join: Join,
Expand Down
117 changes: 117 additions & 0 deletions datafusion/sqllogictest/test_files/pipe_operator.slt
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,120 @@ query TII rowsort
|> WHERE num_items > 1;
----
apples 2 9

# JOIN pipe
query TII
(
SELECT 'apples' AS item, 2 AS sales
UNION ALL
SELECT 'bananas' AS item, 5 AS sales
)
|> AS produce_sales
|> LEFT JOIN
(
SELECT "apples" AS item, 123 AS id
) AS produce_data
ON produce_sales.item = produce_data.item
|> SELECT produce_sales.item, sales, id;
----
apples 2 123
bananas 5 NULL

# PIVOT pipe

statement ok
CREATE TABLE pipe_test(
product VARCHAR,
sales INT,
quarter VARCHAR,
year INT
) AS VALUES
('Kale', 51, 'Q1', 2020),
('Kale', 23, 'Q2', 2020),
('Kale', 45, 'Q3', 2020),
('Kale', 3, 'Q4', 2020),
('Kale', 70, 'Q1', 2021),
('Kale', 85, 'Q2', 2021),
('Apple', 77, 'Q1', 2020),
('Apple', 0, 'Q2', 2020),
('Apple', 1, 'Q1', 2021)
;

query TIIIII rowsort
SELECT * FROM pipe_test
|> PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2', 'Q3', 'Q4'));
----
Apple 2020 77 0 NULL NULL
Apple 2021 1 NULL NULL NULL
Kale 2020 51 23 45 3
Kale 2021 70 85 NULL NULL

query TIIII rowsort
SELECT * FROM pipe_test
|> select product, sales, quarter
|> PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2', 'Q3', 'Q4'));
----
Apple 78 0 NULL NULL
Kale 121 108 45 3

query TIII rowsort
SELECT * FROM pipe_test
|> select product, sales, quarter
|> PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2', 'Q3'));
----
Apple 78 0 NULL
Kale 121 108 45

query TIIII rowsort
SELECT * FROM pipe_test
|> select product, sales, quarter
|> PIVOT(SUM(sales) as total_sales, count(*) as num_records FOR quarter IN ('Q1', 'Q2'));
----
Apple 78 2 0 1
Kale 121 2 108 2


query TT
EXPLAIN SELECT * FROM pipe_test
|> select product, sales, quarter
|> PIVOT(SUM(sales) as total_sales, count(*) as num_records FOR quarter IN ('Q1', 'Q2'));
----
logical_plan
01)Aggregate: groupBy=[[pipe_test.product]], aggr=[[sum(__common_expr_1) FILTER (WHERE __common_expr_2) AS total_sales_Q1, count(Int64(1)) FILTER (WHERE __common_expr_2) AS num_records_Q1, sum(__common_expr_1) FILTER (WHERE __common_expr_3) AS total_sales_Q2, count(Int64(1)) FILTER (WHERE __common_expr_3) AS num_records_Q2]]
02)--Projection: CAST(pipe_test.sales AS Int64) AS __common_expr_1, pipe_test.quarter = Utf8View("Q1") AS __common_expr_2, pipe_test.quarter = Utf8View("Q2") AS __common_expr_3, pipe_test.product
03)----TableScan: pipe_test projection=[product, sales, quarter]
physical_plan
01)AggregateExec: mode=FinalPartitioned, gby=[product@0 as product], aggr=[total_sales_Q1, num_records_Q1, total_sales_Q2, num_records_Q2]
02)--CoalesceBatchesExec: target_batch_size=8192
03)----RepartitionExec: partitioning=Hash([product@0], 4), input_partitions=4
04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
05)--------AggregateExec: mode=Partial, gby=[product@3 as product], aggr=[total_sales_Q1, num_records_Q1, total_sales_Q2, num_records_Q2]
06)----------ProjectionExec: expr=[CAST(sales@1 AS Int64) as __common_expr_1, quarter@2 = Q1 as __common_expr_2, quarter@2 = Q2 as __common_expr_3, product@0 as product]
07)------------DataSourceExec: partitions=1, partition_sizes=[1]

# With explicit pivot value alias
query TT
EXPLAIN SELECT * FROM pipe_test
|> select product, sales, quarter
|> PIVOT(SUM(sales) as total_sales, count(*) as num_records FOR quarter IN ('Q1' as q1, 'Q2'));
----
logical_plan
01)Aggregate: groupBy=[[pipe_test.product]], aggr=[[sum(__common_expr_1) FILTER (WHERE __common_expr_2) AS total_sales_q1, count(Int64(1)) FILTER (WHERE __common_expr_2) AS num_records_q1, sum(__common_expr_1) FILTER (WHERE __common_expr_3) AS total_sales_Q2, count(Int64(1)) FILTER (WHERE __common_expr_3) AS num_records_Q2]]
02)--Projection: CAST(pipe_test.sales AS Int64) AS __common_expr_1, pipe_test.quarter = Utf8View("Q1") AS __common_expr_2, pipe_test.quarter = Utf8View("Q2") AS __common_expr_3, pipe_test.product
03)----TableScan: pipe_test projection=[product, sales, quarter]
physical_plan
01)AggregateExec: mode=FinalPartitioned, gby=[product@0 as product], aggr=[total_sales_q1, num_records_q1, total_sales_Q2, num_records_Q2]
02)--CoalesceBatchesExec: target_batch_size=8192
03)----RepartitionExec: partitioning=Hash([product@0], 4), input_partitions=4
04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
05)--------AggregateExec: mode=Partial, gby=[product@3 as product], aggr=[total_sales_q1, num_records_q1, total_sales_Q2, num_records_Q2]
06)----------ProjectionExec: expr=[CAST(sales@1 AS Int64) as __common_expr_1, quarter@2 = Q1 as __common_expr_2, quarter@2 = Q2 as __common_expr_3, product@0 as product]
07)------------DataSourceExec: partitions=1, partition_sizes=[1]

# Aggregation functions with multiple parameters
query TTT rowsort
SELECT product, sales, quarter FROM pipe_test
|> PIVOT(string_agg(sales, '_' order by sales) as agg FOR quarter IN ('Q1', 'Q2'));
----
Apple 1_77 0
Kale 51_70 23_85
54 changes: 54 additions & 0 deletions docs/source/user-guide/sql/select.md
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,8 @@ DataFusion currently supports the following pipe operators:
- [INTERSECT](#pipe_intersect)
- [EXCEPT](#pipe_except)
- [AGGREGATE](#pipe_aggregate)
- [PIVOT](#pipe_pivot)
- [JOIN](#pipe_join)

(pipe_where)=

Expand Down Expand Up @@ -514,3 +516,55 @@ select * from range(0,3)
| 3 |
+-------+
```

(pipe_pivot)=

### PIVOT

Rotates rows into columns.

```sql
> (
SELECT 'kale' AS product, 51 AS sales, 'Q1' AS quarter
UNION ALL
SELECT 'kale' AS product, 4 AS sales, 'Q1' AS quarter
UNION ALL
SELECT 'kale' AS product, 45 AS sales, 'Q2' AS quarter
UNION ALL
SELECT 'apple' AS product, 8 AS sales, 'Q1' AS quarter
UNION ALL
SELECT 'apple' AS product, 10 AS sales, 'Q2' AS quarter
)
|> PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2'));
+---------+-----+-----+
| product | Q1 | Q2 |
+---------+-----+-----+
| apple | 8 | 10 |
| kale | 55 | 45 |
+---------+-----+-----+
```

(pipe_join)=

### JOIN

```sql
> (
SELECT 'apples' AS item, 2 AS sales
UNION ALL
SELECT 'bananas' AS item, 5 AS sales
)
|> AS produce_sales
|> LEFT JOIN
(
SELECT 'apples' AS item, 123 AS id
) AS produce_data
ON produce_sales.item = produce_data.item
|> SELECT produce_sales.item, sales, id;
+--------+-------+------+
| item | sales | id |
+--------+-------+------+
| apples | 2 | 123 |
| bananas| 5 | NULL |
+--------+-------+------+
```