Skip to content

Commit

Permalink
Fix column indices in EnforceDistribution optimizer in Partial Aggreg…
Browse files Browse the repository at this point in the history
…ateMode (#4878)

* Fix column indices in EnforceDistribution optimizer in Partial AggregateMode

Column expressions need to be updated to correspond with the partial aggregation schema rather than the input schema.

* Simplify new_group_by calculation
  • Loading branch information
jonmmease authored Jan 14, 2023
1 parent a9ddcd3 commit dee0dd8
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 7 deletions.
25 changes: 18 additions & 7 deletions datafusion/core/src/physical_optimizer/dist_enforcement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,11 +431,22 @@ fn reorder_aggregate_keys(
None
};
if let Some(partial_agg) = new_partial_agg {
let mut new_group_exprs = vec![];
for idx in positions.into_iter() {
new_group_exprs.push(group_by.expr()[idx].clone());
}
let new_group_by = PhysicalGroupBy::new_single(new_group_exprs);
// Build new group expressions that correspond to the output of partial_agg
let new_final_group: Vec<Arc<dyn PhysicalExpr>> =
partial_agg.output_group_expr();
let new_group_by = PhysicalGroupBy::new_single(
new_final_group
.iter()
.enumerate()
.map(|(i, expr)| {
(
expr.clone(),
partial_agg.group_expr().expr()[i].1.clone(),
)
})
.collect(),
);

let new_final_agg = Arc::new(AggregateExec::try_new(
AggregateMode::FinalPartitioned,
new_group_by,
Expand Down Expand Up @@ -1494,7 +1505,7 @@ mod tests {
let expected = &[
"HashJoinExec: mode=Partitioned, join_type=Inner, on=[(Column { name: \"b1\", index: 1 }, Column { name: \"b\", index: 0 }), (Column { name: \"a1\", index: 0 }, Column { name: \"a\", index: 1 })]",
"ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]",
"AggregateExec: mode=FinalPartitioned, gby=[b1@1 as b1, a1@0 as a1], aggr=[]",
"AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[]",
"RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 0 }, Column { name: \"a1\", index: 1 }], 10)",
"AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]",
"ParquetExec: limit=None, partitions={1 group: [[x]]}, projection=[a, b, c, d, e]",
Expand Down Expand Up @@ -2057,7 +2068,7 @@ mod tests {
"SortExec: [b3@1 ASC,a3@0 ASC]",
"ProjectionExec: expr=[a1@0 as a3, b1@1 as b3]",
"ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]",
"AggregateExec: mode=FinalPartitioned, gby=[b1@1 as b1, a1@0 as a1], aggr=[]",
"AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[]",
"RepartitionExec: partitioning=Hash([Column { name: \"b1\", index: 0 }, Column { name: \"a1\", index: 1 }], 10)",
"AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]",
"ParquetExec: limit=None, partitions={1 group: [[x]]}, projection=[a, b, c, d, e]",
Expand Down
58 changes: 58 additions & 0 deletions datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2810,3 +2810,61 @@ async fn type_coercion_join_with_filter_and_equi_expr() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn test_cross_join_to_groupby_with_different_key_ordering() -> Result<()> {
// Regression test for GH #4873
let col1 = Arc::new(StringArray::from(vec![
"A", "A", "A", "A", "A", "A", "A", "A", "BB", "BB", "BB", "BB",
])) as ArrayRef;

let col2 =
Arc::new(UInt64Array::from(vec![1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6])) as ArrayRef;

let col3 =
Arc::new(UInt64Array::from(vec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])) as ArrayRef;

let schema = Arc::new(Schema::new(vec![
Field::new("col1", DataType::Utf8, true),
Field::new("col2", DataType::UInt64, true),
Field::new("col3", DataType::UInt64, true),
])) as SchemaRef;

let batch = RecordBatch::try_new(schema.clone(), vec![col1, col2, col3]).unwrap();
let mem_table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();

// Create context and register table
let ctx = SessionContext::new();
ctx.register_table("tbl", Arc::new(mem_table)).unwrap();

let sql = "select col1, col2, coalesce(sum_col3, 0) as sum_col3 \
from (select distinct col2 from tbl) AS q1 \
cross join (select distinct col1 from tbl) AS q2 \
left outer join (SELECT col1, col2, sum(col3) as sum_col3 FROM tbl GROUP BY col1, col2) AS q3 \
USING(col2, col1) \
ORDER BY col1, col2";

let expected = vec![
"+------+------+----------+",
"| col1 | col2 | sum_col3 |",
"+------+------+----------+",
"| A | 1 | 2 |",
"| A | 2 | 2 |",
"| A | 3 | 2 |",
"| A | 4 | 2 |",
"| A | 5 | 0 |",
"| A | 6 | 0 |",
"| BB | 1 | 0 |",
"| BB | 2 | 0 |",
"| BB | 3 | 0 |",
"| BB | 4 | 0 |",
"| BB | 5 | 2 |",
"| BB | 6 | 2 |",
"+------+------+----------+",
];

let results = execute_to_batches(&ctx, sql).await;
assert_batches_sorted_eq!(expected, &results);

Ok(())
}

0 comments on commit dee0dd8

Please sign in to comment.