diff --git a/datafusion/physical-optimizer/src/enforce_distribution.rs b/datafusion/physical-optimizer/src/enforce_distribution.rs index f3ec083efb240..acb1c588097ee 100644 --- a/datafusion/physical-optimizer/src/enforce_distribution.rs +++ b/datafusion/physical-optimizer/src/enforce_distribution.rs @@ -36,7 +36,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; use datafusion_common::stats::Precision; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_expr::logical_plan::JoinType; +use datafusion_expr::logical_plan::{Aggregate, JoinType}; use datafusion_physical_expr::expressions::{Column, NoOp}; use datafusion_physical_expr::utils::map_columns_before_projection; use datafusion_physical_expr::{ @@ -1301,10 +1301,25 @@ pub fn ensure_distribution( // Allow subset satisfaction when: // 1. Current partition count >= threshold // 2. Not a partitioned join since must use exact hash matching for joins + // 3. Not a grouping set aggregate (requires exact hash including __grouping_id) let current_partitions = child.plan.output_partitioning().partition_count(); + + // Check if the hash partitioning requirement includes __grouping_id column. + // Grouping set aggregates (ROLLUP, CUBE, GROUPING SETS) require exact hash + // partitioning on all group columns including __grouping_id to ensure partial + // aggregates from different partitions are correctly combined. + let requires_grouping_id = matches!(&requirement, Distribution::HashPartitioned(exprs) + if exprs.iter().any(|expr| { + expr.as_any() + .downcast_ref::() + .is_some_and(|col| col.name() == Aggregate::INTERNAL_GROUPING_ID) + }) + ); + let allow_subset_satisfy_partitioning = current_partitions >= subset_satisfaction_threshold - && !is_partitioned_join; + && !is_partitioned_join + && !requires_grouping_id; // When `repartition_file_scans` is set, attempt to increase // parallelism at the source. diff --git a/datafusion/sqllogictest/test_files/grouping_set_repartition.slt b/datafusion/sqllogictest/test_files/grouping_set_repartition.slt new file mode 100644 index 0000000000000..16ab90651c8b3 --- /dev/null +++ b/datafusion/sqllogictest/test_files/grouping_set_repartition.slt @@ -0,0 +1,246 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +# Tests for ROLLUP/CUBE/GROUPING SETS with multiple partitions +# +# This tests the fix for https://github.com/apache/datafusion/issues/19849 +# where ROLLUP queries produced incorrect results with multiple partitions +# because subset partitioning satisfaction was incorrectly applied. +# +# The bug manifests when: +# 1. UNION ALL of subqueries each with hash-partitioned aggregates +# 2. Outer ROLLUP groups by more columns than inner hash partitioning +# 3. InterleaveExec preserves the inner hash partitioning +# 4. Optimizer incorrectly uses subset satisfaction, skipping necessary repartition +# +# The fix ensures that when hash partitioning includes __grouping_id, +# subset satisfaction is disabled and proper RepartitionExec is inserted. +########## + +########## +# SETUP: Create partitioned parquet files to simulate distributed data +########## + +statement ok +set datafusion.execution.target_partitions = 4; + +statement ok +set datafusion.optimizer.repartition_aggregations = true; + +# Create partition 1 +statement ok +COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES + ('store', 'nike', 100), + ('store', 'nike', 200), + ('store', 'adidas', 150) +)) +TO 'test_files/scratch/grouping_set_repartition/part=1/data.parquet' +STORED AS PARQUET; + +# Create partition 2 +statement ok +COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES + ('store', 'adidas', 250), + ('web', 'nike', 300), + ('web', 'nike', 400) +)) +TO 'test_files/scratch/grouping_set_repartition/part=2/data.parquet' +STORED AS PARQUET; + +# Create partition 3 +statement ok +COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES + ('web', 'adidas', 350), + ('web', 'adidas', 450), + ('catalog', 'nike', 500) +)) +TO 'test_files/scratch/grouping_set_repartition/part=3/data.parquet' +STORED AS PARQUET; + +# Create partition 4 +statement ok +COPY (SELECT column1 as channel, column2 as brand, column3 as amount FROM (VALUES + ('catalog', 'nike', 600), + ('catalog', 'adidas', 550), + ('catalog', 'adidas', 650) +)) +TO 'test_files/scratch/grouping_set_repartition/part=4/data.parquet' +STORED AS PARQUET; + +# Create external table pointing to the partitioned data +statement ok +CREATE EXTERNAL TABLE sales (channel VARCHAR, brand VARCHAR, amount INT) +STORED AS PARQUET +PARTITIONED BY (part INT) +LOCATION 'test_files/scratch/grouping_set_repartition/'; + +########## +# TEST 1: UNION ALL + ROLLUP pattern (similar to TPC-DS q14) +# This query pattern triggers the subset satisfaction bug because: +# - Each UNION ALL branch has hash partitioning on (brand) +# - The outer ROLLUP requires hash partitioning on (channel, brand, __grouping_id) +# - Without the fix, subset satisfaction incorrectly skips repartition +# +# Verify the physical plan includes RepartitionExec with __grouping_id +########## + +query TT +EXPLAIN SELECT channel, brand, SUM(total) as grand_total +FROM ( + SELECT 'store' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'store' + GROUP BY brand + UNION ALL + SELECT 'web' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'web' + GROUP BY brand + UNION ALL + SELECT 'catalog' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'catalog' + GROUP BY brand +) sub +GROUP BY ROLLUP(channel, brand) +ORDER BY channel NULLS FIRST, brand NULLS FIRST; +---- +logical_plan +01)Sort: sub.channel ASC NULLS FIRST, sub.brand ASC NULLS FIRST +02)--Projection: sub.channel, sub.brand, sum(sub.total) AS grand_total +03)----Aggregate: groupBy=[[ROLLUP (sub.channel, sub.brand)]], aggr=[[sum(sub.total)]] +04)------SubqueryAlias: sub +05)--------Union +06)----------Projection: Utf8("store") AS channel, sales.brand, sum(sales.amount) AS total +07)------------Aggregate: groupBy=[[sales.brand]], aggr=[[sum(CAST(sales.amount AS Int64))]] +08)--------------Projection: sales.brand, sales.amount +09)----------------Filter: sales.channel = Utf8View("store") +10)------------------TableScan: sales projection=[channel, brand, amount], partial_filters=[sales.channel = Utf8View("store")] +11)----------Projection: Utf8("web") AS channel, sales.brand, sum(sales.amount) AS total +12)------------Aggregate: groupBy=[[sales.brand]], aggr=[[sum(CAST(sales.amount AS Int64))]] +13)--------------Projection: sales.brand, sales.amount +14)----------------Filter: sales.channel = Utf8View("web") +15)------------------TableScan: sales projection=[channel, brand, amount], partial_filters=[sales.channel = Utf8View("web")] +16)----------Projection: Utf8("catalog") AS channel, sales.brand, sum(sales.amount) AS total +17)------------Aggregate: groupBy=[[sales.brand]], aggr=[[sum(CAST(sales.amount AS Int64))]] +18)--------------Projection: sales.brand, sales.amount +19)----------------Filter: sales.channel = Utf8View("catalog") +20)------------------TableScan: sales projection=[channel, brand, amount], partial_filters=[sales.channel = Utf8View("catalog")] +physical_plan +01)SortPreservingMergeExec: [channel@0 ASC, brand@1 ASC] +02)--SortExec: expr=[channel@0 ASC, brand@1 ASC], preserve_partitioning=[true] +03)----ProjectionExec: expr=[channel@0 as channel, brand@1 as brand, sum(sub.total)@3 as grand_total] +04)------AggregateExec: mode=FinalPartitioned, gby=[channel@0 as channel, brand@1 as brand, __grouping_id@2 as __grouping_id], aggr=[sum(sub.total)] +05)--------RepartitionExec: partitioning=Hash([channel@0, brand@1, __grouping_id@2], 4), input_partitions=4 +06)----------AggregateExec: mode=Partial, gby=[(NULL as channel, NULL as brand), (channel@0 as channel, NULL as brand), (channel@0 as channel, brand@1 as brand)], aggr=[sum(sub.total)] +07)------------InterleaveExec +08)--------------ProjectionExec: expr=[store as channel, brand@0 as brand, sum(sales.amount)@1 as total] +09)----------------AggregateExec: mode=FinalPartitioned, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +10)------------------RepartitionExec: partitioning=Hash([brand@0], 4), input_partitions=4 +11)--------------------AggregateExec: mode=Partial, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +12)----------------------FilterExec: channel@0 = store, projection=[brand@1, amount@2] +13)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=1/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=2/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=3/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=4/data.parquet]]}, projection=[channel, brand, amount], file_type=parquet, predicate=channel@0 = store, pruning_predicate=channel_null_count@2 != row_count@3 AND channel_min@0 <= store AND store <= channel_max@1, required_guarantees=[channel in (store)] +14)--------------ProjectionExec: expr=[web as channel, brand@0 as brand, sum(sales.amount)@1 as total] +15)----------------AggregateExec: mode=FinalPartitioned, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +16)------------------RepartitionExec: partitioning=Hash([brand@0], 4), input_partitions=4 +17)--------------------AggregateExec: mode=Partial, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +18)----------------------FilterExec: channel@0 = web, projection=[brand@1, amount@2] +19)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=1/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=2/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=3/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=4/data.parquet]]}, projection=[channel, brand, amount], file_type=parquet, predicate=channel@0 = web, pruning_predicate=channel_null_count@2 != row_count@3 AND channel_min@0 <= web AND web <= channel_max@1, required_guarantees=[channel in (web)] +20)--------------ProjectionExec: expr=[catalog as channel, brand@0 as brand, sum(sales.amount)@1 as total] +21)----------------AggregateExec: mode=FinalPartitioned, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +22)------------------RepartitionExec: partitioning=Hash([brand@0], 4), input_partitions=4 +23)--------------------AggregateExec: mode=Partial, gby=[brand@0 as brand], aggr=[sum(sales.amount)] +24)----------------------FilterExec: channel@0 = catalog, projection=[brand@1, amount@2] +25)------------------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=1/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=2/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=3/data.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/grouping_set_repartition/part=4/data.parquet]]}, projection=[channel, brand, amount], file_type=parquet, predicate=channel@0 = catalog, pruning_predicate=channel_null_count@2 != row_count@3 AND channel_min@0 <= catalog AND catalog <= channel_max@1, required_guarantees=[channel in (catalog)] + +query TTI rowsort +SELECT channel, brand, SUM(total) as grand_total +FROM ( + SELECT 'store' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'store' + GROUP BY brand + UNION ALL + SELECT 'web' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'web' + GROUP BY brand + UNION ALL + SELECT 'catalog' as channel, brand, SUM(amount) as total + FROM sales WHERE channel = 'catalog' + GROUP BY brand +) sub +GROUP BY ROLLUP(channel, brand) +ORDER BY channel NULLS FIRST, brand NULLS FIRST; +---- +NULL NULL 4500 +catalog NULL 2300 +catalog adidas 1200 +catalog nike 1100 +store NULL 700 +store adidas 400 +store nike 300 +web NULL 1500 +web adidas 800 +web nike 700 + +########## +# TEST 2: Simple ROLLUP (baseline test) +########## + +query TTI rowsort +SELECT channel, brand, SUM(amount) as total +FROM sales +GROUP BY ROLLUP(channel, brand) +ORDER BY channel NULLS FIRST, brand NULLS FIRST; +---- +NULL NULL 4500 +catalog NULL 2300 +catalog adidas 1200 +catalog nike 1100 +store NULL 700 +store adidas 400 +store nike 300 +web NULL 1500 +web adidas 800 +web nike 700 + +########## +# TEST 3: Verify CUBE also works correctly +########## + +query TTI rowsort +SELECT channel, brand, SUM(amount) as total +FROM sales +GROUP BY CUBE(channel, brand) +ORDER BY channel NULLS FIRST, brand NULLS FIRST; +---- +NULL NULL 4500 +NULL adidas 2400 +NULL nike 2100 +catalog NULL 2300 +catalog adidas 1200 +catalog nike 1100 +store NULL 700 +store adidas 400 +store nike 300 +web NULL 1500 +web adidas 800 +web nike 700 + +########## +# CLEANUP +########## + +statement ok +DROP TABLE sales;