diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/WhereCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/WhereCommandIT.java index 4dfd75455dc..c7a47d42529 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/WhereCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/WhereCommandIT.java @@ -44,7 +44,7 @@ public void testWhereWithMultiLogicalExpr() throws IOException { executeQuery( String.format( "source=%s " - + "| where firstname='Amber' lastname='Duke' age=32 " + + "| where firstname='Amber' and lastname='Duke' and age=32 " + "| fields firstname, lastname, age", TEST_INDEX_ACCOUNT)); verifyDataRows(result, rows("Amber", "Duke", 32)); diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index dcc71c1d061..dfd655177ab 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -14,12 +14,8 @@ root // statement pplStatement - : dmlStatement - ; - -dmlStatement - : queryStatement - | explainStatement + : explainStatement + | queryStatement ; queryStatement @@ -43,9 +39,9 @@ subSearch // commands pplCommands - : searchCommand - | describeCommand + : describeCommand | showDataSourcesCommand + | searchCommand ; commands @@ -106,9 +102,7 @@ commandName ; searchCommand - : (SEARCH)? fromClause # searchFrom - | (SEARCH)? fromClause logicalExpression # searchFromFilter - | (SEARCH)? logicalExpression fromClause # searchFilterFrom + : (SEARCH)? (logicalExpression)* fromClause (logicalExpression)* # searchFrom ; describeCommand @@ -373,7 +367,7 @@ sortbyClause ; evalClause - : fieldExpression EQUAL expression + : fieldExpression EQUAL logicalExpression ; eventstatsAggTerm @@ -447,68 +441,52 @@ numericLiteral | floatLiteral ; -// expressions -expression - : logicalExpression - | comparisonExpression - | valueExpression - ; - // predicates logicalExpression - : LT_PRTHS logicalExpression RT_PRTHS # parentheticLogicalExpr - | NOT logicalExpression # logicalNot - | left = logicalExpression (AND)? right = logicalExpression # logicalAnd + : NOT logicalExpression # logicalNot + | left = logicalExpression AND right = logicalExpression # logicalAnd | left = logicalExpression XOR right = logicalExpression # logicalXor | left = logicalExpression OR right = logicalExpression # logicalOr - | comparisonExpression # comparsion - | booleanExpression # booleanExpr - | relevanceExpression # relevanceExpr - ; - -comparisonExpression - : left = valueExpression comparisonOperator right = valueExpression # compareExpr - | valueExpression NOT? IN valueList # inExpr - | valueExpression NOT? BETWEEN valueExpression AND valueExpression # between + | expression # logicalExpr ; -valueExpressionList - : valueExpression - | LT_PRTHS valueExpression (COMMA valueExpression)* RT_PRTHS +expression + : valueExpression # valueExpr + | relevanceExpression # relevanceExpr + | left = expression comparisonOperator right = expression # compareExpr + | expression NOT? IN valueList # inExpr + | expression NOT? BETWEEN expression AND expression # between ; valueExpression - : left = valueExpression binaryOperator = (STAR | DIVIDE | MODULE) right = valueExpression # binaryArithmetic - | left = valueExpression binaryOperator = (PLUS | MINUS) right = valueExpression # binaryArithmetic - | primaryExpression # valueExpressionDefault - | positionFunction # positionFunctionCall - | caseFunction # caseExpr - | extractFunction # extractFunctionCall - | getFormatFunction # getFormatFunctionCall - | timestampFunction # timestampFunctionCall - | LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr - | LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr - | lambda # lambdaExpr - ; - -primaryExpression + : left = valueExpression binaryOperator = (STAR | DIVIDE | MODULE) right = valueExpression # binaryArithmetic + | left = valueExpression binaryOperator = (PLUS | MINUS) right = valueExpression # binaryArithmetic + | literalValue # literalValueExpr + | functionCall # functionCallExpr + | lambda # lambdaExpr + | LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr + | valueExpression NOT? IN LT_SQR_PRTHS subSearch RT_SQR_PRTHS # inSubqueryExpr + | LT_PRTHS valueExpression (COMMA valueExpression)* RT_PRTHS NOT? IN LT_SQR_PRTHS subSearch RT_SQR_PRTHS # inSubqueryExpr + | EXISTS LT_SQR_PRTHS subSearch RT_SQR_PRTHS # existsSubqueryExpr + | fieldExpression # fieldExpr + | LT_PRTHS logicalExpression RT_PRTHS # nestedValueExpr + ; + +functionCall : evalFunctionCall | dataTypeFunctionCall - | fieldExpression - | literalValue + | positionFunctionCall + | caseFunctionCall + | timestampFunctionCall + | extractFunctionCall + | getFormatFunctionCall ; -positionFunction +positionFunctionCall : positionFunctionName LT_PRTHS functionArg IN functionArg RT_PRTHS ; -booleanExpression - : booleanFunctionCall # booleanFunctionCallExpr - | valueExpressionList NOT? IN LT_SQR_PRTHS subSearch RT_SQR_PRTHS # inSubqueryExpr - | EXISTS LT_SQR_PRTHS subSearch RT_SQR_PRTHS # existsSubqueryExpr - ; - -caseFunction +caseFunctionCall : CASE LT_PRTHS logicalExpression COMMA valueExpression (COMMA logicalExpression COMMA valueExpression)* (ELSE valueExpression)? RT_PRTHS ; @@ -573,12 +551,7 @@ evalFunctionCall // cast function dataTypeFunctionCall - : CAST LT_PRTHS expression AS convertedDataType RT_PRTHS - ; - -// boolean functions -booleanFunctionCall - : conditionFunctionName LT_PRTHS functionArgs RT_PRTHS + : CAST LT_PRTHS logicalExpression AS convertedDataType RT_PRTHS ; convertedDataType @@ -621,12 +594,12 @@ functionArg functionArgExpression : lambda - | expression + | logicalExpression ; lambda - : ident ARROW expression - | LT_PRTHS ident (COMMA ident)+ RT_PRTHS ARROW expression + : ident ARROW logicalExpression + | LT_PRTHS ident (COMMA ident)+ RT_PRTHS ARROW logicalExpression ; relevanceArg @@ -837,7 +810,7 @@ dateTimeFunctionName | YEARWEEK ; -getFormatFunction +getFormatFunctionCall : GET_FORMAT LT_PRTHS getFormatType COMMA functionArg RT_PRTHS ; @@ -848,7 +821,7 @@ getFormatType | TIMESTAMP ; -extractFunction +extractFunctionCall : EXTRACT LT_PRTHS datetimePart FROM functionArg RT_PRTHS ; @@ -883,7 +856,7 @@ datetimePart | complexDateTimePart ; -timestampFunction +timestampFunctionCall : timestampFunctionName LT_PRTHS simpleDateTimePart COMMA firstArg = functionArg COMMA secondArg = functionArg RT_PRTHS ; diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 1e91fa9df09..2c5907b19a7 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -15,9 +15,7 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldsCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.HeadCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.RenameCommandContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SearchFilterFromContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SearchFromContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SearchFromFilterContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.StatsCommandContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.TableFunctionContext; @@ -42,6 +40,7 @@ import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.AllFieldsExcludeMeta; +import org.opensearch.sql.ast.expression.And; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Let; @@ -139,19 +138,16 @@ public UnresolvedPlan visitSubSearch(OpenSearchPPLParser.SubSearchContext ctx) { /** Search command. */ @Override public UnresolvedPlan visitSearchFrom(SearchFromContext ctx) { - return visitFromClause(ctx.fromClause()); - } - - @Override - public UnresolvedPlan visitSearchFromFilter(SearchFromFilterContext ctx) { - return new Filter(internalVisitExpression(ctx.logicalExpression())) - .attach(visit(ctx.fromClause())); - } - - @Override - public UnresolvedPlan visitSearchFilterFrom(SearchFilterFromContext ctx) { - return new Filter(internalVisitExpression(ctx.logicalExpression())) - .attach(visit(ctx.fromClause())); + if (ctx.logicalExpression().isEmpty()) { + return visitFromClause(ctx.fromClause()); + } else { + return new Filter( + ctx.logicalExpression().stream() + .map(this::internalVisitExpression) + .reduce(And::new) + .get()) + .attach(visit(ctx.fromClause())); + } } /** diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 006915e47f8..90e4742469b 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -9,7 +9,6 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.POSITION; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BinaryArithmeticContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BySpanClauseContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CompareExprContext; @@ -34,7 +33,6 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalOrContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalXorContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.MultiFieldRelevanceFunctionContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ParentheticValueExprContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SingleFieldRelevanceFunctionContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortFieldContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SpanClauseContext; @@ -90,7 +88,7 @@ public AstExpressionBuilder(AstBuilder astBuilder) { /** Eval clause. */ @Override public UnresolvedExpression visitEvalClause(EvalClauseContext ctx) { - return new Let((Field) visit(ctx.fieldExpression()), visit(ctx.expression())); + return new Let((Field) visit(ctx.fieldExpression()), visit(ctx.logicalExpression())); } /** Trendline clause. */ @@ -143,7 +141,7 @@ public UnresolvedExpression visitLambda(OpenSearchPPLParser.LambdaContext ctx) { ctx.ident().stream() .map(x -> this.visitIdentifiers(Collections.singletonList(x))) .collect(Collectors.toList()); - UnresolvedExpression function = visitExpression(ctx.expression()); + UnresolvedExpression function = visit(ctx.logicalExpression()); return new LambdaFunction(function, arguments); } @@ -157,7 +155,7 @@ public UnresolvedExpression visitCompareExpr(CompareExprContext ctx) { public UnresolvedExpression visitInExpr(InExprContext ctx) { UnresolvedExpression expr = new In( - visit(ctx.valueExpression()), + visit(ctx.expression()), ctx.valueList().literalValue().stream() .map(this::visitLiteralValue) .collect(Collectors.toList())); @@ -167,18 +165,21 @@ public UnresolvedExpression visitInExpr(InExprContext ctx) { /** Value Expression. */ @Override public UnresolvedExpression visitBinaryArithmetic(BinaryArithmeticContext ctx) { - return new Function( - ctx.binaryOperator.getText(), Arrays.asList(visit(ctx.left), visit(ctx.right))); + return new Function(ctx.binaryOperator.getText(), buildArguments(ctx.left, ctx.right)); } - @Override - public UnresolvedExpression visitParentheticValueExpr(ParentheticValueExprContext ctx) { - return visit(ctx.valueExpression()); // Discard parenthesis around + private List buildArguments( + OpenSearchPPLParser.ValueExpressionContext... ctx) { + ImmutableList.Builder builder = ImmutableList.builder(); + for (OpenSearchPPLParser.ValueExpressionContext value : ctx) { + UnresolvedExpression unresolvedExpression = visit(value); + if (unresolvedExpression != null) builder.add(unresolvedExpression); + } + return builder.build(); } @Override - public UnresolvedExpression visitParentheticLogicalExpr( - OpenSearchPPLParser.ParentheticLogicalExprContext ctx) { + public UnresolvedExpression visitNestedValueExpr(OpenSearchPPLParser.NestedValueExprContext ctx) { return visit(ctx.logicalExpression()); // Discard parenthesis around } @@ -246,32 +247,22 @@ public UnresolvedExpression visitTakeAggFunctionCall( "take", visit(ctx.takeAggFunction().fieldExpression()), builder.build()); } - /** Eval function. */ + /** Case function. */ @Override - public UnresolvedExpression visitBooleanFunctionCall(BooleanFunctionCallContext ctx) { - final String functionName = ctx.conditionFunctionName().getText().toLowerCase(Locale.ROOT); - return buildFunction( - FUNCTION_NAME_MAPPING.getOrDefault(functionName, functionName), - ctx.functionArgs().functionArg()); - } - - @Override - public UnresolvedExpression visitCaseExpr(OpenSearchPPLParser.CaseExprContext ctx) { + public UnresolvedExpression visitCaseFunctionCall( + OpenSearchPPLParser.CaseFunctionCallContext ctx) { List whens = - IntStream.range(0, ctx.caseFunction().logicalExpression().size()) + IntStream.range(0, ctx.logicalExpression().size()) .mapToObj( index -> { - UnresolvedExpression condition = - visit(ctx.caseFunction().logicalExpression(index)); - UnresolvedExpression result = visit(ctx.caseFunction().valueExpression(index)); + UnresolvedExpression condition = visit(ctx.logicalExpression(index)); + UnresolvedExpression result = visit(ctx.valueExpression(index)); return new When(condition, result); }) .collect(Collectors.toList()); UnresolvedExpression elseValue = null; - if (ctx.caseFunction().ELSE() != null) { - elseValue = - visit( - ctx.caseFunction().valueExpression(ctx.caseFunction().valueExpression().size() - 1)); + if (ctx.ELSE() != null) { + elseValue = visit(ctx.valueExpression(ctx.valueExpression().size() - 1)); } return new Case(null, whens, Optional.ofNullable(elseValue)); } @@ -279,13 +270,16 @@ public UnresolvedExpression visitCaseExpr(OpenSearchPPLParser.CaseExprContext ct /** Eval function. */ @Override public UnresolvedExpression visitEvalFunctionCall(EvalFunctionCallContext ctx) { - return buildFunction(ctx.evalFunctionName().getText(), ctx.functionArgs().functionArg()); + final String functionName = ctx.evalFunctionName().getText(); + return buildFunction( + FUNCTION_NAME_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), functionName), + ctx.functionArgs().functionArg()); } /** Cast function. */ @Override public UnresolvedExpression visitDataTypeFunctionCall(DataTypeFunctionCallContext ctx) { - return new Cast(visit(ctx.expression()), visit(ctx.convertedDataType())); + return new Cast(visit(ctx.logicalExpression()), visit(ctx.convertedDataType())); } @Override @@ -325,8 +319,8 @@ public UnresolvedExpression visitTableSource(TableSourceContext ctx) { } @Override - public UnresolvedExpression visitPositionFunction( - OpenSearchPPLParser.PositionFunctionContext ctx) { + public UnresolvedExpression visitPositionFunctionCall( + OpenSearchPPLParser.PositionFunctionCallContext ctx) { return new Function( POSITION.getName().getFunctionName(), Arrays.asList(visitFunctionArg(ctx.functionArg(0)), visitFunctionArg(ctx.functionArg(1)))); @@ -335,49 +329,46 @@ public UnresolvedExpression visitPositionFunction( @Override public UnresolvedExpression visitExtractFunctionCall( OpenSearchPPLParser.ExtractFunctionCallContext ctx) { - return new Function( - ctx.extractFunction().EXTRACT().toString(), getExtractFunctionArguments(ctx)); + return new Function(ctx.EXTRACT().toString(), getExtractFunctionArguments(ctx)); } private List getExtractFunctionArguments( OpenSearchPPLParser.ExtractFunctionCallContext ctx) { List args = Arrays.asList( - new Literal(ctx.extractFunction().datetimePart().getText(), DataType.STRING), - visitFunctionArg(ctx.extractFunction().functionArg())); + new Literal(ctx.datetimePart().getText(), DataType.STRING), + visitFunctionArg(ctx.functionArg())); return args; } @Override public UnresolvedExpression visitGetFormatFunctionCall( OpenSearchPPLParser.GetFormatFunctionCallContext ctx) { - return new Function( - ctx.getFormatFunction().GET_FORMAT().toString(), getFormatFunctionArguments(ctx)); + return new Function(ctx.GET_FORMAT().toString(), getFormatFunctionArguments(ctx)); } private List getFormatFunctionArguments( OpenSearchPPLParser.GetFormatFunctionCallContext ctx) { List args = Arrays.asList( - new Literal(ctx.getFormatFunction().getFormatType().getText(), DataType.STRING), - visitFunctionArg(ctx.getFormatFunction().functionArg())); + new Literal(ctx.getFormatType().getText(), DataType.STRING), + visitFunctionArg(ctx.functionArg())); return args; } @Override public UnresolvedExpression visitTimestampFunctionCall( OpenSearchPPLParser.TimestampFunctionCallContext ctx) { - return new Function( - ctx.timestampFunction().timestampFunctionName().getText(), timestampFunctionArguments(ctx)); + return new Function(ctx.timestampFunctionName().getText(), timestampFunctionArguments(ctx)); } private List timestampFunctionArguments( OpenSearchPPLParser.TimestampFunctionCallContext ctx) { List args = Arrays.asList( - new Literal(ctx.timestampFunction().simpleDateTimePart().getText(), DataType.STRING), - visitFunctionArg(ctx.timestampFunction().firstArg), - visitFunctionArg(ctx.timestampFunction().secondArg)); + new Literal(ctx.simpleDateTimePart().getText(), DataType.STRING), + visitFunctionArg(ctx.firstArg), + visitFunctionArg(ctx.secondArg)); return args; } @@ -478,12 +469,9 @@ public UnresolvedExpression visitRightHint(OpenSearchPPLParser.RightHintContext @Override public UnresolvedExpression visitInSubqueryExpr(OpenSearchPPLParser.InSubqueryExprContext ctx) { - UnresolvedExpression expr = - new InSubquery( - ctx.valueExpressionList().valueExpression().stream() - .map(this::visit) - .collect(Collectors.toList()), - astBuilder.visitSubSearch(ctx.subSearch())); + List s = + ctx.valueExpression().stream().map(this::visit).collect(Collectors.toList()); + UnresolvedExpression expr = new InSubquery(s, astBuilder.visitSubSearch(ctx.subSearch())); return ctx.NOT() != null ? new Not(expr) : expr; } @@ -502,10 +490,7 @@ public UnresolvedExpression visitExistsSubqueryExpr( @Override public UnresolvedExpression visitBetween(OpenSearchPPLParser.BetweenContext ctx) { UnresolvedExpression betweenExpr = - new Between( - visit(ctx.valueExpression(0)), - visit(ctx.valueExpression(1)), - visit(ctx.valueExpression(2))); + new Between(visit(ctx.expression(0)), visit(ctx.expression(1)), visit(ctx.expression(2))); return ctx.NOT() != null ? new Not(betweenExpr) : betweenExpr; } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java index 7c2b1ab2ff0..1a5b198b42c 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java @@ -32,7 +32,7 @@ public class AstStatementBuilder extends OpenSearchPPLParserBaseVisitor= 0)", + filter(relation("test"), compare(">=", intLiteral(1), intLiteral(0)))); + } + + @Test + public void canBuildMultiPartParenthesizedExpression() { + assertEqual( + "source = test | where (day_of_week_i < 2) OR (day_of_week_i > 5)", + filter( + relation("test"), + or( + compare("<", field("day_of_week_i"), intLiteral(2)), + compare(">", field("day_of_week_i"), intLiteral(5))))); + } +}