Skip to content

Commit 368edd1

Browse files
committed
fix source watermark return type for table api
1 parent fd91126 commit 368edd1

File tree

4 files changed

+91
-114
lines changed

4 files changed

+91
-114
lines changed

flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/resolver/rules/ResolveCallByArgumentsRule.java

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,11 @@ final class ResolveCallByArgumentsRule implements ResolverRule {
8181

8282
@Override
8383
public List<Expression> apply(List<Expression> expression, ResolutionContext context) {
84+
// only the top-level expressions may access the output data type
85+
final SurroundingInfo surroundingInfo =
86+
context.getOutputDataType().map(SurroundingInfo::of).orElse(null);
8487
return expression.stream()
85-
.flatMap(expr -> expr.accept(new ResolvingCallVisitor(context, null)).stream())
88+
.flatMap(e -> e.accept(new ResolvingCallVisitor(context, surroundingInfo)).stream())
8689
.collect(Collectors.toList());
8790
}
8891

@@ -120,23 +123,23 @@ public List<ResolvedExpression> visit(UnresolvedCallExpression unresolvedCall) {
120123
// resolve the children with information from the current call
121124
final List<ResolvedExpression> resolvedArgs = new ArrayList<>();
122125
final int argCount = unresolvedCall.getChildren().size();
126+
123127
for (int i = 0; i < argCount; i++) {
124128
final int currentPos = i;
129+
final SurroundingInfo surroundingInfo =
130+
typeInference
131+
.map(
132+
inference ->
133+
SurroundingInfo.of(
134+
name,
135+
definition,
136+
inference,
137+
argCount,
138+
currentPos,
139+
resolutionContext.isGroupedAggregation()))
140+
.orElse(null);
125141
final ResolvingCallVisitor childResolver =
126-
new ResolvingCallVisitor(
127-
resolutionContext,
128-
typeInference
129-
.map(
130-
inference ->
131-
new SurroundingInfo(
132-
name,
133-
definition,
134-
inference,
135-
argCount,
136-
currentPos,
137-
resolutionContext
138-
.isGroupedAggregation()))
139-
.orElse(null));
142+
new ResolvingCallVisitor(resolutionContext, surroundingInfo);
140143
resolvedArgs.addAll(unresolvedCall.getChildren().get(i).accept(childResolver));
141144
}
142145

flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/catalog/SchemaResolutionTest.java

Lines changed: 38 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@
2121
import org.apache.flink.core.testutils.FlinkMatchers;
2222
import org.apache.flink.table.api.DataTypes;
2323
import org.apache.flink.table.api.Schema;
24+
import org.apache.flink.table.expressions.CallExpression;
2425
import org.apache.flink.table.expressions.ResolvedExpression;
2526
import org.apache.flink.table.expressions.utils.ResolvedExpressionMock;
27+
import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
28+
import org.apache.flink.table.functions.FunctionIdentifier;
2629
import org.apache.flink.table.types.DataType;
2730
import org.apache.flink.table.types.logical.LocalZonedTimestampType;
2831
import org.apache.flink.table.types.logical.LogicalType;
@@ -39,6 +42,7 @@
3942
import java.util.Collections;
4043

4144
import static org.apache.flink.table.api.Expressions.callSql;
45+
import static org.apache.flink.table.api.Expressions.sourceWatermark;
4246
import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.isProctimeAttribute;
4347
import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.isRowtimeAttribute;
4448
import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.isTimeAttribute;
@@ -90,23 +94,21 @@ public class SchemaResolutionTest {
9094

9195
// the type of ts_ltz is TIMESTAMP_LTZ
9296
private static final String COMPUTED_SQL_WITH_TS_LTZ = "ts_ltz - INTERVAL '60' MINUTE";
97+
9398
private static final ResolvedExpression COMPUTED_COLUMN_RESOLVED_WITH_TS_LTZ =
9499
new ResolvedExpressionMock(DataTypes.TIMESTAMP_LTZ(3), () -> COMPUTED_SQL_WITH_TS_LTZ);
100+
95101
private static final String WATERMARK_SQL_WITH_TS_LTZ = "ts1 - INTERVAL '5' SECOND";
102+
96103
private static final ResolvedExpression WATERMARK_RESOLVED_WITH_TS_LTZ =
97104
new ResolvedExpressionMock(DataTypes.TIMESTAMP_LTZ(3), () -> WATERMARK_SQL_WITH_TS_LTZ);
105+
98106
private static final Schema SCHEMA_WITH_TS_LTZ =
99107
Schema.newBuilder()
100-
.primaryKeyNamed("primary_constraint", "id") // out of order
101108
.column("id", DataTypes.INT().notNull())
102-
.column("counter", DataTypes.INT().notNull())
103-
.column("payload", "ROW<name STRING, age INT, flag BOOLEAN>")
104-
.columnByMetadata("topic", DataTypes.STRING(), true)
105-
.columnByExpression(
106-
"ts1", callSql(COMPUTED_SQL_WITH_TS_LTZ)) // out of order API expression
109+
.columnByExpression("ts1", callSql(COMPUTED_SQL_WITH_TS_LTZ))
107110
.columnByMetadata("ts_ltz", DataTypes.TIMESTAMP_LTZ(3), "timestamp")
108111
.watermark("ts1", WATERMARK_SQL_WITH_TS_LTZ)
109-
.columnByExpression("proctime", PROCTIME_SQL)
110112
.build();
111113

112114
@Test
@@ -152,38 +154,53 @@ public void testSchemaResolutionWithTimestampLtzRowtime() {
152154
new ResolvedSchema(
153155
Arrays.asList(
154156
Column.physical("id", DataTypes.INT().notNull()),
155-
Column.physical("counter", DataTypes.INT().notNull()),
156-
Column.physical(
157-
"payload",
158-
DataTypes.ROW(
159-
DataTypes.FIELD("name", DataTypes.STRING()),
160-
DataTypes.FIELD("age", DataTypes.INT()),
161-
DataTypes.FIELD("flag", DataTypes.BOOLEAN()))),
162-
Column.metadata("topic", DataTypes.STRING(), null, true),
163157
Column.computed("ts1", COMPUTED_COLUMN_RESOLVED_WITH_TS_LTZ),
164158
Column.metadata(
165-
"ts_ltz", DataTypes.TIMESTAMP_LTZ(3), "timestamp", false),
166-
Column.computed("proctime", PROCTIME_RESOLVED)),
159+
"ts_ltz", DataTypes.TIMESTAMP_LTZ(3), "timestamp", false)),
167160
Collections.singletonList(
168161
WatermarkSpec.of("ts1", WATERMARK_RESOLVED_WITH_TS_LTZ)),
169-
UniqueConstraint.primaryKey(
170-
"primary_constraint", Collections.singletonList("id")));
162+
null);
171163

172164
final ResolvedSchema actualStreamSchema = resolveSchema(SCHEMA_WITH_TS_LTZ, true);
173165
{
174166
assertThat(actualStreamSchema, equalTo(expectedSchema));
175167
assertTrue(isRowtimeAttribute(getType(actualStreamSchema, "ts1")));
176-
assertTrue(isProctimeAttribute(getType(actualStreamSchema, "proctime")));
177168
}
178169

179170
final ResolvedSchema actualBatchSchema = resolveSchema(SCHEMA_WITH_TS_LTZ, false);
180171
{
181172
assertThat(actualBatchSchema, equalTo(expectedSchema));
182173
assertFalse(isRowtimeAttribute(getType(actualBatchSchema, "ts1")));
183-
assertTrue(isProctimeAttribute(getType(actualBatchSchema, "proctime")));
184174
}
185175
}
186176

177+
@Test
178+
public void testSchemaResolutionWithSourceWatermark() {
179+
final ResolvedSchema expectedSchema =
180+
new ResolvedSchema(
181+
Collections.singletonList(
182+
Column.physical("ts_ltz", DataTypes.TIMESTAMP_LTZ(1))),
183+
Collections.singletonList(
184+
WatermarkSpec.of(
185+
"ts_ltz",
186+
new CallExpression(
187+
FunctionIdentifier.of(
188+
BuiltInFunctionDefinitions.SOURCE_WATERMARK
189+
.getName()),
190+
BuiltInFunctionDefinitions.SOURCE_WATERMARK,
191+
Collections.emptyList(),
192+
DataTypes.TIMESTAMP_LTZ(1)))),
193+
null);
194+
final ResolvedSchema resolvedSchema =
195+
resolveSchema(
196+
Schema.newBuilder()
197+
.column("ts_ltz", DataTypes.TIMESTAMP_LTZ(1))
198+
.watermark("ts_ltz", sourceWatermark())
199+
.build());
200+
201+
assertThat(resolvedSchema, equalTo(expectedSchema));
202+
}
203+
187204
@Test
188205
public void testSchemaResolutionErrors() {
189206

@@ -282,20 +299,6 @@ public void testUnresolvedSchemaString() {
282299
+ " WATERMARK FOR `ts` AS [ts - INTERVAL '5' SECOND],\n"
283300
+ " CONSTRAINT `primary_constraint` PRIMARY KEY (`id`) NOT ENFORCED\n"
284301
+ ")"));
285-
assertThat(
286-
SCHEMA_WITH_TS_LTZ.toString(),
287-
equalTo(
288-
"(\n"
289-
+ " `id` INT NOT NULL,\n"
290-
+ " `counter` INT NOT NULL,\n"
291-
+ " `payload` [ROW<name STRING, age INT, flag BOOLEAN>],\n"
292-
+ " `topic` METADATA VIRTUAL,\n"
293-
+ " `ts1` AS [ts_ltz - INTERVAL '60' MINUTE],\n"
294-
+ " `ts_ltz` METADATA FROM 'timestamp',\n"
295-
+ " `proctime` AS [PROCTIME()],\n"
296-
+ " WATERMARK FOR `ts1` AS [ts1 - INTERVAL '5' SECOND],\n"
297-
+ " CONSTRAINT `primary_constraint` PRIMARY KEY (`id`) NOT ENFORCED\n"
298-
+ ")"));
299302
}
300303

301304
@Test
@@ -315,22 +318,6 @@ public void testResolvedSchemaString() {
315318
+ " WATERMARK FOR `ts`: TIMESTAMP(3) AS ts - INTERVAL '5' SECOND,\n"
316319
+ " CONSTRAINT `primary_constraint` PRIMARY KEY (`id`) NOT ENFORCED\n"
317320
+ ")"));
318-
319-
final ResolvedSchema resolvedSchemaWithTsLtz = resolveSchema(SCHEMA_WITH_TS_LTZ);
320-
assertThat(
321-
resolvedSchemaWithTsLtz.toString(),
322-
equalTo(
323-
"(\n"
324-
+ " `id` INT NOT NULL,\n"
325-
+ " `counter` INT NOT NULL,\n"
326-
+ " `payload` ROW<`name` STRING, `age` INT, `flag` BOOLEAN>,\n"
327-
+ " `topic` STRING METADATA VIRTUAL,\n"
328-
+ " `ts1` TIMESTAMP_LTZ(3) *ROWTIME* AS ts_ltz - INTERVAL '60' MINUTE,\n"
329-
+ " `ts_ltz` TIMESTAMP_LTZ(3) METADATA FROM 'timestamp',\n"
330-
+ " `proctime` TIMESTAMP_LTZ(3) NOT NULL *PROCTIME* AS PROCTIME(),\n"
331-
+ " WATERMARK FOR `ts1`: TIMESTAMP_LTZ(3) AS ts1 - INTERVAL '5' SECOND,\n"
332-
+ " CONSTRAINT `primary_constraint` PRIMARY KEY (`id`) NOT ENFORCED\n"
333-
+ ")"));
334321
}
335322

336323
@Test

flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/TypeInferenceUtil.java

Lines changed: 34 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -216,63 +216,50 @@ public static TableException createUnexpectedException(
216216
*
217217
* @see CallContext#getOutputDataType()
218218
*/
219-
public static final class SurroundingInfo {
219+
public interface SurroundingInfo {
220220

221-
private final String name;
222-
223-
private final FunctionDefinition functionDefinition;
224-
225-
private final TypeInference typeInference;
226-
227-
private final int argumentCount;
228-
229-
private final int innerCallPosition;
230-
231-
private final boolean isGroupedAggregation;
232-
233-
public SurroundingInfo(
221+
static SurroundingInfo of(
234222
String name,
235223
FunctionDefinition functionDefinition,
236224
TypeInference typeInference,
237225
int argumentCount,
238226
int innerCallPosition,
239227
boolean isGroupedAggregation) {
240-
this.name = name;
241-
this.functionDefinition = functionDefinition;
242-
this.typeInference = typeInference;
243-
this.argumentCount = argumentCount;
244-
this.innerCallPosition = innerCallPosition;
245-
this.isGroupedAggregation = isGroupedAggregation;
228+
return typeFactory -> {
229+
final boolean isValidCount =
230+
validateArgumentCount(
231+
typeInference.getInputTypeStrategy().getArgumentCount(),
232+
argumentCount,
233+
false);
234+
if (!isValidCount) {
235+
return Optional.empty();
236+
}
237+
// for "takes_string(this_function(NULL))" simulate "takes_string(NULL)"
238+
// for retrieving the output type of "this_function(NULL)"
239+
final CallContext callContext =
240+
new UnknownCallContext(
241+
typeFactory,
242+
name,
243+
functionDefinition,
244+
argumentCount,
245+
isGroupedAggregation);
246+
247+
// We might not be able to infer the input types at this moment, if the surrounding
248+
// function does not provide an explicit input type strategy.
249+
final CallContext adaptedContext =
250+
adaptArguments(typeInference, callContext, null, false);
251+
return typeInference
252+
.getInputTypeStrategy()
253+
.inferInputTypes(adaptedContext, false)
254+
.map(dataTypes -> dataTypes.get(innerCallPosition));
255+
};
246256
}
247257

248-
private Optional<DataType> inferOutputType(DataTypeFactory typeFactory) {
249-
final boolean isValidCount =
250-
validateArgumentCount(
251-
typeInference.getInputTypeStrategy().getArgumentCount(),
252-
argumentCount,
253-
false);
254-
if (!isValidCount) {
255-
return Optional.empty();
256-
}
257-
// for "takes_string(this_function(NULL))" simulate "takes_string(NULL)"
258-
// for retrieving the output type of "this_function(NULL)"
259-
final CallContext callContext =
260-
new UnknownCallContext(
261-
typeFactory,
262-
name,
263-
functionDefinition,
264-
argumentCount,
265-
isGroupedAggregation);
266-
267-
// We might not be able to infer the input types at this moment, if the surrounding
268-
// function does not provide an explicit input type strategy.
269-
final CallContext adaptedContext =
270-
adaptArguments(typeInference, callContext, null, false);
271-
return typeInference
272-
.getInputTypeStrategy()
273-
.inferInputTypes(adaptedContext, false)
274-
.map(dataTypes -> dataTypes.get(innerCallPosition));
258+
static SurroundingInfo of(DataType dataType) {
259+
return typeFactory -> Optional.of(dataType);
275260
}
261+
262+
Optional<DataType> inferOutputType(DataTypeFactory typeFactory);
276263
}
277264

278265
/**

flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/inference/InputTypeStrategiesTestBase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ private TypeInferenceUtil.Result runTypeInference(List<DataType> actualArgumentT
117117
.outputTypeStrategy(TypeStrategies.MISSING)
118118
.build();
119119
surroundingInfo =
120-
new TypeInferenceUtil.SurroundingInfo(
120+
TypeInferenceUtil.SurroundingInfo.of(
121121
"f_outer",
122122
functionDefinitionMock,
123123
outerTypeInference,

0 commit comments

Comments
 (0)