Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions datafusion/physical-optimizer/src/enforce_distribution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use datafusion_common::config::ConfigOptions;
use datafusion_common::error::Result;
use datafusion_common::stats::Precision;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_expr::logical_plan::JoinType;
use datafusion_expr::logical_plan::{Aggregate, JoinType};
use datafusion_physical_expr::expressions::{Column, NoOp};
use datafusion_physical_expr::utils::map_columns_before_projection;
use datafusion_physical_expr::{
Expand Down Expand Up @@ -1301,10 +1301,25 @@ pub fn ensure_distribution(
// Allow subset satisfaction when:
// 1. Current partition count >= threshold
// 2. Not a partitioned join since must use exact hash matching for joins
// 3. Not a grouping set aggregate (requires exact hash including __grouping_id)
let current_partitions = child.plan.output_partitioning().partition_count();

// Check if the hash partitioning requirement includes __grouping_id column.
// Grouping set aggregates (ROLLUP, CUBE, GROUPING SETS) require exact hash
// partitioning on all group columns including __grouping_id to ensure partial
// aggregates from different partitions are correctly combined.
let requires_grouping_id = matches!(&requirement, Distribution::HashPartitioned(exprs)
if exprs.iter().any(|expr| {
expr.as_any()
.downcast_ref::<Column>()
.is_some_and(|col| col.name() == Aggregate::INTERNAL_GROUPING_ID)
})
);

let allow_subset_satisfy_partitioning = current_partitions
>= subset_satisfaction_threshold
&& !is_partitioned_join;
&& !is_partitioned_join
&& !requires_grouping_id;

// When `repartition_file_scans` is set, attempt to increase
// parallelism at the source.
Expand Down
246 changes: 246 additions & 0 deletions datafusion/sqllogictest/test_files/grouping_set_repartition.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

##########
# Tests for ROLLUP/CUBE/GROUPING SETS with multiple partitions
#
# This tests the fix for https://github.com/apache/datafusion/issues/19849
# where ROLLUP queries produced incorrect results with multiple partitions
# because subset partitioning satisfaction was incorrectly applied.
#
# The bug manifests when:
# 1. UNION ALL of subqueries each with hash-partitioned aggregates
# 2. Outer ROLLUP groups by more columns than inner hash partitioning
# 3. InterleaveExec preserves the inner hash partitioning
# 4. Optimizer incorrectly uses subset satisfaction, skipping necessary repartition
#
# The fix ensures that when hash partitioning includes __grouping_id,
# subset satisfaction is disabled and proper RepartitionExec is inserted.
##########

##########
# SETUP: Create partitioned parquet files to simulate distributed data
##########

statement ok
set datafusion.execution.target_partitions = 4;

statement ok
set datafusion.optimizer.repartition_aggregations = true;

# Create partition 1
statement ok
COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES
('store', 'nike', 100),
('store', 'nike', 200),
('store', 'adidas', 150)
))
TO 'test_files/scratch/grouping_set_repartition/part=1/data.parquet'
STORED AS PARQUET;

# Create partition 2
statement ok
COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES
('store', 'adidas', 250),
('web', 'nike', 300),
('web', 'nike', 400)
))
TO 'test_files/scratch/grouping_set_repartition/part=2/data.parquet'
STORED AS PARQUET;

# Create partition 3
statement ok
COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES
('web', 'adidas', 350),
('web', 'adidas', 450),
('catalog', 'nike', 500)
))
TO 'test_files/scratch/grouping_set_repartition/part=3/data.parquet'
STORED AS PARQUET;

# Create partition 4
statement ok
COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES
('catalog', 'nike', 600),
('catalog', 'adidas', 550),
('catalog', 'adidas', 650)
))
TO 'test_files/scratch/grouping_set_repartition/part=4/data.parquet'
STORED AS PARQUET;

# Create external table pointing to the partitioned data
statement ok
CREATE EXTERNAL TABLE sales (channel VARCHAR, brand VARCHAR, amount INT)
STORED AS PARQUET
PARTITIONED BY (part INT)
LOCATION 'test_files/scratch/grouping_set_repartition/';

##########
# TEST 1: UNION ALL + ROLLUP pattern (similar to TPC-DS q14)
# This query pattern triggers the subset satisfaction bug because:
# - Each UNION ALL branch has hash partitioning on (brand)
# - The outer ROLLUP requires hash partitioning on (channel, brand, __grouping_id)
# - Without the fix, subset satisfaction incorrectly skips repartition
#
# Verify the physical plan includes RepartitionExec with __grouping_id
##########

query TT
EXPLAIN SELECT channel, brand, SUM(total) as grand_total
FROM (
SELECT 'store' as channel, brand, SUM(amount) as total
FROM sales WHERE channel = 'store'
GROUP BY brand
UNION ALL
SELECT 'web' as channel, brand, SUM(amount) as total
FROM sales WHERE channel = 'web'
GROUP BY brand
UNION ALL
SELECT 'catalog' as channel, brand, SUM(amount) as total
FROM sales WHERE channel = 'catalog'
GROUP BY brand
) sub
GROUP BY ROLLUP(channel, brand)
ORDER BY channel NULLS FIRST, brand NULLS FIRST;
----
logical_plan
01)Sort: sub.channel ASC NULLS FIRST, sub.brand ASC NULLS FIRST
02)--Projection: sub.channel, sub.brand, sum(sub.total) AS grand_total
03)----Aggregate: groupBy=[[ROLLUP (sub.channel, sub.brand)]], aggr=[[sum(sub.total)]]
04)------SubqueryAlias: sub
05)--------Union
06)----------Projection: Utf8("store") AS channel, sales.brand, sum(sales.amount) AS total
07)------------Aggregate: groupBy=[[sales.brand]], aggr=[[sum(CAST(sales.amount AS Int64))]]
08)--------------Projection: sales.brand, sales.amount
09)----------------Filter: sales.channel = Utf8View("store")
10)------------------TableScan: sales projection=[channel, brand, amount], partial_filters=[sales.channel = Utf8View("store")]
11)----------Projection: Utf8("web") AS channel, sales.brand, sum(sales.amount) AS total
12)------------Aggregate: groupBy=[[sales.brand]], aggr=[[sum(CAST(sales.amount AS Int64))]]
13)--------------Projection: sales.brand, sales.amount
14)----------------Filter: sales.channel = Utf8View("web")
15)------------------TableScan: sales projection=[channel, brand, amount], partial_filters=[sales.channel = Utf8View("web")]
16)----------Projection: Utf8("catalog") AS channel, sales.brand, sum(sales.amount) AS total
17)------------Aggregate: groupBy=[[sales.brand]], aggr=[[sum(CAST(sales.amount AS Int64))]]
18)--------------Projection: sales.brand, sales.amount
19)----------------Filter: sales.channel = Utf8View("catalog")
20)------------------TableScan: sales projection=[channel, brand, amount], partial_filters=[sales.channel = Utf8View("catalog")]
physical_plan
01)SortPreservingMergeExec: [channel@0 ASC, brand@1 ASC]
02)--SortExec: expr=[channel@0 ASC, brand@1 ASC], preserve_partitioning=[true]
03)----ProjectionExec: expr=[channel@0 as channel, brand@1 as brand, sum(sub.total)@3 as grand_total]
04)------AggregateExec: mode=FinalPartitioned, gby=[channel@0 as channel, brand@1 as brand, __grouping_id@2 as __grouping_id], aggr=[sum(sub.total)]
05)--------RepartitionExec: partitioning=Hash([channel@0, brand@1, __grouping_id@2], 4), input_partitions=4
06)----------AggregateExec: mode=Partial, gby=[(NULL as channel, NULL as brand), (channel@0 as channel, NULL as brand), (channel@0 as channel, brand@1 as brand)], aggr=[sum(sub.total)]
Comment on lines +145 to +147
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 I imagine before this PR the RepartitionExec would not be there right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct

07)------------InterleaveExec
08)--------------ProjectionExec: expr=[store as channel, brand@0 as brand, sum(sales.amount)@1 as total]
09)----------------AggregateExec: mode=FinalPartitioned, gby=[brand@0 as brand], aggr=[sum(sales.amount)]
10)------------------RepartitionExec: partitioning=Hash([brand@0], 4), input_partitions=4
11)--------------------AggregateExec: mode=Partial, gby=[brand@0 as brand], aggr=[sum(sales.amount)]
12)----------------------FilterExec: channel@0 = store, projection=[brand@1, amount@2]
13)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=1/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=2/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=3/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=4/data.parquet]]}, projection=[channel, brand, amount], file_type=parquet, predicate=channel@0 = store, pruning_predicate=channel_null_count@2 != row_count@3 AND channel_min@0 <= store AND store <= channel_max@1, required_guarantees=[channel in (store)]
14)--------------ProjectionExec: expr=[web as channel, brand@0 as brand, sum(sales.amount)@1 as total]
15)----------------AggregateExec: mode=FinalPartitioned, gby=[brand@0 as brand], aggr=[sum(sales.amount)]
16)------------------RepartitionExec: partitioning=Hash([brand@0], 4), input_partitions=4
17)--------------------AggregateExec: mode=Partial, gby=[brand@0 as brand], aggr=[sum(sales.amount)]
18)----------------------FilterExec: channel@0 = web, projection=[brand@1, amount@2]
19)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=1/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=2/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=3/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=4/data.parquet]]}, projection=[channel, brand, amount], file_type=parquet, predicate=channel@0 = web, pruning_predicate=channel_null_count@2 != row_count@3 AND channel_min@0 <= web AND web <= channel_max@1, required_guarantees=[channel in (web)]
20)--------------ProjectionExec: expr=[catalog as channel, brand@0 as brand, sum(sales.amount)@1 as total]
21)----------------AggregateExec: mode=FinalPartitioned, gby=[brand@0 as brand], aggr=[sum(sales.amount)]
22)------------------RepartitionExec: partitioning=Hash([brand@0], 4), input_partitions=4
23)--------------------AggregateExec: mode=Partial, gby=[brand@0 as brand], aggr=[sum(sales.amount)]
24)----------------------FilterExec: channel@0 = catalog, projection=[brand@1, amount@2]
25)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=1/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=2/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=3/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=4/data.parquet]]}, projection=[channel, brand, amount], file_type=parquet, predicate=channel@0 = catalog, pruning_predicate=channel_null_count@2 != row_count@3 AND channel_min@0 <= catalog AND catalog <= channel_max@1, required_guarantees=[channel in (catalog)]

query TTI rowsort
SELECT channel, brand, SUM(total) as grand_total
FROM (
SELECT 'store' as channel, brand, SUM(amount) as total
FROM sales WHERE channel = 'store'
GROUP BY brand
UNION ALL
SELECT 'web' as channel, brand, SUM(amount) as total
FROM sales WHERE channel = 'web'
GROUP BY brand
UNION ALL
SELECT 'catalog' as channel, brand, SUM(amount) as total
FROM sales WHERE channel = 'catalog'
GROUP BY brand
) sub
GROUP BY ROLLUP(channel, brand)
ORDER BY channel NULLS FIRST, brand NULLS FIRST;
----
NULL NULL 4500
catalog NULL 2300
catalog adidas 1200
catalog nike 1100
store NULL 700
store adidas 400
store nike 300
web NULL 1500
web adidas 800
web nike 700

##########
# TEST 2: Simple ROLLUP (baseline test)
##########

query TTI rowsort
SELECT channel, brand, SUM(amount) as total
FROM sales
GROUP BY ROLLUP(channel, brand)
ORDER BY channel NULLS FIRST, brand NULLS FIRST;
----
NULL NULL 4500
catalog NULL 2300
catalog adidas 1200
catalog nike 1100
store NULL 700
store adidas 400
store nike 300
web NULL 1500
web adidas 800
web nike 700

##########
# TEST 3: Verify CUBE also works correctly
##########

query TTI rowsort
SELECT channel, brand, SUM(amount) as total
FROM sales
GROUP BY CUBE(channel, brand)
ORDER BY channel NULLS FIRST, brand NULLS FIRST;
----
NULL NULL 4500
NULL adidas 2400
NULL nike 2100
catalog NULL 2300
catalog adidas 1200
catalog nike 1100
store NULL 700
store adidas 400
store nike 300
web NULL 1500
web adidas 800
web nike 700

##########
# CLEANUP
##########

statement ok
DROP TABLE sales;