Skip to content

Commit

Permalink
[CALCITE-6642] AggregateUnionTransposeRule should account for changes…
Browse files Browse the repository at this point in the history
… in nullability of pushed down aggregates
  • Loading branch information
arkanovicz authored and rubenada committed Oct 27, 2024
1 parent 5b4f32a commit add1f8b
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,32 @@ public AggregateUnionTransposeRule(Class<? extends Aggregate> aggregateClass,

// create corresponding aggregates on top of each union child
final RelBuilder relBuilder = call.builder();
RelDataType origUnionType = union.getRowType();
for (RelNode input : union.getInputs()) {
List<AggregateCall> childAggCalls = new ArrayList<>(aggRel.getAggCallList());
// if the nullability of a specific input column differs from the nullability
// of the union'ed column, we need to re-evaluate the nullability of the aggregate
RelDataType inputRowType = input.getRowType();
for (int i = 0; i < childAggCalls.size(); ++i) {
AggregateCall origCall = aggRel.getAggCallList().get(i);
if (origCall.getAggregation() == SqlStdOperatorTable.COUNT) {
continue;
}
assert origCall.getArgList().size() == 1;
int field = origCall.getArgList().get(0);
if (origUnionType.getFieldList().get(field).getType().isNullable()
!= inputRowType.getFieldList().get(field).getType().isNullable()) {
AggregateCall newCall =
AggregateCall.create(origCall.getParserPosition(), origCall.getAggregation(),
origCall.isDistinct(), origCall.isApproximate(), origCall.ignoreNulls(),
origCall.rexList, origCall.getArgList(), -1, origCall.distinctKeys,
origCall.collation, groupCount, input, null, origCall.getName());
childAggCalls.set(i, newCall);
}
}
relBuilder.push(input);
relBuilder.aggregate(relBuilder.groupKey(aggRel.getGroupSet()),
aggRel.getAggCallList());
childAggCalls);
}

// create a new union whose children are the aggregates created above
Expand Down
15 changes: 15 additions & 0 deletions core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7434,6 +7434,21 @@ private void checkSemiJoinRuleOnAntiJoin(RelOptRule rule) {
.check();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-6642">[CALCITE-6642]
* AggregateUnionTransposeRule throws an assertion error when creating children aggregates
* when one input only has a non-nullable column</a>. */
@Test void testAggregateUnionTransposeWithOneInputNonNullable() {
final String sql = "select deptno, SUM(t) from (\n"
+ "select deptno, 1 as t from sales.emp e1\n"
+ "union all\n"
+ "select deptno, nullif(sal, 0) as t from sales.emp e2)\n"
+ "group by deptno";
sql(sql)
.withRule(CoreRules.AGGREGATE_UNION_TRANSPOSE)
.check();
}

/** If all inputs to UNION are already unique, AggregateUnionTransposeRule is
* a no-op. */
@Test void testAggregateUnionTransposeWithAllInputsUnique() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,37 @@ LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])
LogicalAggregate(group=[{0, 1}])
LogicalProject(DEPTNO=[$7], T=[2])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
<TestCase name="testAggregateUnionTransposeWithOneInputNonNullable">
<Resource name="sql">
<![CDATA[select deptno, SUM(t) from (
select deptno, 1 as t from sales.emp e1
union all
select deptno, nullif(sal, 0) as t from sales.emp e2)
group by deptno]]>
</Resource>
<Resource name="planBefore">
<![CDATA[
LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])
LogicalUnion(all=[true])
LogicalProject(DEPTNO=[$7], T=[1])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalProject(DEPTNO=[$7], T=[CASE(=($5, 0), null:INTEGER, $5)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
<Resource name="planAfter">
<![CDATA[
LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])
LogicalUnion(all=[true])
LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])
LogicalProject(DEPTNO=[$7], T=[1])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)])
LogicalProject(DEPTNO=[$7], T=[CASE(=($5, 0), null:INTEGER, $5)])
LogicalTableScan(table=[[CATALOG, SALES, EMP]])
]]>
</Resource>
</TestCase>
Expand Down

0 comments on commit add1f8b

Please sign in to comment.