Skip to content

Commit 694aaee

Browse files
author
shuo.cs
committed
fix test
1 parent 0573b65 commit 694aaee

File tree

4 files changed

+34
-6
lines changed

4 files changed

+34
-6
lines changed

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/AvgAggFunction.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
2525
import org.apache.flink.table.types.DataType;
2626
import org.apache.flink.table.types.logical.DecimalType;
27+
import org.apache.flink.table.types.logical.LogicalTypeRoot;
2728
import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
2829

2930
import java.math.BigDecimal;
@@ -95,7 +96,11 @@ public Expression[] mergeExpressions() {
9596
}
9697

9798
private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) {
98-
return cast(sumExpr, typeLiteral(getSumType()));
99+
if (getResultType().getLogicalType().getTypeRoot() == LogicalTypeRoot.DECIMAL) {
100+
return cast(sumExpr, typeLiteral(getResultType()));
101+
} else {
102+
return sumExpr;
103+
}
99104
}
100105

101106
/** If all input are nulls, count will be 0 and we will get null after the division. */

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/Sum0AggFunction.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,23 @@
2020

2121
import org.apache.flink.table.api.DataTypes;
2222
import org.apache.flink.table.expressions.Expression;
23+
import org.apache.flink.table.expressions.UnresolvedCallExpression;
2324
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
2425
import org.apache.flink.table.types.DataType;
2526
import org.apache.flink.table.types.logical.DecimalType;
27+
import org.apache.flink.table.types.logical.LogicalTypeRoot;
2628
import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
2729

2830
import java.math.BigDecimal;
2931

3032
import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
33+
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.cast;
3134
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse;
3235
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull;
3336
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.literal;
3437
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.minus;
3538
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus;
39+
import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral;
3640

3741
/** built-in sum0 aggregate function. */
3842
public abstract class Sum0AggFunction extends DeclarativeAggregateFunction {
@@ -56,20 +60,29 @@ public DataType[] getAggBufferTypes() {
5660
@Override
5761
public Expression[] accumulateExpressions() {
5862
return new Expression[] {
59-
/* sum0 = */ ifThenElse(isNull(operand(0)), sum0, plus(sum0, operand(0)))
63+
/* sum0 = */ adjustSumType(ifThenElse(isNull(operand(0)), sum0, plus(sum0, operand(0))))
6064
};
6165
}
6266

6367
@Override
6468
public Expression[] retractExpressions() {
6569
return new Expression[] {
66-
/* sum0 = */ ifThenElse(isNull(operand(0)), sum0, minus(sum0, operand(0)))
70+
/* sum0 = */ adjustSumType(
71+
ifThenElse(isNull(operand(0)), sum0, minus(sum0, operand(0))))
6772
};
6873
}
6974

7075
@Override
7176
public Expression[] mergeExpressions() {
72-
return new Expression[] {/* sum0 = */ plus(sum0, mergeOperand(sum0))};
77+
return new Expression[] {/* sum0 = */ adjustSumType(plus(sum0, mergeOperand(sum0)))};
78+
}
79+
80+
private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) {
81+
if (getResultType().getLogicalType().getTypeRoot() == LogicalTypeRoot.DECIMAL) {
82+
return cast(sumExpr, typeLiteral(getResultType()));
83+
} else {
84+
return sumExpr;
85+
}
7386
}
7487

7588
@Override

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
2626
import org.apache.flink.table.types.DataType;
2727
import org.apache.flink.table.types.logical.DecimalType;
28+
import org.apache.flink.table.types.logical.LogicalTypeRoot;
2829
import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
2930

3031
import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
@@ -89,7 +90,11 @@ public Expression[] mergeExpressions() {
8990
}
9091

9192
private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) {
92-
return cast(sumExpr, typeLiteral(getResultType()));
93+
if (getResultType().getLogicalType().getTypeRoot() == LogicalTypeRoot.DECIMAL) {
94+
return cast(sumExpr, typeLiteral(getResultType()));
95+
} else {
96+
return sumExpr;
97+
}
9398
}
9499

95100
@Override

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumWithRetractAggFunction.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
2525
import org.apache.flink.table.types.DataType;
2626
import org.apache.flink.table.types.logical.DecimalType;
27+
import org.apache.flink.table.types.logical.LogicalTypeRoot;
2728
import org.apache.flink.table.types.logical.utils.LogicalTypeMerging;
2829

2930
import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
@@ -103,7 +104,11 @@ public Expression[] mergeExpressions() {
103104
}
104105

105106
private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) {
106-
return cast(sumExpr, typeLiteral(getResultType()));
107+
if (getResultType().getLogicalType().getTypeRoot() == LogicalTypeRoot.DECIMAL) {
108+
return cast(sumExpr, typeLiteral(getResultType()));
109+
} else {
110+
return sumExpr;
111+
}
107112
}
108113

109114
@Override

0 commit comments

Comments
 (0)