Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.LocalReferenceExpression;
import org.apache.flink.table.expressions.LookupCallExpression;
import org.apache.flink.table.expressions.NestedFieldReferenceExpression;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.expressions.SqlCallExpression;
import org.apache.flink.table.expressions.TableReferenceExpression;
Expand Down Expand Up @@ -117,4 +118,9 @@ public T visit(SqlCallExpression sqlCall) {
public T visitNonApiExpression(Expression other) {
return defaultMethod(other);
}

@Override
public T visit(NestedFieldReferenceExpression nestedFieldReference) {
return defaultMethod(nestedFieldReference);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.flink.table.expressions.CallExpression;
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.LocalReferenceExpression;
import org.apache.flink.table.expressions.NestedFieldReferenceExpression;
import org.apache.flink.table.expressions.ResolvedExpression;
import org.apache.flink.table.expressions.ResolvedExpressionVisitor;
import org.apache.flink.table.expressions.TableReferenceExpression;
Expand Down Expand Up @@ -70,5 +71,10 @@ public T visit(ResolvedExpression other) {
return defaultMethod(other);
}

@Override
public T visit(NestedFieldReferenceExpression nestedFieldReference) {
return defaultMethod(nestedFieldReference);
}

protected abstract T defaultMethod(ResolvedExpression expression);
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,10 @@ public T visit(Expression other) {
return defaultMethod(other);
}

@Override
public T visit(NestedFieldReferenceExpression nestedFieldReference) {
return defaultMethod(nestedFieldReference);
}

protected abstract T defaultMethod(Expression expression);
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,8 @@ public interface ExpressionVisitor<R> {
// --------------------------------------------------------------------------------------------

R visit(Expression other);

default R visit(NestedFieldReferenceExpression nestedFieldReference) {
throw new UnsupportedOperationException("NestedFieldReferenceExpression is not supported.");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.expressions;

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.table.types.DataType;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

/**
* A reference to a nested field in an input. The reference contains:
*
* <ul>
* <li>nested field names to traverse from the top level column to the nested leaf column.
* <li>nested field indices to traverse from the top level column to the nested leaf column.
* <li>type
* </ul>
*/
@PublicEvolving
public class NestedFieldReferenceExpression implements ResolvedExpression {

/** Nested field names to traverse from the top level column to the nested leaf column. */
private final String[] fieldNames;

/** Nested field index to traverse from the top level column to the nested leaf column. */
private final int[] fieldIndices;

private final DataType dataType;

public NestedFieldReferenceExpression(
String[] fieldNames, int[] fieldIndices, DataType dataType) {
this.fieldNames = fieldNames;
this.fieldIndices = fieldIndices;
this.dataType = dataType;
}

public String[] getFieldNames() {
return fieldNames;
}

public int[] getFieldIndices() {
return fieldIndices;
}

public String getName() {
return String.format(
"`%s`",
String.join(
".",
Arrays.stream(fieldNames)
.map(this::quoteIdentifier)
.toArray(String[]::new)));
}

@Override
public DataType getOutputDataType() {
return dataType;
}

@Override
public List<ResolvedExpression> getResolvedChildren() {
return Collections.emptyList();
}

@Override
public String asSummaryString() {
return getName();
}

@Override
public List<Expression> getChildren() {
return Collections.emptyList();
}

@Override
public <R> R accept(ExpressionVisitor<R> visitor) {
return visitor.visit(this);
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
NestedFieldReferenceExpression that = (NestedFieldReferenceExpression) o;
return Arrays.equals(fieldNames, that.fieldNames)
&& Arrays.equals(fieldIndices, that.fieldIndices)
&& dataType.equals(that.dataType);
}

@Override
public int hashCode() {
return Objects.hash(Arrays.hashCode(fieldNames), Arrays.hashCode(fieldIndices), dataType);
}

@Override
public String toString() {
return asSummaryString();
}

private String quoteIdentifier(String identifier) {
return identifier.replace("`", "``");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.flink.table.expressions.ExpressionVisitor;
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.LocalReferenceExpression;
import org.apache.flink.table.expressions.NestedFieldReferenceExpression;
import org.apache.flink.table.expressions.TimeIntervalUnit;
import org.apache.flink.table.expressions.TimePointUnit;
import org.apache.flink.table.expressions.TypeLiteralExpression;
Expand Down Expand Up @@ -202,6 +203,17 @@ public RexNode visit(FieldReferenceExpression fieldReference) {
return relBuilder.field(fieldReference.getName());
}

@Override
public RexNode visit(NestedFieldReferenceExpression nestedFieldReference) {
String[] fieldNames = nestedFieldReference.getFieldNames();
RexNode fieldAccess = relBuilder.field(fieldNames[0]);
for (int i = 1; i < fieldNames.length; i++) {
fieldAccess =
relBuilder.getRexBuilder().makeFieldAccess(fieldAccess, fieldNames[i], true);
}
return fieldAccess;
}

@Override
public RexNode visit(TypeLiteralExpression typeLiteral) {
throw new UnsupportedOperationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
import java.util.TimeZone;
import java.util.stream.Collectors;

import scala.Option;

import static org.apache.flink.util.Preconditions.checkNotNull;

/**
Expand Down Expand Up @@ -104,7 +106,10 @@ public static SupportsFilterPushDown.Result apply(
context.getFunctionCatalog(),
context.getCatalogManager(),
TimeZone.getTimeZone(
TableConfigUtils.getLocalTimeZone(context.getTableConfig())));
TableConfigUtils.getLocalTimeZone(context.getTableConfig())),
Option.apply(
context.getTypeFactory()
.buildRelNodeRowType(context.getSourceRowType())));
List<Expression> filters =
predicates.stream()
.map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,8 @@ object FlinkRexUtil {
inputNames,
context.getFunctionCatalog,
context.getCatalogManager,
TimeZone.getTimeZone(TableConfigUtils.getLocalTimeZone(context.getTableConfig)));
TimeZone.getTimeZone(TableConfigUtils.getLocalTimeZone(context.getTableConfig)),
Some(rel.getRowType));

RexNodeExtractor.extractConjunctiveConditions(
filterExpression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,23 @@ import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromLog
import org.apache.flink.table.types.DataType
import org.apache.flink.table.types.logical.LogicalTypeRoot._
import org.apache.flink.table.types.logical.YearMonthIntervalType
import org.apache.flink.table.types.utils.TypeConversions
import org.apache.flink.util.Preconditions

import org.apache.calcite.plan.RelOptUtil
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rex._
import org.apache.calcite.sql.{SqlFunction, SqlKind, SqlPostfixOperator}
import org.apache.calcite.sql.fun.{SqlStdOperatorTable, SqlTrimFunction}
import org.apache.calcite.util.{TimestampString, Util}

import java.util
import java.util.{List => JList, TimeZone}
import java.util.{Collections, List => JList, TimeZone}

import scala.collection.{mutable, JavaConverters}
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.{Failure, Success, Try}

object RexNodeExtractor extends Logging {
Expand Down Expand Up @@ -395,9 +398,19 @@ class RexNodeToExpressionConverter(
inputNames: Array[String],
functionCatalog: FunctionCatalog,
catalogManager: CatalogManager,
timeZone: TimeZone)
timeZone: TimeZone,
relDataType: Option[RelDataType] = None)
extends RexVisitor[Option[ResolvedExpression]] {

def this(
rexBuilder: RexBuilder,
inputNames: Array[String],
functionCatalog: FunctionCatalog,
catalogManager: CatalogManager,
timeZone: TimeZone) = {
this(rexBuilder, inputNames, functionCatalog, catalogManager, timeZone, None)
}

override def visitInputRef(inputRef: RexInputRef): Option[ResolvedExpression] = {
Preconditions.checkArgument(inputRef.getIndex < inputNames.length)
Some(
Expand Down Expand Up @@ -538,8 +551,35 @@ class RexNodeToExpressionConverter(
}
}

override def visitFieldAccess(fieldAccess: RexFieldAccess): Option[ResolvedExpression] = None
override def visitFieldAccess(fieldAccess: RexFieldAccess): Option[ResolvedExpression] = {
fieldAccess.getReferenceExpr match {
// push down on nested field inside a composite type like map or array is not supported
case _: RexCall => return None
case _ => // do nothing
}

relDataType match {
case Some(dataType) =>
val schema = NestedProjectionUtil.build(Collections.singletonList(fieldAccess), dataType)
val fieldIndices = NestedProjectionUtil.convertToIndexArray(schema)
var (topLevelColumnName, nestedColumn) = schema.columns.head
val fieldNames = new ArrayBuffer[String]()

while (!nestedColumn.isLeaf) {
fieldNames.add(topLevelColumnName)
topLevelColumnName = nestedColumn.children.head._1
nestedColumn = nestedColumn.children.head._2
}
fieldNames.add(topLevelColumnName)

Some(
new NestedFieldReferenceExpression(
fieldNames.toArray,
fieldIndices(0),
TypeConversions.fromLogicalToDataType(
FlinkTypeFactory.toLogicalType(fieldAccess.getType))))
}
}
override def visitCorrelVariable(correlVariable: RexCorrelVariable): Option[ResolvedExpression] =
None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public List<CatalogPartitionSpec> listPartitionsByFilter(
Function<String, Comparable<?>> getter =
getValueGetter(partition.getPartitionSpec(), schema);
return FilterUtils.isRetainedAfterApplyingFilterPredicates(
resolvedExpressions, getter);
resolvedExpressions, getter, null);
})
.collect(Collectors.toList());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,17 @@ private Function<String, Comparable<?>> getValueGetter(Row row) {
};
}

private Function<int[], Comparable<?>> getNestedValueGetter(Row row) {
return fieldIndices -> {
Object current = row;
for (int i = 0; i < fieldIndices.length - 1; i++) {
current = ((Row) current).getField(fieldIndices[i]);
}
return (Comparable<?>)
((Row) current).getField(fieldIndices[fieldIndices.length - 1]);
};
}

@Override
public DynamicTableSource copy() {
return new TestValuesScanTableSourceWithoutProjectionPushDown(
Expand Down Expand Up @@ -1183,7 +1194,9 @@ private Map<Map<String, String>, Collection<Row>> filterAllData(
for (Row row : allData.get(partition)) {
boolean isRetained =
FilterUtils.isRetainedAfterApplyingFilterPredicates(
filterPredicates, getValueGetter(row));
filterPredicates,
getValueGetter(row),
getNestedValueGetter(row));
if (isRetained) {
remainData.add(row);
}
Expand Down
Loading