Skip to content

Commit 9546f5c

Browse files
authored
branch-4.0: [feature](nereids) Support dereference expression #57532 (#58546)
picked from #57532
1 parent 31557ff commit 9546f5c

File tree

8 files changed

+360
-27
lines changed

8 files changed

+360
-27
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,7 @@
522522
import org.apache.doris.nereids.trees.expressions.CaseWhen;
523523
import org.apache.doris.nereids.trees.expressions.Cast;
524524
import org.apache.doris.nereids.trees.expressions.DefaultValueSlot;
525+
import org.apache.doris.nereids.trees.expressions.DereferenceExpression;
525526
import org.apache.doris.nereids.trees.expressions.Divide;
526527
import org.apache.doris.nereids.trees.expressions.EqualTo;
527528
import org.apache.doris.nereids.trees.expressions.Exists;
@@ -3418,8 +3419,7 @@ public Expression visitDereference(DereferenceContext ctx) {
34183419
UnboundSlot slot = new UnboundSlot(nameParts, Optional.empty());
34193420
return slot;
34203421
} else {
3421-
// todo: base is an expression, may be not a table name.
3422-
throw new ParseException("Unsupported dereference expression: " + ctx.getText(), ctx);
3422+
return new DereferenceExpression(e, new StringLiteral(ctx.identifier().getText()));
34233423
}
34243424
});
34253425
}

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -469,12 +469,12 @@ private LogicalHaving<Plan> bindHavingAggregate(
469469
Scope groupBySlotsScope = toScope(cascadesContext, groupBySlots.build());
470470

471471
return (analyzer, unboundSlot) -> {
472-
List<Slot> boundInGroupBy = analyzer.bindSlotByScope(unboundSlot, groupBySlotsScope);
472+
List<Expression> boundInGroupBy = analyzer.bindSlotByScope(unboundSlot, groupBySlotsScope);
473473
if (!boundInGroupBy.isEmpty()) {
474474
return ImmutableList.of(boundInGroupBy.get(0));
475475
}
476476

477-
List<Slot> boundInAggOutput = analyzer.bindSlotByScope(unboundSlot, aggOutputScope);
477+
List<Expression> boundInAggOutput = analyzer.bindSlotByScope(unboundSlot, aggOutputScope);
478478
if (!boundInAggOutput.isEmpty()) {
479479
return ImmutableList.of(boundInAggOutput.get(0));
480480
}
@@ -553,7 +553,7 @@ private LogicalHaving<Plan> bindHavingByScopes(
553553
SimpleExprAnalyzer analyzer = buildCustomSlotBinderAnalyzer(
554554
having, cascadesContext, defaultScope, false, true,
555555
(self, unboundSlot) -> {
556-
List<Slot> slots = self.bindSlotByScope(unboundSlot, defaultScope);
556+
List<Expression> slots = self.bindSlotByScope(unboundSlot, defaultScope);
557557
if (!slots.isEmpty()) {
558558
return slots;
559559
}
@@ -1006,7 +1006,7 @@ private void bindQualifyByProject(LogicalProject<? extends Plan> project, Cascad
10061006
SimpleExprAnalyzer analyzer = buildCustomSlotBinderAnalyzer(
10071007
qualify, cascadesContext, defaultScope.get(), true, true,
10081008
(self, unboundSlot) -> {
1009-
List<Slot> slots = self.bindSlotByScope(unboundSlot, defaultScope.get());
1009+
List<Expression> slots = self.bindSlotByScope(unboundSlot, defaultScope.get());
10101010
if (!slots.isEmpty()) {
10111011
return slots;
10121012
}
@@ -1044,11 +1044,11 @@ private void bindQualifyByAggregate(Aggregate<? extends Plan> aggregate, Cascade
10441044
Scope groupBySlotsScope = toScope(cascadesContext, groupBySlots.build());
10451045

10461046
return (analyzer, unboundSlot) -> {
1047-
List<Slot> boundInGroupBy = analyzer.bindSlotByScope(unboundSlot, groupBySlotsScope);
1047+
List<Expression> boundInGroupBy = analyzer.bindSlotByScope(unboundSlot, groupBySlotsScope);
10481048
if (!boundInGroupBy.isEmpty()) {
10491049
return ImmutableList.of(boundInGroupBy.get(0));
10501050
}
1051-
List<Slot> boundInAggOutput = analyzer.bindSlotByScope(unboundSlot, aggOutputScope);
1051+
List<Expression> boundInAggOutput = analyzer.bindSlotByScope(unboundSlot, aggOutputScope);
10521052
if (!boundInAggOutput.isEmpty()) {
10531053
return ImmutableList.of(boundInAggOutput.get(0));
10541054
}
@@ -1368,15 +1368,15 @@ private List<Expression> bindGroupBy(
13681368
// see: https://github.com/apache/doris/pull/15240
13691369
//
13701370
// first, try to bind by agg.child.output
1371-
List<Slot> slotsInChildren = self.bindExactSlotsByThisScope(unboundSlot, childOutputScope);
1371+
List<Expression> slotsInChildren = self.bindExactSlotsByThisScope(unboundSlot, childOutputScope);
13721372
if (slotsInChildren.size() == 1) {
13731373
// bind succeed
13741374
return slotsInChildren;
13751375
}
13761376
// second, bind failed:
13771377
// if the slot not found, or more than one candidate slots found in agg.child.output,
13781378
// then try to bind by agg.output
1379-
List<Slot> slotsInOutput = self.bindExactSlotsByThisScope(
1379+
List<Expression> slotsInOutput = self.bindExactSlotsByThisScope(
13801380
unboundSlot, aggOutputScopeWithoutAggFun.get());
13811381
if (slotsInOutput.isEmpty()) {
13821382
// if slotsInChildren.size() > 1 && slotsInOutput.isEmpty(),
@@ -1385,7 +1385,7 @@ private List<Expression> bindGroupBy(
13851385
}
13861386

13871387
Builder<Expression> useOutputExpr = ImmutableList.builderWithExpectedSize(slotsInOutput.size());
1388-
for (Slot slotInOutput : slotsInOutput) {
1388+
for (Expression slotInOutput : slotsInOutput) {
13891389
// mappingSlot is provided by aggOutputScopeWithoutAggFun
13901390
// and no non-MappingSlot slot exist in the Scope, so we
13911391
// can direct cast it safely
@@ -1476,7 +1476,7 @@ private Plan bindSortWithoutSetOperation(MatchingContext<LogicalSort<Plan>> ctx)
14761476
sort, cascadesContext, inputScope, true, false,
14771477
(self, unboundSlot) -> {
14781478
// first, try to bind slot in Scope(input.output)
1479-
List<Slot> slotsInInput = self.bindExactSlotsByThisScope(unboundSlot, inputScope);
1479+
List<Expression> slotsInInput = self.bindExactSlotsByThisScope(unboundSlot, inputScope);
14801480
if (!slotsInInput.isEmpty()) {
14811481
// bind succeed
14821482
return ImmutableList.of(slotsInInput.get(0));
@@ -1678,7 +1678,7 @@ private SimpleExprAnalyzer getAnalyzerForOrderByAggFunc(Plan finalInput, Cascade
16781678
sort, cascadesContext, inputScope, true, false,
16791679
(analyzer, unboundSlot) -> {
16801680
if (finalInput instanceof LogicalAggregate) {
1681-
List<Slot> boundInOutputWithoutAggFunc = analyzer.bindSlotByScope(unboundSlot,
1681+
List<Expression> boundInOutputWithoutAggFunc = analyzer.bindSlotByScope(unboundSlot,
16821682
outputWithoutAggFunc);
16831683
if (!boundInOutputWithoutAggFunc.isEmpty()) {
16841684
return ImmutableList.of(boundInOutputWithoutAggFunc.get(0));

fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java

Lines changed: 142 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import org.apache.doris.nereids.trees.expressions.CaseWhen;
5050
import org.apache.doris.nereids.trees.expressions.Cast;
5151
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
52+
import org.apache.doris.nereids.trees.expressions.DereferenceExpression;
5253
import org.apache.doris.nereids.trees.expressions.Divide;
5354
import org.apache.doris.nereids.trees.expressions.EqualTo;
5455
import org.apache.doris.nereids.trees.expressions.ExprId;
@@ -74,7 +75,9 @@
7475
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
7576
import org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
7677
import org.apache.doris.nereids.trees.expressions.functions.agg.SupportMultiDistinct;
78+
import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt;
7779
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
80+
import org.apache.doris.nereids.trees.expressions.functions.scalar.StructElement;
7881
import org.apache.doris.nereids.trees.expressions.functions.udf.AliasUdfBuilder;
7982
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdaf;
8083
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdf;
@@ -93,6 +96,8 @@
9396
import org.apache.doris.nereids.types.BooleanType;
9497
import org.apache.doris.nereids.types.DataType;
9598
import org.apache.doris.nereids.types.StringType;
99+
import org.apache.doris.nereids.types.StructField;
100+
import org.apache.doris.nereids.types.StructType;
96101
import org.apache.doris.nereids.types.TinyIntType;
97102
import org.apache.doris.nereids.util.ExpressionUtils;
98103
import org.apache.doris.nereids.util.TypeCoercionUtils;
@@ -240,6 +245,25 @@ public Expression visitUnboundAlias(UnboundAlias unboundAlias, ExpressionRewrite
240245
}
241246
}
242247

248+
@Override
249+
public Expression visitDereferenceExpression(DereferenceExpression dereferenceExpression,
250+
ExpressionRewriteContext context) {
251+
Expression expression = dereferenceExpression.child(0).accept(this, context);
252+
DataType dataType = expression.getDataType();
253+
if (dataType.isStructType()) {
254+
StructType structType = (StructType) dataType;
255+
StructField field = structType.getField(dereferenceExpression.fieldName);
256+
if (field != null) {
257+
return new StructElement(expression, dereferenceExpression.child(1));
258+
}
259+
} else if (dataType.isMapType()) {
260+
return new ElementAt(expression, dereferenceExpression.child(1));
261+
} else if (dataType.isVariantType()) {
262+
return new ElementAt(expression, dereferenceExpression.child(1));
263+
}
264+
throw new AnalysisException("Can not dereference field: " + dereferenceExpression.fieldName);
265+
}
266+
243267
@Override
244268
public Expression visitUnboundSlot(UnboundSlot unboundSlot, ExpressionRewriteContext context) {
245269
Optional<Scope> outerScope = getScope().getOuterScope();
@@ -913,13 +937,13 @@ protected List<? extends Expression> bindSlotByThisScope(UnboundSlot unboundSlot
913937
return bindSlotByScope(unboundSlot, getScope());
914938
}
915939

916-
protected List<Slot> bindExactSlotsByThisScope(UnboundSlot unboundSlot, Scope scope) {
917-
List<Slot> candidates = bindSlotByScope(unboundSlot, scope);
940+
protected List<Expression> bindExactSlotsByThisScope(UnboundSlot unboundSlot, Scope scope) {
941+
List<Expression> candidates = bindSlotByScope(unboundSlot, scope);
918942
if (candidates.size() == 1) {
919943
return candidates;
920944
}
921-
List<Slot> extractSlots = Utils.filterImmutableList(candidates, bound ->
922-
unboundSlot.getNameParts().size() == bound.getQualifier().size() + 1
945+
List<Expression> extractSlots = Utils.filterImmutableList(candidates, bound ->
946+
bound instanceof Slot && unboundSlot.getNameParts().size() == ((Slot) bound).getQualifier().size() + 1
923947
);
924948
// we should return origin candidates slots if extract slots is empty,
925949
// and then throw an ambiguous exception
@@ -938,33 +962,137 @@ private List<Slot> addSqlIndexInfo(List<Slot> slots, Optional<Pair<Integer, Inte
938962
}
939963

940964
/** bindSlotByScope */
941-
public List<Slot> bindSlotByScope(UnboundSlot unboundSlot, Scope scope) {
965+
public List<Expression> bindSlotByScope(UnboundSlot unboundSlot, Scope scope) {
942966
List<String> nameParts = unboundSlot.getNameParts();
943967
Optional<Pair<Integer, Integer>> idxInSql = unboundSlot.getIndexInSqlString();
944968
int namePartSize = nameParts.size();
945969
switch (namePartSize) {
946970
// column
947971
case 1: {
948-
return addSqlIndexInfo(bindSingleSlotByName(nameParts.get(0), scope), idxInSql);
972+
return (List<Expression>) bindExpressionByColumn(unboundSlot, nameParts, idxInSql, scope);
949973
}
950974
// table.column
951975
case 2: {
952-
return addSqlIndexInfo(bindSingleSlotByTable(nameParts.get(0), nameParts.get(1), scope), idxInSql);
976+
return (List<Expression>) bindExpressionByTableColumn(unboundSlot, nameParts, idxInSql, scope);
953977
}
954978
// db.table.column
955979
case 3: {
956-
return addSqlIndexInfo(bindSingleSlotByDb(nameParts.get(0), nameParts.get(1), nameParts.get(2), scope),
957-
idxInSql);
980+
return (List<Expression>) bindExpressionByDbTableColumn(unboundSlot, nameParts, idxInSql, scope);
958981
}
959982
// catalog.db.table.column
960-
case 4: {
961-
return addSqlIndexInfo(bindSingleSlotByCatalog(
962-
nameParts.get(0), nameParts.get(1), nameParts.get(2), nameParts.get(3), scope), idxInSql);
963-
}
964983
default: {
965-
throw new AnalysisException("Not supported name: " + StringUtils.join(nameParts, "."));
984+
return (List<Expression>) bindExpressionByCatalogDbTableColumn(unboundSlot, nameParts, idxInSql, scope);
985+
}
986+
}
987+
}
988+
989+
private List<? extends Expression> bindExpressionByCatalogDbTableColumn(
990+
UnboundSlot unboundSlot, List<String> nameParts, Optional<Pair<Integer, Integer>> idxInSql, Scope scope) {
991+
List<Slot> slots = addSqlIndexInfo(bindSingleSlotByCatalog(
992+
nameParts.get(0), nameParts.get(1), nameParts.get(2), nameParts.get(3), scope), idxInSql);
993+
if (slots.isEmpty()) {
994+
return bindExpressionByDbTableColumn(unboundSlot, nameParts, idxInSql, scope);
995+
} else if (slots.size() > 1) {
996+
return slots;
997+
}
998+
if (nameParts.size() == 4) {
999+
return slots;
1000+
}
1001+
1002+
Optional<Expression> expression = bindNestedFields(
1003+
unboundSlot, slots.get(0), nameParts.subList(4, nameParts.size())
1004+
);
1005+
if (!expression.isPresent()) {
1006+
return slots;
1007+
}
1008+
return ImmutableList.of(expression.get());
1009+
}
1010+
1011+
private List<? extends Expression> bindExpressionByDbTableColumn(
1012+
UnboundSlot unboundSlot, List<String> nameParts, Optional<Pair<Integer, Integer>> idxInSql, Scope scope) {
1013+
List<Slot> slots = addSqlIndexInfo(
1014+
bindSingleSlotByDb(nameParts.get(0), nameParts.get(1), nameParts.get(2), scope), idxInSql);
1015+
if (slots.isEmpty()) {
1016+
return bindExpressionByTableColumn(unboundSlot, nameParts, idxInSql, scope);
1017+
} else if (slots.size() > 1) {
1018+
return slots;
1019+
}
1020+
if (nameParts.size() == 3) {
1021+
return slots;
1022+
}
1023+
1024+
Optional<Expression> expression = bindNestedFields(
1025+
unboundSlot, slots.get(0), nameParts.subList(3, nameParts.size())
1026+
);
1027+
if (!expression.isPresent()) {
1028+
return slots;
1029+
}
1030+
return ImmutableList.of(expression.get());
1031+
}
1032+
1033+
private List<? extends Expression> bindExpressionByTableColumn(
1034+
UnboundSlot unboundSlot, List<String> nameParts, Optional<Pair<Integer, Integer>> idxInSql, Scope scope) {
1035+
List<Slot> slots = addSqlIndexInfo(bindSingleSlotByTable(nameParts.get(0), nameParts.get(1), scope), idxInSql);
1036+
if (slots.isEmpty()) {
1037+
return bindExpressionByColumn(unboundSlot, nameParts, idxInSql, scope);
1038+
} else if (slots.size() > 1) {
1039+
return slots;
1040+
}
1041+
if (nameParts.size() == 2) {
1042+
return slots;
1043+
}
1044+
1045+
Optional<Expression> expression = bindNestedFields(
1046+
unboundSlot, slots.get(0), nameParts.subList(2, nameParts.size())
1047+
);
1048+
if (!expression.isPresent()) {
1049+
return slots;
1050+
}
1051+
return ImmutableList.of(expression.get());
1052+
}
1053+
1054+
private List<? extends Expression> bindExpressionByColumn(
1055+
UnboundSlot unboundSlot, List<String> nameParts, Optional<Pair<Integer, Integer>> idxInSql, Scope scope) {
1056+
List<Slot> slots = addSqlIndexInfo(bindSingleSlotByName(nameParts.get(0), scope), idxInSql);
1057+
if (slots.size() != 1) {
1058+
return slots;
1059+
}
1060+
if (nameParts.size() == 1) {
1061+
return slots;
1062+
}
1063+
Optional<Expression> expression = bindNestedFields(
1064+
unboundSlot, slots.get(0), nameParts.subList(1, nameParts.size())
1065+
);
1066+
if (!expression.isPresent()) {
1067+
return slots;
1068+
}
1069+
return ImmutableList.of(expression.get());
1070+
}
1071+
1072+
private Optional<Expression> bindNestedFields(UnboundSlot unboundSlot, Slot slot, List<String> fieldNames) {
1073+
Expression expression = slot;
1074+
String lastFieldName = slot.getName();
1075+
for (String fieldName : fieldNames) {
1076+
DataType dataType = expression.getDataType();
1077+
if (dataType.isStructType()) {
1078+
StructType structType = (StructType) dataType;
1079+
StructField field = structType.getField(fieldName);
1080+
if (field == null) {
1081+
throw new AnalysisException("No such struct field '" + fieldName + "' in '" + lastFieldName + "'");
1082+
}
1083+
lastFieldName = fieldName;
1084+
expression = new StructElement(expression, new StringLiteral(fieldName));
1085+
continue;
1086+
} else if (dataType.isMapType()) {
1087+
expression = new ElementAt(expression, new StringLiteral(fieldName));
1088+
continue;
1089+
} else if (dataType.isVariantType()) {
1090+
expression = new ElementAt(expression, new StringLiteral(fieldName));
1091+
continue;
9661092
}
1093+
throw new AnalysisException("No such field '" + fieldName + "' in '" + lastFieldName + "'");
9671094
}
1095+
return Optional.of(new Alias(expression));
9681096
}
9691097

9701098
public static boolean sameTableName(String boundSlot, String unboundSlot) {
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package org.apache.doris.nereids.trees.expressions;
19+
20+
import org.apache.doris.nereids.analyzer.Unbound;
21+
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
22+
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
23+
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
24+
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
25+
26+
import com.google.common.collect.ImmutableList;
27+
28+
/** DereferenceExpression*/
29+
public class DereferenceExpression extends Expression implements BinaryExpression, PropagateNullable, Unbound {
30+
public final String fieldName;
31+
32+
public DereferenceExpression(Expression expression, StringLiteral fieldName) {
33+
super(ImmutableList.of(expression, fieldName));
34+
this.fieldName = fieldName.getValue();
35+
}
36+
37+
@Override
38+
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
39+
return visitor.visitDereferenceExpression(this, context);
40+
}
41+
}

0 commit comments

Comments
 (0)