diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionDefaultVisitor.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionDefaultVisitor.java index ce499a90446e5..9797f53a331c2 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionDefaultVisitor.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ApiExpressionDefaultVisitor.java @@ -25,6 +25,7 @@ import org.apache.flink.table.expressions.FieldReferenceExpression; import org.apache.flink.table.expressions.LocalReferenceExpression; import org.apache.flink.table.expressions.LookupCallExpression; +import org.apache.flink.table.expressions.NestedFieldReferenceExpression; import org.apache.flink.table.expressions.ResolvedExpression; import org.apache.flink.table.expressions.SqlCallExpression; import org.apache.flink.table.expressions.TableReferenceExpression; @@ -117,4 +118,9 @@ public T visit(SqlCallExpression sqlCall) { public T visitNonApiExpression(Expression other) { return defaultMethod(other); } + + @Override + public T visit(NestedFieldReferenceExpression nestedFieldReference) { + return defaultMethod(nestedFieldReference); + } } diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ResolvedExpressionDefaultVisitor.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ResolvedExpressionDefaultVisitor.java index 1370175cb8e73..3bf93880d7d56 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ResolvedExpressionDefaultVisitor.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/expressions/utils/ResolvedExpressionDefaultVisitor.java @@ -22,6 +22,7 @@ import org.apache.flink.table.expressions.CallExpression; import org.apache.flink.table.expressions.FieldReferenceExpression; import org.apache.flink.table.expressions.LocalReferenceExpression; +import org.apache.flink.table.expressions.NestedFieldReferenceExpression; import org.apache.flink.table.expressions.ResolvedExpression; import org.apache.flink.table.expressions.ResolvedExpressionVisitor; import org.apache.flink.table.expressions.TableReferenceExpression; @@ -70,5 +71,10 @@ public T visit(ResolvedExpression other) { return defaultMethod(other); } + @Override + public T visit(NestedFieldReferenceExpression nestedFieldReference) { + return defaultMethod(nestedFieldReference); + } + protected abstract T defaultMethod(ResolvedExpression expression); } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/ExpressionDefaultVisitor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/ExpressionDefaultVisitor.java index 95af8c185028a..cc38f09258271 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/ExpressionDefaultVisitor.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/ExpressionDefaultVisitor.java @@ -52,5 +52,10 @@ public T visit(Expression other) { return defaultMethod(other); } + @Override + public T visit(NestedFieldReferenceExpression nestedFieldReference) { + return defaultMethod(nestedFieldReference); + } + protected abstract T defaultMethod(Expression expression); } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/ExpressionVisitor.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/ExpressionVisitor.java index 41e9b852b24fa..62f9ff7bbca83 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/ExpressionVisitor.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/ExpressionVisitor.java @@ -48,4 +48,8 @@ public interface ExpressionVisitor { // -------------------------------------------------------------------------------------------- R visit(Expression other); + + default R visit(NestedFieldReferenceExpression nestedFieldReference) { + throw new UnsupportedOperationException("NestedFieldReferenceExpression is not supported."); + } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/NestedFieldReferenceExpression.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/NestedFieldReferenceExpression.java new file mode 100644 index 0000000000000..70575ddf2a01d --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/NestedFieldReferenceExpression.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.expressions; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.table.types.DataType; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * A reference to a nested field in an input. The reference contains: + * + * + */ +@PublicEvolving +public class NestedFieldReferenceExpression implements ResolvedExpression { + + /** Nested field names to traverse from the top level column to the nested leaf column. */ + private final String[] fieldNames; + + /** Nested field index to traverse from the top level column to the nested leaf column. */ + private final int[] fieldIndices; + + private final DataType dataType; + + public NestedFieldReferenceExpression( + String[] fieldNames, int[] fieldIndices, DataType dataType) { + this.fieldNames = fieldNames; + this.fieldIndices = fieldIndices; + this.dataType = dataType; + } + + public String[] getFieldNames() { + return fieldNames; + } + + public int[] getFieldIndices() { + return fieldIndices; + } + + public String getName() { + return String.format( + "`%s`", + String.join( + ".", + Arrays.stream(fieldNames) + .map(this::quoteIdentifier) + .toArray(String[]::new))); + } + + @Override + public DataType getOutputDataType() { + return dataType; + } + + @Override + public List getResolvedChildren() { + return Collections.emptyList(); + } + + @Override + public String asSummaryString() { + return getName(); + } + + @Override + public List getChildren() { + return Collections.emptyList(); + } + + @Override + public R accept(ExpressionVisitor visitor) { + return visitor.visit(this); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + NestedFieldReferenceExpression that = (NestedFieldReferenceExpression) o; + return Arrays.equals(fieldNames, that.fieldNames) + && Arrays.equals(fieldIndices, that.fieldIndices) + && dataType.equals(that.dataType); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(fieldNames), Arrays.hashCode(fieldIndices), dataType); + } + + @Override + public String toString() { + return asSummaryString(); + } + + private String quoteIdentifier(String identifier) { + return identifier.replace("`", "``"); + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/converter/ExpressionConverter.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/converter/ExpressionConverter.java index ab5e0cf09c3a4..66c90e344edfd 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/converter/ExpressionConverter.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/converter/ExpressionConverter.java @@ -26,6 +26,7 @@ import org.apache.flink.table.expressions.ExpressionVisitor; import org.apache.flink.table.expressions.FieldReferenceExpression; import org.apache.flink.table.expressions.LocalReferenceExpression; +import org.apache.flink.table.expressions.NestedFieldReferenceExpression; import org.apache.flink.table.expressions.TimeIntervalUnit; import org.apache.flink.table.expressions.TimePointUnit; import org.apache.flink.table.expressions.TypeLiteralExpression; @@ -202,6 +203,17 @@ public RexNode visit(FieldReferenceExpression fieldReference) { return relBuilder.field(fieldReference.getName()); } + @Override + public RexNode visit(NestedFieldReferenceExpression nestedFieldReference) { + String[] fieldNames = nestedFieldReference.getFieldNames(); + RexNode fieldAccess = relBuilder.field(fieldNames[0]); + for (int i = 1; i < fieldNames.length; i++) { + fieldAccess = + relBuilder.getRexBuilder().makeFieldAccess(fieldAccess, fieldNames[i], true); + } + return fieldAccess; + } + @Override public RexNode visit(TypeLiteralExpression typeLiteral) { throw new UnsupportedOperationException(); diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/FilterPushDownSpec.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/FilterPushDownSpec.java index c7bd36ceb8f10..007fa0a4054df 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/FilterPushDownSpec.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/abilities/source/FilterPushDownSpec.java @@ -45,6 +45,8 @@ import java.util.TimeZone; import java.util.stream.Collectors; +import scala.Option; + import static org.apache.flink.util.Preconditions.checkNotNull; /** @@ -104,7 +106,10 @@ public static SupportsFilterPushDown.Result apply( context.getFunctionCatalog(), context.getCatalogManager(), TimeZone.getTimeZone( - TableConfigUtils.getLocalTimeZone(context.getTableConfig()))); + TableConfigUtils.getLocalTimeZone(context.getTableConfig())), + Option.apply( + context.getTypeFactory() + .buildRelNodeRowType(context.getSourceRowType()))); List filters = predicates.stream() .map( diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala index 057fa50ed9c47..82590106330bf 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala @@ -644,7 +644,8 @@ object FlinkRexUtil { inputNames, context.getFunctionCatalog, context.getCatalogManager, - TimeZone.getTimeZone(TableConfigUtils.getLocalTimeZone(context.getTableConfig))); + TimeZone.getTimeZone(TableConfigUtils.getLocalTimeZone(context.getTableConfig)), + Some(rel.getRowType)); RexNodeExtractor.extractConjunctiveConditions( filterExpression, diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala index a7cbb4a9ffc01..482ce56dc6389 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala @@ -34,20 +34,23 @@ import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromLog import org.apache.flink.table.types.DataType import org.apache.flink.table.types.logical.LogicalTypeRoot._ import org.apache.flink.table.types.logical.YearMonthIntervalType +import org.apache.flink.table.types.utils.TypeConversions import org.apache.flink.util.Preconditions import org.apache.calcite.plan.RelOptUtil +import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rex._ import org.apache.calcite.sql.{SqlFunction, SqlKind, SqlPostfixOperator} import org.apache.calcite.sql.fun.{SqlStdOperatorTable, SqlTrimFunction} import org.apache.calcite.util.{TimestampString, Util} import java.util -import java.util.{List => JList, TimeZone} +import java.util.{Collections, List => JList, TimeZone} +import scala.collection.{mutable, JavaConverters} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ -import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Success, Try} object RexNodeExtractor extends Logging { @@ -395,9 +398,19 @@ class RexNodeToExpressionConverter( inputNames: Array[String], functionCatalog: FunctionCatalog, catalogManager: CatalogManager, - timeZone: TimeZone) + timeZone: TimeZone, + relDataType: Option[RelDataType] = None) extends RexVisitor[Option[ResolvedExpression]] { + def this( + rexBuilder: RexBuilder, + inputNames: Array[String], + functionCatalog: FunctionCatalog, + catalogManager: CatalogManager, + timeZone: TimeZone) = { + this(rexBuilder, inputNames, functionCatalog, catalogManager, timeZone, None) + } + override def visitInputRef(inputRef: RexInputRef): Option[ResolvedExpression] = { Preconditions.checkArgument(inputRef.getIndex < inputNames.length) Some( @@ -538,8 +551,35 @@ class RexNodeToExpressionConverter( } } - override def visitFieldAccess(fieldAccess: RexFieldAccess): Option[ResolvedExpression] = None + override def visitFieldAccess(fieldAccess: RexFieldAccess): Option[ResolvedExpression] = { + fieldAccess.getReferenceExpr match { + // push down on nested field inside a composite type like map or array is not supported + case _: RexCall => return None + case _ => // do nothing + } + relDataType match { + case Some(dataType) => + val schema = NestedProjectionUtil.build(Collections.singletonList(fieldAccess), dataType) + val fieldIndices = NestedProjectionUtil.convertToIndexArray(schema) + var (topLevelColumnName, nestedColumn) = schema.columns.head + val fieldNames = new ArrayBuffer[String]() + + while (!nestedColumn.isLeaf) { + fieldNames.add(topLevelColumnName) + topLevelColumnName = nestedColumn.children.head._1 + nestedColumn = nestedColumn.children.head._2 + } + fieldNames.add(topLevelColumnName) + + Some( + new NestedFieldReferenceExpression( + fieldNames.toArray, + fieldIndices(0), + TypeConversions.fromLogicalToDataType( + FlinkTypeFactory.toLogicalType(fieldAccess.getType)))) + } + } override def visitCorrelVariable(correlVariable: RexCorrelVariable): Option[ResolvedExpression] = None diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesCatalog.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesCatalog.java index badd3ecb004b0..2fc45dc743a0a 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesCatalog.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesCatalog.java @@ -92,7 +92,7 @@ public List listPartitionsByFilter( Function> getter = getValueGetter(partition.getPartitionSpec(), schema); return FilterUtils.isRetainedAfterApplyingFilterPredicates( - resolvedExpressions, getter); + resolvedExpressions, getter, null); }) .collect(Collectors.toList()); } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java index f739fa558a4e6..8c5f34445c97e 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesTableFactory.java @@ -1007,6 +1007,17 @@ private Function> getValueGetter(Row row) { }; } + private Function> getNestedValueGetter(Row row) { + return fieldIndices -> { + Object current = row; + for (int i = 0; i < fieldIndices.length - 1; i++) { + current = ((Row) current).getField(fieldIndices[i]); + } + return (Comparable) + ((Row) current).getField(fieldIndices[fieldIndices.length - 1]); + }; + } + @Override public DynamicTableSource copy() { return new TestValuesScanTableSourceWithoutProjectionPushDown( @@ -1183,7 +1194,9 @@ private Map, Collection> filterAllData( for (Row row : allData.get(partition)) { boolean isRetained = FilterUtils.isRetainedAfterApplyingFilterPredicates( - filterPredicates, getValueGetter(row)); + filterPredicates, + getValueGetter(row), + getNestedValueGetter(row)); if (isRetained) { remainData.add(row); } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/PushFilterIntoTableSourceScanRuleTest.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/PushFilterIntoTableSourceScanRuleTest.java index a2c2a8447a7aa..47e03c3beb409 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/PushFilterIntoTableSourceScanRuleTest.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/PushFilterIntoTableSourceScanRuleTest.java @@ -86,6 +86,39 @@ public void setup() { + ")"; util.tableEnv().executeSql(ddl2); + + String ddl3 = + "CREATE TABLE NestedTable (\n" + + " id int,\n" + + " deepNested row, nested2 row>,\n" + + " nested row,\n" + + " `deepNestedWith.` row<`.value` int, nested row<```name` string, `.value` int>>,\n" + + " name string,\n" + + " testMap Map\n" + + ") WITH (\n" + + " 'connector' = 'values',\n" + + " 'filterable-fields' = '`deepNested.nested1.value`;`deepNestedWith..nested..value`;`deepNestedWith..nested.``name`;'," + + " 'bounded' = 'true'\n" + + ")"; + util.tableEnv().executeSql(ddl3); + + String ddl4 = + "CREATE TABLE NestedItemTable (\n" + + " `ID` INT,\n" + + " `Timestamp` TIMESTAMP(3),\n" + + " `Result` ROW<\n" + + " `Mid` ROW<" + + " `data_arr` ROW<`value` BIGINT> ARRAY,\n" + + " `data_map` MAP>" + + " >" + + " >,\n" + + " WATERMARK FOR `Timestamp` AS `Timestamp`\n" + + ") WITH (\n" + + " 'connector' = 'values',\n" + + " 'filterable-fields' = 'Result_Mid_data_map;'," + + " 'bounded' = 'true'\n" + + ")"; + util.tableEnv().executeSql(ddl4); } @Test @@ -118,4 +151,34 @@ public void testWithInterval() { util.tableEnv().executeSql(ddl); super.testWithInterval(); } + + @Test + public void testBasicNestedFilter() { + util.verifyRelPlan("SELECT * FROM NestedTable WHERE deepNested.nested1.`value` > 2"); + } + + @Test + public void testNestedFilterWithDotInTheName() { + util.verifyRelPlan( + "SELECT id FROM NestedTable WHERE `deepNestedWith.`.nested.`.value` > 5"); + } + + @Test + public void testNestedFilterWithBacktickInTheName() { + util.verifyRelPlan( + "SELECT id FROM NestedTable WHERE `deepNestedWith.`.nested.```name` = 'foo'"); + } + + @Test + public void testNestedFilterOnMapKey() { + util.verifyRelPlan( + "SELECT * FROM NestedItemTable WHERE" + + " `Result`.`Mid`.data_map['item'].`value` = 3"); + } + + @Test + public void testNestedFilterOnArrayField() { + util.verifyRelPlan( + "SELECT * FROM NestedItemTable WHERE `Result`.`Mid`.data_arr[2].`value` = 3"); + } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/utils/FilterUtils.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/utils/FilterUtils.java index 0209193266244..32d3afa233a9d 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/utils/FilterUtils.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/utils/FilterUtils.java @@ -22,12 +22,15 @@ import org.apache.flink.table.expressions.CallExpression; import org.apache.flink.table.expressions.Expression; import org.apache.flink.table.expressions.FieldReferenceExpression; +import org.apache.flink.table.expressions.NestedFieldReferenceExpression; import org.apache.flink.table.expressions.ResolvedExpression; import org.apache.flink.table.expressions.ValueLiteralExpression; import org.apache.flink.table.functions.BuiltInFunctionDefinitions; import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.util.Preconditions; +import javax.annotation.Nullable; + import java.util.List; import java.util.Optional; import java.util.Set; @@ -50,7 +53,9 @@ && shouldPushDownUnaryExpression( } public static boolean isRetainedAfterApplyingFilterPredicates( - List predicates, Function> getter) { + List predicates, + Function> getter, + @Nullable Function> nestedFieldGetter) { for (ResolvedExpression predicate : predicates) { if (predicate instanceof CallExpression) { FunctionDefinition definition = @@ -62,13 +67,17 @@ public static boolean isRetainedAfterApplyingFilterPredicates( if (!(expr instanceof CallExpression && expr.getChildren().size() == 2)) { throw new TableException(expr + " not supported!"); } - result = binaryFilterApplies((CallExpression) expr, getter); + result = + binaryFilterApplies( + (CallExpression) expr, getter, nestedFieldGetter); if (result) { break; } } } else if (predicate.getChildren().size() == 2) { - result = binaryFilterApplies((CallExpression) predicate, getter); + result = + binaryFilterApplies( + (CallExpression) predicate, getter, nestedFieldGetter); } else { throw new UnsupportedOperationException( String.format("Unsupported expr: %s.", predicate)); @@ -96,6 +105,12 @@ private static boolean shouldPushDownUnaryExpression( } } + if (expr instanceof NestedFieldReferenceExpression) { + if (filterableFields.contains(((NestedFieldReferenceExpression) expr).getName())) { + return true; + } + } + if (expr instanceof ValueLiteralExpression) { return true; } @@ -113,12 +128,14 @@ private static boolean shouldPushDownUnaryExpression( @SuppressWarnings({"unchecked", "rawtypes"}) private static boolean binaryFilterApplies( - CallExpression binExpr, Function> getter) { + CallExpression binExpr, + Function> getter, + Function> nestedFieldGetter) { List children = binExpr.getChildren(); Preconditions.checkArgument(children.size() == 2); - Comparable lhsValue = getValue(children.get(0), getter); - Comparable rhsValue = getValue(children.get(1), getter); + Comparable lhsValue = getValue(children.get(0), getter, nestedFieldGetter); + Comparable rhsValue = getValue(children.get(1), getter, nestedFieldGetter); FunctionDefinition functionDefinition = binExpr.getFunctionDefinition(); if (BuiltInFunctionDefinitions.GREATER_THAN.equals(functionDefinition)) { return lhsValue.compareTo(rhsValue) > 0; @@ -141,7 +158,10 @@ private static boolean isComparable(Class clazz) { return Comparable.class.isAssignableFrom(clazz); } - private static Comparable getValue(Expression expr, Function> getter) { + private static Comparable getValue( + Expression expr, + Function> getter, + Function> nestedFieldGetter) { if (expr instanceof ValueLiteralExpression) { Optional value = ((ValueLiteralExpression) expr) @@ -156,8 +176,17 @@ private static Comparable getValue(Expression expr, Function + + + 2]]> + + + ($1.nested1.value, 2)]) + +- LogicalTableScan(table=[[default_catalog, default_database, NestedTable]]) +]]> + + + (deepNested.nested1.value, 2)]]]) +]]> + + + + + 5]]> + + + ($3.nested..value, 5)]) + +- LogicalTableScan(table=[[default_catalog, default_database, NestedTable]]) +]]> + + + (deepNestedWith..nested..value, 5)]]]) +]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/TableSourceITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/TableSourceITCase.scala index e3c987fd6a246..bae6b751ec215 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/TableSourceITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/TableSourceITCase.scala @@ -77,6 +77,7 @@ class TableSourceITCase extends BatchTestBase { |) WITH ( | 'connector' = 'values', | 'nested-projection-supported' = 'true', + | 'filterable-fields' = '`nested.value`;`nestedItem.deepMap`;`nestedItem.deepArray`', | 'data-id' = '$nestedTableDataId', | 'bounded' = 'true' |) @@ -427,4 +428,41 @@ class TableSourceITCase extends BatchTestBase { "3,2,Hello world") assertThat(expected.sorted).isEqualTo(result.sorted) } + + @Test + def testSimpleNestedFilter(): Unit = { + checkResult( + """ + |SELECT id, deepNested.nested1.name AS nestedName FROM NestedTable + | WHERE nested.`value` > 20000 + """.stripMargin, + Seq(row(3, "Mike")) + ) + } + + @Test + def testNestedFilterOnArray(): Unit = { + checkResult( + """ + |SELECT id, + | deepNested.nested1.name AS nestedName, + | nestedItem.deepArray[2].`value` FROM NestedTable + |WHERE nestedItem.deepArray[2].`value` > 1 + """.stripMargin, + Seq(row(1, "Sarah", 2), row(2, "Rob", 2), row(3, "Mike", 2)) + ) + } + + @Test + def testNestedFilterOnMap(): Unit = { + checkResult( + """ + |SELECT id, + | deepNested.nested1.name AS nestedName, + | nestedItem.deepMap['Monday'] FROM NestedTable + |WHERE nestedItem.deepMap['Monday'] = 1 + """.stripMargin, + Seq(row(1, "Sarah", 1), row(2, "Rob", 1), row(3, "Mike", 1)) + ) + } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TableSourceITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TableSourceITCase.scala index 206ae68d04064..6deb1fd61c7f4 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TableSourceITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/TableSourceITCase.scala @@ -101,6 +101,7 @@ class TableSourceITCase extends StreamingTestBase { |) WITH ( | 'connector' = 'values', | 'nested-projection-supported' = 'true', + | 'filterable-fields' = '`nested.value`;`nestedItem.deepMap`;`nestedItem.deepArray`', | 'data-id' = '$nestedTableDataId', | 'bounded' = 'true' |) @@ -376,4 +377,57 @@ class TableSourceITCase extends StreamingTestBase { assertThat(t, containsCause(new TableException(SourceWatermarkFunction.ERROR_MESSAGE))) } } + + @Test + def testSimpleNestedFilter(): Unit = { + val query = + """ + |SELECT id, deepNested.nested1.name AS nestedName FROM NestedTable + | WHERE nested.`value` > 20000 + """.stripMargin + val result = tEnv.sqlQuery(query).toAppendStream[Row] + val sink = new TestingAppendSink + result.addSink(sink) + env.execute() + + val expected = Seq("3,Mike") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + } + + @Test + def testNestedFilterOnArray(): Unit = { + val query = + """ + |SELECT id, + | deepNested.nested1.name AS nestedName, + | nestedItem.deepArray[2].`value` FROM NestedTable + |WHERE nestedItem.deepArray[2].`value` > 1 + """.stripMargin + val result = tEnv.sqlQuery(query).toAppendStream[Row] + val sink = new TestingAppendSink + result.addSink(sink) + env.execute() + + val expected = Seq("1,Sarah,2", "2,Rob,2", "3,Mike,2") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + } + + @Test + def testNestedFilterOnMap(): Unit = { + val query = + """ + |SELECT id, + | deepNested.nested1.name AS nestedName, + | nestedItem.deepMap['Monday'] FROM NestedTable + |WHERE nestedItem.deepMap['Monday'] = 1 + """.stripMargin + + val result = tEnv.sqlQuery(query).toAppendStream[Row] + val sink = new TestingAppendSink + result.addSink(sink) + env.execute() + + val expected = Seq("1,Sarah,1", "2,Rob,1", "3,Mike,1") + assertEquals(expected.sorted, sink.getAppendResults.sorted) + } }