Skip to content

Commit

Permalink
[feat](Nereids) support nereids hint position detaction (#39113)
Browse files Browse the repository at this point in the history
When use hint in wrong position or use unsupport hint,
use channel(2) to filter it out
  • Loading branch information
LiBinfeng-01 committed Aug 15, 2024
1 parent 2168f09 commit f514a7e
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 46 deletions.
18 changes: 18 additions & 0 deletions fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,19 @@ lexer grammar DorisLexer;
public void markUnclosedComment() {
has_unclosed_bracketed_comment = true;
}

// This variable will hold the external state
private boolean channel2;

// Method to set the external state
public void setChannel2(boolean value) {
this.channel2 = value;
}

// Method to decide the channel based on external state
private boolean isChannel2() {
return this.channel2;
}
}

SEMICOLON: ';';
Expand Down Expand Up @@ -654,6 +667,11 @@ BRACKETED_COMMENT
: '/*' {!isHint()}? ( BRACKETED_COMMENT | . )*? ('*/' | {markUnclosedComment();} EOF) -> channel(HIDDEN)
;

HINT_WITH_CHANNEL
: {isChannel2()}? HINT_START .*? HINT_END -> channel(2)
;


FROM_DUAL
: 'FROM' WS+ 'DUAL' -> channel(HIDDEN);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,12 @@
@SuppressWarnings({"OptionalUsedAsFieldOrParameterType", "OptionalGetWithoutIsPresent"})
public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {

private final Map<Integer, ParserRuleContext> selectHintMap;

public LogicalPlanBuilder(Map<Integer, ParserRuleContext> selectHintMap) {
this.selectHintMap = selectHintMap;
}

@SuppressWarnings("unchecked")
protected <T> T typedVisit(ParseTree ctx) {
return (T) ctx.accept(this);
Expand Down Expand Up @@ -604,7 +610,16 @@ public LogicalPlan visitRegularQuerySpecification(RegularQuerySpecificationConte
Optional.ofNullable(ctx.aggClause()),
Optional.ofNullable(ctx.havingClause()));
selectPlan = withQueryOrganization(selectPlan, ctx.queryOrganization());
return withSelectHint(selectPlan, selectCtx.selectHint());
if ((selectHintMap == null) || selectHintMap.isEmpty()) {
return selectPlan;
}
List<ParserRuleContext> selectHintContexts = Lists.newArrayList();
for (Integer key : selectHintMap.keySet()) {
if (key > selectCtx.getStart().getStopIndex() && key < selectCtx.getStop().getStartIndex()) {
selectHintContexts.add(selectHintMap.get(key));
}
}
return withSelectHint(selectPlan, selectHintContexts);
});
}

Expand Down Expand Up @@ -1785,47 +1800,70 @@ private LogicalPlan withJoinRelations(LogicalPlan input, RelationContext ctx) {
return last;
}

private LogicalPlan withSelectHint(LogicalPlan logicalPlan, SelectHintContext hintContext) {
if (hintContext == null) {
private LogicalPlan withSelectHint(LogicalPlan logicalPlan, List<ParserRuleContext> hintContexts) {
if (hintContexts.isEmpty()) {
return logicalPlan;
}
Map<String, SelectHint> hints = Maps.newLinkedHashMap();
for (HintStatementContext hintStatement : hintContext.hintStatements) {
String hintName = hintStatement.hintName.getText().toLowerCase(Locale.ROOT);
switch (hintName) {
case "set_var":
Map<String, Optional<String>> parameters = Maps.newLinkedHashMap();
for (HintAssignmentContext kv : hintStatement.parameters) {
if (kv.key != null) {
String parameterName = visitIdentifierOrText(kv.key);
Optional<String> value = Optional.empty();
if (kv.constantValue != null) {
Literal literal = (Literal) visit(kv.constantValue);
value = Optional.ofNullable(literal.toLegacyLiteral().getStringValue());
} else if (kv.identifierValue != null) {
// maybe we should throw exception when the identifierValue is quoted identifier
value = Optional.ofNullable(kv.identifierValue.getText());
for (ParserRuleContext hintContext : hintContexts) {
SelectHintContext selectHintContext = (SelectHintContext) hintContext;
for (HintStatementContext hintStatement : selectHintContext.hintStatements) {
String hintName = hintStatement.hintName.getText().toLowerCase(Locale.ROOT);
switch (hintName) {
case "set_var":
Map<String, Optional<String>> parameters = Maps.newLinkedHashMap();
for (HintAssignmentContext kv : hintStatement.parameters) {
if (kv.key != null) {
String parameterName = visitIdentifierOrText(kv.key);
Optional<String> value = Optional.empty();
if (kv.constantValue != null) {
Literal literal = (Literal) visit(kv.constantValue);
value = Optional.ofNullable(literal.toLegacyLiteral().getStringValue());
} else if (kv.identifierValue != null) {
// maybe we should throw exception when the identifierValue is quoted identifier
value = Optional.ofNullable(kv.identifierValue.getText());
}
parameters.put(parameterName, value);
}
parameters.put(parameterName, value);
}
}
hints.put(hintName, new SelectHintSetVar(hintName, parameters));
break;
case "leading":
List<String> leadingParameters = new ArrayList<String>();
for (HintAssignmentContext kv : hintStatement.parameters) {
if (kv.key != null) {
hints.put(hintName, new SelectHintSetVar(hintName, parameters));
break;
case "leading":
List<String> leadingParameters = new ArrayList<String>();
for (HintAssignmentContext kv : hintStatement.parameters) {
if (kv.key != null) {
String parameterName = visitIdentifierOrText(kv.key);
leadingParameters.add(parameterName);
}
}
hints.put(hintName, new SelectHintLeading(hintName, leadingParameters));
break;
case "ordered":
hints.put(hintName, new SelectHintOrdered(hintName));
break;
case "use_cbo_rule":
List<String> useRuleParameters = new ArrayList<String>();
for (HintAssignmentContext kv : hintStatement.parameters) {
if (kv.key != null) {
String parameterName = visitIdentifierOrText(kv.key);
useRuleParameters.add(parameterName);
}
}
hints.put(hintName, new SelectHintUseCboRule(hintName, useRuleParameters, false));
break;
case "no_use_cbo_rule":
List<String> noUseRuleParameters = new ArrayList<String>();
for (HintAssignmentContext kv : hintStatement.parameters) {
String parameterName = visitIdentifierOrText(kv.key);
leadingParameters.add(parameterName);
if (kv.key != null) {
noUseRuleParameters.add(parameterName);
}
}
}
hints.put(hintName, new SelectHintLeading(hintName, leadingParameters));
break;
case "ordered":
hints.put(hintName, new SelectHintOrdered(hintName));
break;
default:
break;
hints.put(hintName, new SelectHintUseCboRule(hintName, noUseRuleParameters, true));
break;
default:
break;
}
}
}
return new LogicalSelectHint<>(hints, logicalPlan);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.atn.PredictionMode;
import org.antlr.v4.runtime.misc.ParseCancellationException;

import java.util.List;
import java.util.Map;
import java.util.function.Function;

/**
Expand Down Expand Up @@ -80,12 +83,41 @@ public List<String> parseDataType(String dataType) {

private <T> T parse(String sql, Function<DorisParser, ParserRuleContext> parseFunction) {
ParserRuleContext tree = toAst(sql, parseFunction);
LogicalPlanBuilder logicalPlanBuilder = new LogicalPlanBuilder();
LogicalPlanBuilder logicalPlanBuilder = new LogicalPlanBuilder(getHintMap(sql, DorisParser::selectHint));
return (T) logicalPlanBuilder.visit(tree);
}

/** get hint map */
public static Map<Integer, ParserRuleContext> getHintMap(String sql,
Function<DorisParser, ParserRuleContext> parseFunction) {
// parse hint first round
DorisLexer hintLexer = new DorisLexer(new CaseInsensitiveStream(CharStreams.fromString(sql)));
hintLexer.setChannel2(true);
CommonTokenStream hintTokenStream = new CommonTokenStream(hintLexer);

Map<Integer, ParserRuleContext> selectHintMap = Maps.newHashMap();

Token hintToken = hintTokenStream.getTokenSource().nextToken();
while (hintToken != null && hintToken.getType() != DorisLexer.EOF) {
int tokenType = hintToken.getType();
if (tokenType == DorisLexer.HINT_WITH_CHANNEL) {
String hintSql = sql.substring(hintToken.getStartIndex(), hintToken.getStopIndex() + 1);
DorisLexer newHintLexer = new DorisLexer(new CaseInsensitiveStream(CharStreams.fromString(hintSql)));
newHintLexer.setChannel2(false);
CommonTokenStream newHintTokenStream = new CommonTokenStream(newHintLexer);
DorisParser hintParser = new DorisParser(newHintTokenStream);
ParserRuleContext hintContext = parseFunction.apply(hintParser);
selectHintMap.put(hintToken.getStartIndex(), hintContext);
}
hintToken = hintTokenStream.getTokenSource().nextToken();
}
return selectHintMap;
}

/** toAst */
private ParserRuleContext toAst(String sql, Function<DorisParser, ParserRuleContext> parseFunction) {
DorisLexer lexer = new DorisLexer(new CaseInsensitiveStream(CharStreams.fromString(sql)));
lexer.setChannel2(true);
CommonTokenStream tokenStream = new CommonTokenStream(lexer);
DorisParser parser = new DorisParser(tokenStream);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,16 +341,10 @@ public void testJoinHint() {
parsePlan("select * from t1 join [broadcast] t2 on t1.key1=t2.key1")
.matches(logicalJoin().when(j -> j.getHint() == JoinHint.BROADCAST_RIGHT));

parsePlan("select * from t1 join /*+ broadcast */ t2 on t1.key1=t2.key1")
.matches(logicalJoin().when(j -> j.getHint() == JoinHint.BROADCAST_RIGHT));

// invalid hint position
parsePlan("select * from [shuffle] t1 join t2 on t1.key1=t2.key1")
.assertThrowsExactly(ParseException.class);

parsePlan("select * from /*+ shuffle */ t1 join t2 on t1.key1=t2.key1")
.assertThrowsExactly(ParseException.class);

// invalid hint content
parsePlan("select * from t1 join [bucket] t2 on t1.key1=t2.key1")
.assertThrowsExactly(ParseException.class)
Expand All @@ -361,8 +355,6 @@ public void testJoinHint() {
+ "----------------------^^^");

// invalid multiple hints
parsePlan("select * from t1 join /*+ shuffle , broadcast */ t2 on t1.key1=t2.key1")
.assertThrowsExactly(ParseException.class);

parsePlan("select * from t1 join [shuffle,broadcast] t2 on t1.key1=t2.key1")
.assertThrowsExactly(ParseException.class);
Expand Down
78 changes: 78 additions & 0 deletions regression-test/data/nereids_p0/hint/test_hint.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select1_1 --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute[DistributionSpecGather]
------hashAgg[LOCAL]
--------hashJoin[INNER_JOIN broadcast] hashCondition=((t1.c1 = t2.c2)) otherCondition=()
----------PhysicalOlapScan[t2]
----------PhysicalOlapScan[t1]

Hint log:
Used: leading(t2 broadcast t1 )
UnUsed:
SyntaxError:

-- !select1_2 --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute[DistributionSpecGather]
------hashAgg[LOCAL]
--------PhysicalStorageLayerAggregate[t1]

-- !select1_3 --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute[DistributionSpecGather]
------hashAgg[LOCAL]
--------PhysicalStorageLayerAggregate[t1]

-- !select1_4 --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute[DistributionSpecGather]
------hashAgg[LOCAL]
--------PhysicalStorageLayerAggregate[t1]

-- !select1_5 --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute[DistributionSpecGather]
------hashAgg[LOCAL]
--------hashJoin[INNER_JOIN broadcast] hashCondition=((t1.c1 = t2.c2)) otherCondition=()
----------PhysicalOlapScan[t2]
----------PhysicalOlapScan[t1]

Hint log:
Used: leading(t2 broadcast t1 )
UnUsed:
SyntaxError:

-- !select1_6 --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute[DistributionSpecGather]
------hashAgg[LOCAL]
--------hashJoin[INNER_JOIN broadcast] hashCondition=((t1.c1 = t2.c2)) otherCondition=()
----------PhysicalOlapScan[t2]
----------PhysicalOlapScan[t1]

Hint log:
Used: leading(t2 broadcast t1 )
UnUsed:
SyntaxError:

-- !select1_7 --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute[DistributionSpecGather]
------hashAgg[LOCAL]
--------PhysicalStorageLayerAggregate[t1]

-- !select1_8 --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute[DistributionSpecGather]
------hashAgg[LOCAL]
--------PhysicalStorageLayerAggregate[t1]

4 changes: 2 additions & 2 deletions regression-test/data/nereids_p0/hint/test_leading.out
Original file line number Diff line number Diff line change
Expand Up @@ -2538,8 +2538,8 @@ PhysicalResultSink
------------PhysicalOlapScan[t3]

Hint log:
Used: leading(t1 broadcast t2 t3 )
UnUsed:
Used: leading(t1 broadcast t2 broadcast t3 )
UnUsed:
SyntaxError:

-- !select95_4 --
Expand Down
Loading

0 comments on commit f514a7e

Please sign in to comment.