Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use datafusion_physical_expr_common::sort_expr::{
LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortExpr,
PhysicalSortRequirement,
};
use datafusion_physical_plan::aggregates::AggregateExec;
use datafusion_physical_plan::execution_plan::CardinalityEffect;
use datafusion_physical_plan::filter::FilterExec;
use datafusion_physical_plan::joins::utils::{
Expand Down Expand Up @@ -353,6 +354,8 @@ fn pushdown_requirement_to_children(
Ok(None)
}
}
} else if let Some(aggregate_exec) = plan.as_any().downcast_ref::<AggregateExec>() {
handle_aggregate_pushdown(aggregate_exec, parent_required)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment explains why this is needed: #19287 (comment)

Basically the generic version is not correct for AggregateExec, but this was masked due to some limitations that were lifted in #19287

} else if maintains_input_order.is_empty()
|| !maintains_input_order.iter().any(|o| *o)
|| plan.as_any().is::<RepartitionExec>()
Expand Down Expand Up @@ -388,6 +391,77 @@ fn pushdown_requirement_to_children(
// TODO: Add support for Projection push down
}

/// Try to push sorting through [`AggregateExec`]
///
/// `AggregateExec` only preserves the input order of its group by columns
/// (not aggregates in general, which are formed from arbitrary expressions over
/// input)
///
/// Thus function rewrites the parent required ordering in terms of the
/// aggregate input if possible. This rewritten requirement represents the
/// ordering of the `AggregateExec`'s **input** that would also satisfy the
/// **parent** ordering.
///
/// If no such mapping is possible (e.g. because the sort references aggregate
/// columns), returns None.
fn handle_aggregate_pushdown(
aggregate_exec: &AggregateExec,
parent_required: OrderingRequirements,
) -> Result<Option<Vec<Option<OrderingRequirements>>>> {
if !aggregate_exec
.maintains_input_order()
.into_iter()
.any(|o| o)
{
return Ok(None);
}

let group_expr = aggregate_exec.group_expr();
// GROUPING SETS introduce additional output columns and NULL substitutions;
// skip pushdown until we can map those cases safely.
if group_expr.has_grouping_set() {
return Ok(None);
}

let group_input_exprs = group_expr.input_exprs();
let parent_requirement = parent_required.into_single();
let mut child_requirement = Vec::with_capacity(parent_requirement.len());

for req in parent_requirement {
// Sort above AggregateExec should reference its output columns. Map each
// output group-by column to its original input expression.
let Some(column) = req.expr.as_any().downcast_ref::<Column>() else {
return Ok(None);
};
if column.index() >= group_input_exprs.len() {
// AggregateExec does not produce output that is sorted on aggregate
// columns so those can not be pushed through.
return Ok(None);
}
child_requirement.push(PhysicalSortRequirement::new(
Arc::clone(&group_input_exprs[column.index()]),
req.options,
));
}

let Some(child_requirement) = LexRequirement::new(child_requirement) else {
return Ok(None);
};

// Keep sort above aggregate unless input ordering already satisfies the
// mapped requirement.
if aggregate_exec
.input()
.equivalence_properties()
.ordering_satisfy_requirement(child_requirement.iter().cloned())?
{
let child_requirements = OrderingRequirements::new(child_requirement);
Ok(Some(vec![Some(child_requirements)]))
} else {
Ok(None)
}
}

/// Return true if pushing the sort requirements through a node would violate
/// the input sorting requirements for the plan
fn pushdown_would_violate_requirements(
Expand Down
184 changes: 184 additions & 0 deletions datafusion/sqllogictest/test_files/sort_pushdown.slt
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,184 @@ LIMIT 3;
5 4
2 -3

# Test 3.7: Aggregate ORDER BY expression should keep SortExec
# Source pattern declared on parquet scan: [x ASC, y ASC].
# Requested pattern in ORDER BY: [x ASC, CAST(y AS BIGINT) % 2 ASC].
# Example for x=1 input y order 1,2,3 gives bucket order 1,0,1, which does not
# match requested bucket ASC order. SortExec is required above AggregateExec.
statement ok
SET datafusion.execution.target_partitions = 1;

statement ok
CREATE TABLE agg_expr_data(x INT, y INT, v INT) AS VALUES
(1, 1, 10),
(1, 2, 20),
(1, 3, 30),
(2, 1, 40),
(2, 2, 50),
(2, 3, 60);

query I
COPY (SELECT * FROM agg_expr_data ORDER BY x, y)
TO 'test_files/scratch/sort_pushdown/agg_expr_sorted.parquet';
----
6

statement ok
CREATE EXTERNAL TABLE agg_expr_parquet(x INT, y INT, v INT)
STORED AS PARQUET
LOCATION 'test_files/scratch/sort_pushdown/agg_expr_sorted.parquet'
WITH ORDER (x ASC, y ASC);

query TT
EXPLAIN SELECT
x,
CAST(y AS BIGINT) % 2,
SUM(v)
FROM agg_expr_parquet
GROUP BY x, CAST(y AS BIGINT) % 2
ORDER BY x, CAST(y AS BIGINT) % 2;
----
logical_plan
01)Sort: agg_expr_parquet.x ASC NULLS LAST, agg_expr_parquet.y % Int64(2) ASC NULLS LAST
02)--Aggregate: groupBy=[[agg_expr_parquet.x, CAST(agg_expr_parquet.y AS Int64) % Int64(2)]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]]
03)----TableScan: agg_expr_parquet projection=[x, y, v]
physical_plan
01)SortExec: expr=[x@0 ASC NULLS LAST, agg_expr_parquet.y % Int64(2)@1 ASC NULLS LAST], preserve_partitioning=[false]
02)--AggregateExec: mode=Single, gby=[x@0 as x, CAST(y@1 AS Int64) % 2 as agg_expr_parquet.y % Int64(2)], aggr=[sum(agg_expr_parquet.v)], ordering_mode=PartiallySorted([0])
03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, y, v], output_ordering=[x@0 ASC NULLS LAST, y@1 ASC NULLS LAST], file_type=parquet

# Expected output pattern from ORDER BY [x, bucket]:
# rows grouped by x, and within each x bucket appears as 0 then 1.
query III
SELECT
x,
CAST(y AS BIGINT) % 2,
SUM(v)
FROM agg_expr_parquet
GROUP BY x, CAST(y AS BIGINT) % 2
ORDER BY x, CAST(y AS BIGINT) % 2;
----
1 0 20
1 1 40
2 0 50
2 1 100

# Test 3.8: Aggregate ORDER BY monotonic expression can push down (no SortExec)
query TT
EXPLAIN SELECT
x,
CAST(y AS BIGINT),
SUM(v)
FROM agg_expr_parquet
GROUP BY x, CAST(y AS BIGINT)
ORDER BY x, CAST(y AS BIGINT);
----
logical_plan
01)Sort: agg_expr_parquet.x ASC NULLS LAST, agg_expr_parquet.y ASC NULLS LAST
02)--Aggregate: groupBy=[[agg_expr_parquet.x, CAST(agg_expr_parquet.y AS Int64)]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]]
03)----TableScan: agg_expr_parquet projection=[x, y, v]
physical_plan
01)AggregateExec: mode=Single, gby=[x@0 as x, CAST(y@1 AS Int64) as agg_expr_parquet.y], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted
02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, y, v], output_ordering=[x@0 ASC NULLS LAST, y@1 ASC NULLS LAST], file_type=parquet

query III
SELECT
x,
CAST(y AS BIGINT),
SUM(v)
FROM agg_expr_parquet
GROUP BY x, CAST(y AS BIGINT)
ORDER BY x, CAST(y AS BIGINT);
----
1 1 10
1 2 20
1 3 30
2 1 40
2 2 50
2 3 60

# Test 3.9: Aggregate ORDER BY aggregate output should keep SortExec
query TT
EXPLAIN SELECT x, SUM(v)
FROM agg_expr_parquet
GROUP BY x
ORDER BY SUM(v);
----
logical_plan
01)Sort: sum(agg_expr_parquet.v) ASC NULLS LAST
02)--Aggregate: groupBy=[[agg_expr_parquet.x]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]]
03)----TableScan: agg_expr_parquet projection=[x, v]
physical_plan
01)SortExec: expr=[sum(agg_expr_parquet.v)@1 ASC NULLS LAST], preserve_partitioning=[false]
02)--AggregateExec: mode=Single, gby=[x@0 as x], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted
03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, v], output_ordering=[x@0 ASC NULLS LAST], file_type=parquet

query II
SELECT x, SUM(v)
FROM agg_expr_parquet
GROUP BY x
ORDER BY SUM(v);
----
1 60
2 150

# Test 3.10: Aggregate with non-preserved input order should keep SortExec
# v is not part of the order by
query TT
EXPLAIN SELECT v, SUM(y)
FROM agg_expr_parquet
GROUP BY v
ORDER BY v;
----
logical_plan
01)Sort: agg_expr_parquet.v ASC NULLS LAST
02)--Aggregate: groupBy=[[agg_expr_parquet.v]], aggr=[[sum(CAST(agg_expr_parquet.y AS Int64))]]
03)----TableScan: agg_expr_parquet projection=[y, v]
physical_plan
01)SortExec: expr=[v@0 ASC NULLS LAST], preserve_partitioning=[false]
02)--AggregateExec: mode=Single, gby=[v@1 as v], aggr=[sum(agg_expr_parquet.y)]
03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[y, v], file_type=parquet

query II
SELECT v, SUM(y)
FROM agg_expr_parquet
GROUP BY v
ORDER BY v;
----
10 1
20 2
30 3
40 1
50 2
60 3

# Test 3.11: Aggregate ORDER BY non-column expression (unsatisfied) keeps SortExec
# (though note in theory DataFusion could figure out that data sorted by x will also be sorted by x+1)
query TT
EXPLAIN SELECT x, SUM(v)
FROM agg_expr_parquet
GROUP BY x
ORDER BY x + 1 DESC;
----
logical_plan
01)Sort: CAST(agg_expr_parquet.x AS Int64) + Int64(1) DESC NULLS FIRST
02)--Aggregate: groupBy=[[agg_expr_parquet.x]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]]
03)----TableScan: agg_expr_parquet projection=[x, v]
physical_plan
01)SortExec: expr=[CAST(x@0 AS Int64) + 1 DESC], preserve_partitioning=[false]
02)--AggregateExec: mode=Single, gby=[x@0 as x], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted
03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, v], output_ordering=[x@0 ASC NULLS LAST], file_type=parquet

query II
SELECT x, SUM(v)
FROM agg_expr_parquet
GROUP BY x
ORDER BY x + 1 DESC;
----
2 150
1 60

# Cleanup
statement ok
DROP TABLE timestamp_data;
Expand Down Expand Up @@ -882,5 +1060,11 @@ DROP TABLE signed_data;
statement ok
DROP TABLE signed_parquet;

statement ok
DROP TABLE agg_expr_data;

statement ok
DROP TABLE agg_expr_parquet;

statement ok
SET datafusion.optimizer.enable_sort_pushdown = true;