Skip to content

Commit 02fe863

Browse files
committed
[FLINK-19449][table-planner] LEAD/LAG cannot work correctly in streaming mode
1 parent 928871b commit 02fe863

File tree

16 files changed

+646
-38
lines changed

16 files changed

+646
-38
lines changed

docs/data/sql_functions.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -674,10 +674,10 @@ aggregate:
674674
- sql: ROW_NUMER()
675675
description: Assigns a unique, sequential number to each row, starting with one, according to the ordering of rows within the window partition. ROW_NUMBER and RANK are similar. ROW_NUMBER numbers all rows sequentially (for example 1, 2, 3, 4, 5). RANK provides the same numeric value for ties (for example 1, 2, 2, 4, 5).
676676
- sql: LEAD(expression [, offset] [, default])
677-
description: Returns the value of expression at the offsetth row before the current row in the window. The default value of offset is 1 and the default value of default is NULL.
678-
- sql: LAG(expression [, offset] [, default])
679677
description: Returns the value of expression at the offsetth row after the current row in the window. The default value of offset is 1 and the default value of default is NULL.
680-
- sql: FIRST_VALUE(expression)
678+
- sql: LAG(expression [, offset] [, default])
679+
description: Returns the value of expression at the offsetth row before the current row in the window. The default value of offset is 1 and the default value of default is NULL.
680+
- sql: FIRST_VALUE(expression)
681681
description: Returns the first value in an ordered set of values.
682682
- sql: LAST_VALUE(expression)
683683
description: Returns the last value in an ordered set of values.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.planner.functions.aggfunctions;
20+
21+
import org.apache.flink.api.common.typeutils.TypeSerializer;
22+
import org.apache.flink.table.api.DataTypes;
23+
import org.apache.flink.table.api.TableException;
24+
import org.apache.flink.table.functions.AggregateFunction;
25+
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
26+
import org.apache.flink.table.runtime.typeutils.InternalSerializers;
27+
import org.apache.flink.table.runtime.typeutils.LinkedListSerializer;
28+
import org.apache.flink.table.types.DataType;
29+
import org.apache.flink.table.types.logical.LogicalType;
30+
import org.apache.flink.table.types.logical.LogicalTypeRoot;
31+
import org.apache.flink.table.types.utils.DataTypeUtils;
32+
33+
import java.util.Arrays;
34+
import java.util.LinkedList;
35+
import java.util.List;
36+
import java.util.Objects;
37+
38+
/** Lag {@link AggregateFunction}. */
39+
public class LagAggFunction<T> extends BuiltInAggregateFunction<T, LagAggFunction.LagAcc<T>> {
40+
41+
private final transient DataType[] valueDataTypes;
42+
43+
@SuppressWarnings("unchecked")
44+
public LagAggFunction(LogicalType[] valueTypes) {
45+
this.valueDataTypes =
46+
Arrays.stream(valueTypes)
47+
.map(DataTypeUtils::toInternalDataType)
48+
.toArray(DataType[]::new);
49+
if (valueDataTypes.length == 3
50+
&& valueDataTypes[2].getLogicalType().getTypeRoot() != LogicalTypeRoot.NULL) {
51+
if (valueDataTypes[0].getConversionClass() != valueDataTypes[2].getConversionClass()) {
52+
throw new TableException(
53+
String.format(
54+
"Please explicitly cast default value %s to %s.",
55+
valueDataTypes[2], valueDataTypes[1]));
56+
}
57+
}
58+
}
59+
60+
// --------------------------------------------------------------------------------------------
61+
// Planning
62+
// --------------------------------------------------------------------------------------------
63+
64+
@Override
65+
public List<DataType> getArgumentDataTypes() {
66+
return Arrays.asList(valueDataTypes);
67+
}
68+
69+
@Override
70+
public DataType getAccumulatorDataType() {
71+
return DataTypes.STRUCTURED(
72+
LagAcc.class,
73+
DataTypes.FIELD("offset", DataTypes.INT()),
74+
DataTypes.FIELD("defaultValue", valueDataTypes[0]),
75+
DataTypes.FIELD("buffer", getLinkedListType()));
76+
}
77+
78+
@SuppressWarnings({"unchecked", "rawtypes"})
79+
private DataType getLinkedListType() {
80+
TypeSerializer<T> serializer =
81+
InternalSerializers.create(getOutputDataType().getLogicalType());
82+
return DataTypes.RAW(
83+
LinkedList.class, (TypeSerializer) new LinkedListSerializer<>(serializer));
84+
}
85+
86+
@Override
87+
public DataType getOutputDataType() {
88+
return valueDataTypes[0];
89+
}
90+
91+
// --------------------------------------------------------------------------------------------
92+
// Runtime
93+
// --------------------------------------------------------------------------------------------
94+
95+
public void accumulate(LagAcc<T> acc, T value) throws Exception {
96+
acc.buffer.add(value);
97+
while (acc.buffer.size() > acc.offset + 1) {
98+
acc.buffer.removeFirst();
99+
}
100+
}
101+
102+
public void accumulate(LagAcc<T> acc, T value, int offset) throws Exception {
103+
if (offset < 0) {
104+
throw new TableException(String.format("Offset(%d) should be positive.", offset));
105+
}
106+
107+
acc.offset = offset;
108+
accumulate(acc, value);
109+
}
110+
111+
public void accumulate(LagAcc<T> acc, T value, int offset, T defaultValue) throws Exception {
112+
acc.defaultValue = defaultValue;
113+
accumulate(acc, value, offset);
114+
}
115+
116+
public void resetAccumulator(LagAcc<T> acc) throws Exception {
117+
acc.offset = 1;
118+
acc.defaultValue = null;
119+
acc.buffer.clear();
120+
}
121+
122+
@Override
123+
public T getValue(LagAcc<T> acc) {
124+
if (acc.buffer.size() < acc.offset + 1) {
125+
return acc.defaultValue;
126+
} else if (acc.buffer.size() == acc.offset + 1) {
127+
return acc.buffer.getFirst();
128+
} else {
129+
throw new TableException("Too more elements: " + acc);
130+
}
131+
}
132+
133+
@Override
134+
public LagAcc<T> createAccumulator() {
135+
return new LagAcc<>();
136+
}
137+
138+
/** Accumulator for LAG. */
139+
public static class LagAcc<T> {
140+
public int offset = 1;
141+
public T defaultValue = null;
142+
public LinkedList<T> buffer = new LinkedList<>();
143+
144+
@Override
145+
public boolean equals(Object o) {
146+
if (this == o) {
147+
return true;
148+
}
149+
if (o == null || getClass() != o.getClass()) {
150+
return false;
151+
}
152+
LagAcc<?> lagAcc = (LagAcc<?>) o;
153+
return offset == lagAcc.offset
154+
&& Objects.equals(defaultValue, lagAcc.defaultValue)
155+
&& Objects.equals(buffer, lagAcc.buffer);
156+
}
157+
158+
@Override
159+
public int hashCode() {
160+
return Objects.hash(offset, defaultValue, buffer);
161+
}
162+
}
163+
}

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecGlobalWindowAggregate.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,14 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
145145
final SliceAssigner sliceAssigner = createSliceAssigner(windowing, shiftTimeZone);
146146

147147
final AggregateInfoList localAggInfoList =
148-
AggregateUtil.deriveWindowAggregateInfoList(
148+
AggregateUtil.deriveStreamWindowAggregateInfoList(
149149
localAggInputRowType, // should use original input here
150150
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
151151
windowing.getWindow(),
152152
false); // isStateBackendDataViews
153153

154154
final AggregateInfoList globalAggInfoList =
155-
AggregateUtil.deriveWindowAggregateInfoList(
155+
AggregateUtil.deriveStreamWindowAggregateInfoList(
156156
localAggInputRowType, // should use original input here
157157
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
158158
windowing.getWindow(),

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecLocalWindowAggregate.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
122122
final SliceAssigner sliceAssigner = createSliceAssigner(windowing, shiftTimeZone);
123123

124124
final AggregateInfoList aggInfoList =
125-
AggregateUtil.deriveWindowAggregateInfoList(
125+
AggregateUtil.deriveStreamWindowAggregateInfoList(
126126
inputRowType,
127127
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
128128
windowing.getWindow(),

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecWindowAggregate.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ protected Transformation<RowData> translateToPlanInternal(PlannerBase planner) {
143143
// Hopping window requires additional COUNT(*) to determine whether to register next timer
144144
// through whether the current fired window is empty, see SliceSharedWindowAggProcessor.
145145
final AggregateInfoList aggInfoList =
146-
AggregateUtil.deriveWindowAggregateInfoList(
146+
AggregateUtil.deriveStreamWindowAggregateInfoList(
147147
inputRowType,
148148
JavaScalaConversionUtil.toScala(Arrays.asList(aggCalls)),
149149
windowing.getWindow(),

flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -562,9 +562,10 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
562562
def getAggCallFromLocalAgg(
563563
index: Int,
564564
aggCalls: Seq[AggregateCall],
565-
inputType: RelDataType): AggregateCall = {
565+
inputType: RelDataType,
566+
isBounded: Boolean): AggregateCall = {
566567
val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
567-
aggCalls, inputType)
568+
aggCalls, inputType, isBounded)
568569
if (outputIndexToAggCallIndexMap.containsKey(index)) {
569570
val realIndex = outputIndexToAggCallIndexMap.get(index)
570571
aggCalls(realIndex)
@@ -576,9 +577,10 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
576577
def getAggCallIndexInLocalAgg(
577578
index: Int,
578579
globalAggCalls: Seq[AggregateCall],
579-
inputRowType: RelDataType): Integer = {
580+
inputRowType: RelDataType,
581+
isBounded: Boolean): Integer = {
580582
val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
581-
globalAggCalls, inputRowType)
583+
globalAggCalls, inputRowType, isBounded)
582584

583585
outputIndexToAggCallIndexMap.foreach {
584586
case (k, v) => if (v == index) {
@@ -600,34 +602,37 @@ class FlinkRelMdColumnInterval private extends MetadataHandler[ColumnInterval] {
600602
case agg: StreamPhysicalGlobalGroupAggregate
601603
if agg.aggCalls.length > aggCallIndex =>
602604
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
603-
aggCallIndex, agg.aggCalls, agg.localAggInputRowType)
605+
aggCallIndex, agg.aggCalls, agg.localAggInputRowType, isBounded = false)
604606
if (aggCallIndexInLocalAgg != null) {
605607
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
606608
} else {
607609
null
608610
}
609611
case agg: StreamPhysicalLocalGroupAggregate =>
610-
getAggCallFromLocalAgg(aggCallIndex, agg.aggCalls, agg.getInput.getRowType)
612+
getAggCallFromLocalAgg(
613+
aggCallIndex, agg.aggCalls, agg.getInput.getRowType, isBounded = false)
611614
case agg: StreamPhysicalIncrementalGroupAggregate
612615
if agg.partialAggCalls.length > aggCallIndex =>
613616
agg.partialAggCalls(aggCallIndex)
614617
case agg: StreamPhysicalGroupWindowAggregate if agg.aggCalls.length > aggCallIndex =>
615618
agg.aggCalls(aggCallIndex)
616619
case agg: BatchPhysicalLocalHashAggregate =>
617-
getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
620+
getAggCallFromLocalAgg(
621+
aggCallIndex, agg.getAggCallList, agg.getInput.getRowType, isBounded = true)
618622
case agg: BatchPhysicalHashAggregate if agg.isMerge =>
619623
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
620-
aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
624+
aggCallIndex, agg.getAggCallList, agg.aggInputRowType, isBounded = true)
621625
if (aggCallIndexInLocalAgg != null) {
622626
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
623627
} else {
624628
null
625629
}
626630
case agg: BatchPhysicalLocalSortAggregate =>
627-
getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
631+
getAggCallFromLocalAgg(
632+
aggCallIndex, agg.getAggCallList, agg.getInput.getRowType, isBounded = true)
628633
case agg: BatchPhysicalSortAggregate if agg.isMerge =>
629634
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
630-
aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
635+
aggCallIndex, agg.getAggCallList, agg.aggInputRowType, isBounded = true)
631636
if (aggCallIndexInLocalAgg != null) {
632637
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
633638
} else {

flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalGlobalWindowAggregate.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class StreamPhysicalGlobalWindowAggregate(
6363
extends SingleRel(cluster, traitSet, inputRel)
6464
with StreamPhysicalRel {
6565

66-
private lazy val aggInfoList = AggregateUtil.deriveWindowAggregateInfoList(
66+
private lazy val aggInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList(
6767
FlinkTypeFactory.toLogicalRowType(inputRowTypeOfLocalAgg),
6868
aggCalls,
6969
windowing.getWindow,

flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalLocalWindowAggregate.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class StreamPhysicalLocalWindowAggregate(
5656
extends SingleRel(cluster, traitSet, inputRel)
5757
with StreamPhysicalRel {
5858

59-
private lazy val aggInfoList = AggregateUtil.deriveWindowAggregateInfoList(
59+
private lazy val aggInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList(
6060
FlinkTypeFactory.toLogicalRowType(inputRel.getRowType),
6161
aggCalls,
6262
windowing.getWindow,

flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalWindowAggregate.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class StreamPhysicalWindowAggregate(
5656
extends SingleRel(cluster, traitSet, inputRel)
5757
with StreamPhysicalRel {
5858

59-
lazy val aggInfoList: AggregateInfoList = AggregateUtil.deriveWindowAggregateInfoList(
59+
lazy val aggInfoList: AggregateInfoList = AggregateUtil.deriveStreamWindowAggregateInfoList(
6060
FlinkTypeFactory.toLogicalRowType(inputRel.getRowType),
6161
aggCalls,
6262
windowing.getWindow,

flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,16 @@ import scala.collection.JavaConversions._
4545
* as subclasses of [[SqlAggFunction]] in Calcite but not as [[BridgingSqlAggFunction]]. The factory
4646
* returns [[DeclarativeAggregateFunction]] or [[BuiltInAggregateFunction]].
4747
*
48-
* @param inputType the input rel data type
49-
* @param orderKeyIdx the indexes of order key (null when is not over agg)
50-
* @param needRetraction true if need retraction
48+
* @param inputRowType the input row type
49+
* @param orderKeyIndexes the indexes of order key (null when is not over agg)
50+
* @param aggCallNeedRetractions true if need retraction
51+
* @param isBounded true if the source is bounded source
5152
*/
5253
class AggFunctionFactory(
5354
inputRowType: RowType,
5455
orderKeyIndexes: Array[Int],
55-
aggCallNeedRetractions: Array[Boolean]) {
56+
aggCallNeedRetractions: Array[Boolean],
57+
isBounded: Boolean) {
5658

5759
/**
5860
* The entry point to create an aggregate function from the given [[AggregateCall]].
@@ -94,8 +96,12 @@ class AggFunctionFactory(
9496
case a: SqlRankFunction if a.getKind == SqlKind.DENSE_RANK =>
9597
createDenseRankAggFunction(argTypes)
9698

97-
case _: SqlLeadLagAggFunction =>
98-
createLeadLagAggFunction(argTypes, index)
99+
case func: SqlLeadLagAggFunction =>
100+
if (isBounded) {
101+
createBatchLeadLagAggFunction(argTypes, index)
102+
} else {
103+
createStreamLeadLagAggFunction(func, argTypes, index)
104+
}
99105

100106
case _: SqlSingleValueAggFunction =>
101107
createSingleValueAggFunction(argTypes)
@@ -328,7 +334,22 @@ class AggFunctionFactory(
328334
}
329335
}
330336

331-
private def createLeadLagAggFunction(
337+
private def createStreamLeadLagAggFunction(
338+
func: SqlLeadLagAggFunction,
339+
argTypes: Array[LogicalType],
340+
index: Int): UserDefinedFunction = {
341+
if (func.getKind == SqlKind.LEAD) {
342+
throw new TableException("LEAD Function is not supported in stream mode.")
343+
}
344+
345+
if (aggCallNeedRetractions(index)) {
346+
throw new TableException("LAG Function with retraction is not supported in stream mode.")
347+
}
348+
349+
new LagAggFunction(argTypes)
350+
}
351+
352+
private def createBatchLeadLagAggFunction(
332353
argTypes: Array[LogicalType], index: Int): UserDefinedFunction = {
333354
argTypes(0).getTypeRoot match {
334355
case TINYINT =>

0 commit comments

Comments
 (0)