From b5a287c79fcd00a0f394c4609a16859e5d3b084a Mon Sep 17 00:00:00 2001 From: forestmvey Date: Tue, 27 Sep 2022 13:11:15 -0700 Subject: [PATCH] Adding highlight support in PPL with optional arguments and wildcard support in SQL and PPL. Signed-off-by: forestmvey --- .../org/opensearch/sql/analysis/Analyzer.java | 1 - .../sql/analysis/ExpressionAnalyzer.java | 2 +- .../sql/analysis/HighlightAnalyzer.java | 7 +- .../sql/ast/AbstractNodeVisitor.java | 2 +- .../org/opensearch/sql/ast/dsl/AstDSL.java | 5 +- .../sql/ast/expression/HighlightFunction.java | 4 +- .../sql/expression/HighlightExpression.java | 53 ++++++++- .../function/OpenSearchFunctions.java | 10 -- .../sql/planner/logical/LogicalHighlight.java | 12 +- .../sql/planner/logical/LogicalPlanDSL.java | 5 +- .../opensearch/sql/analysis/AnalyzerTest.java | 33 +++++- .../sql/analysis/ExpressionAnalyzerTest.java | 9 -- .../analysis/NamedExpressionAnalyzerTest.java | 11 +- .../expression/HighlightExpressionTest.java | 67 +++++++++--- .../logical/LogicalPlanNodeVisitorTest.java | 4 +- .../sql/sql/HighlightFunctionIT.java | 103 ++++++++++++++---- .../request/OpenSearchRequestBuilder.java | 27 ++++- .../opensearch/storage/OpenSearchIndex.java | 3 +- .../OpenSearchDefaultImplementorTest.java | 2 +- .../storage/OpenSearchIndexScanTest.java | 49 ++++++++- sql/src/main/antlr/OpenSearchSQLLexer.g4 | 2 + sql/src/main/antlr/OpenSearchSQLParser.g4 | 14 ++- .../sql/sql/parser/AstExpressionBuilder.java | 11 +- .../sql/sql/parser/AstBuilderTest.java | 26 ++++- .../sql/parser/AstExpressionBuilderTest.java | 9 +- 25 files changed, 375 insertions(+), 96 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 043d299b40..0bc1008152 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -313,7 +313,6 @@ public LogicalPlan visitProject(Project node, AnalysisContext context) { for (UnresolvedExpression expr : node.getProjectList()) { HighlightAnalyzer highlightAnalyzer = new HighlightAnalyzer(expressionAnalyzer, child); child = highlightAnalyzer.analyze(expr, context); - } List namedExpressions = diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index e1dbedebb2..b877fcf673 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -205,7 +205,7 @@ public Expression visitWindowFunction(WindowFunction node, AnalysisContext conte } @Override - public Expression visitHighlight(HighlightFunction node, AnalysisContext context) { + public Expression visitHighlightFunction(HighlightFunction node, AnalysisContext context) { Expression expr = node.getHighlightField().accept(this, context); return new HighlightExpression(expr); } diff --git a/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java index 06a601327c..0a15c6bac8 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java @@ -30,12 +30,13 @@ public LogicalPlan analyze(UnresolvedExpression projectItem, AnalysisContext con @Override public LogicalPlan visitAlias(Alias node, AnalysisContext context) { - if (!(node.getDelegated() instanceof HighlightFunction)) { + UnresolvedExpression delegated = node.getDelegated(); + if (!(delegated instanceof HighlightFunction)) { return null; } - HighlightFunction unresolved = (HighlightFunction) node.getDelegated(); + HighlightFunction unresolved = (HighlightFunction) delegated; Expression field = expressionAnalyzer.analyze(unresolved.getHighlightField(), context); - return new LogicalHighlight(child, field); + return new LogicalHighlight(child, field, unresolved.getArguments()); } } diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index e75f8f4ce5..5aeedcc58d 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -261,7 +261,7 @@ public T visitAD(AD node, C context) { return visitChildren(node, context); } - public T visitHighlight(HighlightFunction node, C context) { + public T visitHighlightFunction(HighlightFunction node, C context) { return visitChildren(node, context); } } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 6a757ccab8..4fa0b89f6c 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -272,8 +272,9 @@ public When when(UnresolvedExpression condition, UnresolvedExpression result) { return new When(condition, result); } - public UnresolvedExpression highlight(UnresolvedExpression fieldName) { - return new HighlightFunction(fieldName); + public UnresolvedExpression highlight(UnresolvedExpression fieldName, + java.util.Map arguments) { + return new HighlightFunction(fieldName, arguments); } public UnresolvedExpression window(UnresolvedExpression function, diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/HighlightFunction.java b/core/src/main/java/org/opensearch/sql/ast/expression/HighlightFunction.java index 5f1bb652d9..0d4e57a78c 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/HighlightFunction.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/HighlightFunction.java @@ -6,6 +6,7 @@ package org.opensearch.sql.ast.expression; import java.util.List; +import java.util.Map; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -21,10 +22,11 @@ @ToString public class HighlightFunction extends UnresolvedExpression { private final UnresolvedExpression highlightField; + private final Map arguments; @Override public T accept(AbstractNodeVisitor nodeVisitor, C context) { - return nodeVisitor.visitHighlight(this, context); + return nodeVisitor.visitHighlightFunction(this, context); } @Override diff --git a/core/src/main/java/org/opensearch/sql/expression/HighlightExpression.java b/core/src/main/java/org/opensearch/sql/expression/HighlightExpression.java index 9745696111..804c38a6f7 100644 --- a/core/src/main/java/org/opensearch/sql/expression/HighlightExpression.java +++ b/core/src/main/java/org/opensearch/sql/expression/HighlightExpression.java @@ -5,10 +5,16 @@ package org.opensearch.sql.expression; +import java.util.LinkedHashMap; import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; import lombok.Getter; import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.env.Environment; @@ -20,6 +26,7 @@ @Getter public class HighlightExpression extends FunctionExpression { private final Expression highlightField; + private final ExprType type; /** * HighlightExpression Constructor. @@ -28,6 +35,8 @@ public class HighlightExpression extends FunctionExpression { public HighlightExpression(Expression highlightField) { super(BuiltinFunctionName.HIGHLIGHT.getName(), List.of(highlightField)); this.highlightField = highlightField; + this.type = this.highlightField.toString().contains("*") + ? ExprCoreType.STRUCT : ExprCoreType.ARRAY; } /** @@ -37,21 +46,57 @@ public HighlightExpression(Expression highlightField) { */ @Override public ExprValue valueOf(Environment valueEnv) { - String refName = "_highlight" + "." + StringUtils.unquoteText(getHighlightField().toString()); - return valueEnv.resolve(DSL.ref(refName, ExprCoreType.STRING)); + String refName = "_highlight"; + // Not a wilcard expression + if (this.type == ExprCoreType.ARRAY) { + refName += "." + StringUtils.unquoteText(getHighlightField().toString()); + } + ExprValue value = valueEnv.resolve(DSL.ref(refName, ExprCoreType.STRING)); + + // In the event of multiple returned highlights and wildcard being + // used in conjunction with other highlight calls, we need to ensure + // only wildcard regex matching is mapped to wildcard call. + if (this.type == ExprCoreType.STRUCT && value.type() == ExprCoreType.STRUCT) { + value = new ExprTupleValue( + new LinkedHashMap(value.tupleValue() + .entrySet() + .stream() + .filter(s -> matchesHighlightRegex(s.getKey(), + StringUtils.unquoteText(highlightField.toString()))) + .collect(Collectors.toMap( + e -> e.getKey(), + e -> e.getValue())))); + if (value.tupleValue().isEmpty()) { + value = ExprValueUtils.missingValue(); + } + } + + return value; } /** * Get type for HighlightExpression. - * @return : String type. + * @return : Expression type. */ @Override public ExprType type() { - return ExprCoreType.ARRAY; + return this.type; } @Override public T accept(ExpressionNodeVisitor visitor, C context) { return visitor.visitHighlight(this, context); } + + /** + * Check if field matches the wildcard pattern used in highlight query. + * @param field Highlight selected field for query + * @param pattern Wildcard regex to match field against + * @return True if field matches wildcard pattern + */ + private boolean matchesHighlightRegex(String field, String pattern) { + Pattern p = Pattern.compile(pattern.replace("*", ".*")); + Matcher matcher = p.matcher(field); + return matcher.matches(); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java index bb3eb7008b..43a722b838 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java @@ -17,7 +17,6 @@ import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.HighlightExpression; import org.opensearch.sql.expression.NamedArgumentExpression; import org.opensearch.sql.expression.env.Environment; @@ -37,15 +36,6 @@ public void register(BuiltinFunctionRepository repository) { repository.register(match_phrase(BuiltinFunctionName.MATCH_PHRASE)); repository.register(match_phrase(BuiltinFunctionName.MATCHPHRASE)); repository.register(match_phrase_prefix()); - repository.register(highlight()); - } - - private static FunctionResolver highlight() { - FunctionName functionName = BuiltinFunctionName.HIGHLIGHT.getName(); - FunctionSignature functionSignature = new FunctionSignature(functionName, List.of(STRING)); - FunctionBuilder functionBuilder = arguments -> new HighlightExpression(arguments.get(0)); - return new DefaultFunctionResolver(functionName, - ImmutableMap.of(functionSignature, functionBuilder)); } private static FunctionResolver match_bool_prefix() { diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalHighlight.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalHighlight.java index 986a545486..c1e873a00d 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalHighlight.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalHighlight.java @@ -6,9 +6,11 @@ package org.opensearch.sql.planner.logical; import java.util.Collections; +import java.util.Map; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.expression.Expression; @EqualsAndHashCode(callSuper = true) @@ -16,10 +18,16 @@ @ToString public class LogicalHighlight extends LogicalPlan { private final Expression highlightField; + private final Map arguments; - public LogicalHighlight(LogicalPlan childPlan, Expression field) { + /** + * Constructor of LogicalHighlight. + */ + public LogicalHighlight(LogicalPlan childPlan, Expression highlightField, + Map arguments) { super(Collections.singletonList(childPlan)); - highlightField = field; + this.highlightField = highlightField; + this.arguments = arguments; } @Override diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java index 005a5d84fd..9e07e702de 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java @@ -63,8 +63,9 @@ public LogicalPlan window(LogicalPlan input, return new LogicalWindow(input, windowFunction, windowDefinition); } - public LogicalPlan highlight(LogicalPlan input, Expression field) { - return new LogicalHighlight(input, field); + public LogicalPlan highlight(LogicalPlan input, Expression field, + Map arguments) { + return new LogicalHighlight(input, field, arguments); } public static LogicalPlan remove(LogicalPlan input, ReferenceExpression... fields) { diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 19bfcabfec..31309a1953 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -27,10 +27,12 @@ import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; import static org.opensearch.sql.ast.tree.Sort.SortOrder; import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; +import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.LONG; import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -270,16 +272,41 @@ public void project_source() { @Test public void project_highlight() { + Map args = new HashMap<>(); + args.put("pre_tags", new Literal("", DataType.STRING)); + args.put("post_tags", new Literal("", DataType.STRING)); + + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table), + DSL.literal("fieldA"), args), + DSL.named("highlight(fieldA, pre_tags='', post_tags='')", + new HighlightExpression(DSL.literal("fieldA"))) + ), + AstDSL.projectWithArg( + AstDSL.relation("schema"), + AstDSL.defaultFieldsArgs(), + AstDSL.alias("highlight(fieldA, pre_tags='', post_tags='')", + new HighlightFunction(AstDSL.stringLiteral("fieldA"), args)) + ) + ); + } + + @Test + public void project_highlight_wildcard() { + Map args = new HashMap<>(); assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table), - DSL.literal("fieldA")), - DSL.named("highlight(fieldA)", new HighlightExpression(DSL.literal("fieldA"))) + DSL.literal("*"), args), + DSL.named("highlight(*)", + new HighlightExpression(DSL.literal("*"))) ), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("highlight(fieldA)", new HighlightFunction(AstDSL.stringLiteral("fieldA"))) + AstDSL.alias("highlight(*)", + new HighlightFunction(AstDSL.stringLiteral("*"), args)) ) ); } diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index 5787d08f3c..c76f449357 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -8,7 +8,6 @@ import static java.util.Collections.emptyList; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.opensearch.sql.ast.dsl.AstDSL.field; @@ -38,7 +37,6 @@ import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.DataType; -import org.opensearch.sql.ast.expression.HighlightFunction; import org.opensearch.sql.ast.expression.RelevanceFieldList; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.expression.UnresolvedExpression; @@ -50,7 +48,6 @@ import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.HighlightExpression; import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction; @@ -592,12 +589,6 @@ public void now_as_a_function_not_cached() { assertTrue(values.stream().noneMatch(v -> v.valueOf(null) == referenceValue)); } - @Test - void highlight() { - assertAnalyzeEqual(new HighlightExpression(DSL.literal("fieldA")), - new HighlightFunction(stringLiteral("fieldA"))); - } - protected Expression analyze(UnresolvedExpression unresolvedExpression) { return expressionAnalyzer.analyze(unresolvedExpression, analysisContext); } diff --git a/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java index 2293d125aa..b944115a48 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java @@ -8,15 +8,14 @@ import static org.junit.jupiter.api.Assertions.assertEquals; +import java.util.HashMap; +import java.util.Map; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.HighlightFunction; -import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.expression.DSL; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.LiteralExpression; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.config.ExpressionConfig; import org.springframework.context.annotation.Configuration; @@ -40,8 +39,10 @@ void visit_named_select_item() { @Test void visit_highlight() { + Map args = new HashMap<>(); Alias alias = AstDSL.alias("highlight(fieldA)", - new HighlightFunction(AstDSL.stringLiteral("fieldA"))); + new HighlightFunction( + AstDSL.stringLiteral("fieldA"), args)); NamedExpressionAnalyzer analyzer = new NamedExpressionAnalyzer(expressionAnalyzer); diff --git a/core/src/test/java/org/opensearch/sql/expression/HighlightExpressionTest.java b/core/src/test/java/org/opensearch/sql/expression/HighlightExpressionTest.java index c6e2dccf69..41f3bad030 100644 --- a/core/src/test/java/org/opensearch/sql/expression/HighlightExpressionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/HighlightExpressionTest.java @@ -6,18 +6,19 @@ package org.opensearch.sql.expression; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; import com.google.common.collect.ImmutableMap; -import com.google.errorprone.annotations.DoNotCall; import org.junit.jupiter.api.Test; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.expression.env.Environment; + public class HighlightExpressionTest extends ExpressionTestBase { @Test @@ -35,33 +36,73 @@ public void single_highlight_test() { public void missing_highlight_test() { Environment hlTuple = ExprValueUtils.tupleValue( ImmutableMap.of("_highlight.Title", "result value")).bindingTuples(); + HighlightExpression expr = new HighlightExpression(DSL.literal("invalid")); ExprValue resultVal = expr.valueOf(hlTuple); assertTrue(resultVal.isMissing()); } - /** - * Enable me when '*' is supported in highlight. - */ - @DoNotCall + @Test + public void missing_highlight_wildcard_test() { + ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); + var hlBuilder = ImmutableMap.builder(); + hlBuilder.put("Title", ExprValueUtils.stringValue("first result value")); + hlBuilder.put("Body", ExprValueUtils.stringValue("secondary result value")); + builder.put("_highlight", ExprTupleValue.fromExprValueMap(hlBuilder.build())); + + HighlightExpression hlExpr = new HighlightExpression(DSL.literal("invalid*")); + ExprValue resultVal = hlExpr.valueOf( + ExprTupleValue.fromExprValueMap(builder.build()).bindingTuples()); + + assertTrue(resultVal.isMissing()); + } + + @Test public void highlight_all_test() { ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); var hlBuilder = ImmutableMap.builder(); hlBuilder.put("Title", ExprValueUtils.stringValue("correct result value")); - hlBuilder.put("Body", ExprValueUtils.stringValue("incorrect result value")); + hlBuilder.put("Body", ExprValueUtils.stringValue("secondary correct result value")); builder.put("_highlight", ExprTupleValue.fromExprValueMap(hlBuilder.build())); - HighlightExpression hlExpr = new HighlightExpression(DSL.literal("*")); + HighlightExpression hlExpr = new HighlightExpression(DSL.literal("T*")); ExprValue resultVal = hlExpr.valueOf( ExprTupleValue.fromExprValueMap(builder.build()).bindingTuples()); - assertEquals(ARRAY, resultVal.type()); - for (var field : resultVal.tupleValue().entrySet()) { - assertTrue(field.toString().contains(hlExpr.getHighlightField().toString())); - } + + assertEquals(STRUCT, resultVal.type()); assertTrue(resultVal.tupleValue().containsValue( - ExprValueUtils.stringValue("\"correct result value\""))); + ExprValueUtils.stringValue("correct result value"))); + assertFalse(resultVal.tupleValue().containsValue( + ExprValueUtils.stringValue("secondary correct result value"))); + } + + @Test + public void do_nothing_with_missing_value() { + Environment hlTuple = ExprValueUtils.tupleValue( + ImmutableMap.of("NonHighlightField", "ResultValue")).bindingTuples(); + HighlightExpression expr = new HighlightExpression(DSL.literal("*")); + ExprValue resultVal = expr.valueOf(hlTuple); + + assertTrue(resultVal.isMissing()); + } + + @Test + public void highlight_wildcard_test() { + ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); + var hlBuilder = ImmutableMap.builder(); + hlBuilder.put("Title", ExprValueUtils.stringValue("correct result value")); + hlBuilder.put("Body", ExprValueUtils.stringValue("incorrect result value")); + builder.put("_highlight", ExprTupleValue.fromExprValueMap(hlBuilder.build())); + + HighlightExpression hlExpr = new HighlightExpression(DSL.literal("T*")); + ExprValue resultVal = hlExpr.valueOf( + ExprTupleValue.fromExprValueMap(builder.build()).bindingTuples()); + + assertEquals(STRUCT, resultVal.type()); assertTrue(resultVal.tupleValue().containsValue( - ExprValueUtils.stringValue("\"correct result value\""))); + ExprValueUtils.stringValue("correct result value"))); + assertFalse(resultVal.tupleValue().containsValue( + ExprValueUtils.stringValue("incorrect result value"))); } } diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index c90ea365d2..329708b7d8 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -13,6 +13,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.HashMap; +import java.util.Map; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Test; @@ -117,8 +118,9 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() { assertNull(rareTopN.accept(new LogicalPlanNodeVisitor() { }, null)); + Map args = new HashMap<>(); LogicalPlan highlight = new LogicalHighlight(filter, - new LiteralExpression(ExprValueUtils.stringValue("fieldA"))); + new LiteralExpression(ExprValueUtils.stringValue("fieldA")), args); assertNull(highlight.accept(new LogicalPlanNodeVisitor() { }, null)); diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java index 422f71968f..809e2dc7c5 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java @@ -5,14 +5,18 @@ package org.opensearch.sql.sql; +import static org.opensearch.sql.util.MatcherUtils.rows; import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; import static org.opensearch.sql.util.MatcherUtils.verifySchema; -import com.google.errorprone.annotations.DoNotCall; +import com.google.common.collect.ImmutableMap; +import org.json.JSONArray; import org.json.JSONObject; import org.junit.Test; import org.opensearch.sql.legacy.SQLIntegTestCase; import org.opensearch.sql.legacy.TestsConstants; +import java.util.List; public class HighlightFunctionIT extends SQLIntegTestCase { @@ -25,57 +29,104 @@ protected void init() throws Exception { public void single_highlight_test() { String query = "SELECT Tags, highlight('Tags') FROM %s WHERE match(Tags, 'yeast') LIMIT 1"; JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); + verifySchema(response, schema("Tags", null, "text"), schema("highlight('Tags')", null, "nested")); assertEquals(1, response.getInt("total")); + + verifyDataRows(response, + rows("alcohol-level yeast home-brew champagne", + new JSONArray(List.of("alcohol-level yeast home-brew champagne")))); } @Test - public void accepts_unquoted_test() { - String query = "SELECT Tags, highlight(Tags) FROM %s WHERE match(Tags, 'yeast') LIMIT 1"; + public void highlight_optional_arguments_test() { + String query = "SELECT highlight('Tags', pre_tags='', post_tags='') " + + "FROM %s WHERE match(Tags, 'yeast') LIMIT 1"; JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); - verifySchema(response, schema("Tags", null, "text"), - schema("highlight(Tags)", null, "nested")); + + verifySchema(response, schema("highlight('Tags', pre_tags='', post_tags='')", + null, "nested")); + assertEquals(1, response.getInt("total")); + + verifyDataRows(response, + rows(new JSONArray(List.of("alcohol-level yeast home-brew champagne")))); + } + + @Test + public void highlight_multiple_optional_arguments_test() { + String query = "SELECT highlight(Title), highlight(Body, pre_tags='', post_tags='') FROM %s WHERE multi_match([Title, Body], 'IPA') LIMIT 1"; + JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); + + verifySchema(response, schema("highlight(Title)", null, "nested"), + schema("highlight(Body, pre_tags='', " + + "post_tags='')", null, "nested")); + + assertEquals(1, response.getInt("total")); + + verifyDataRows(response, rows(new JSONArray(List.of("What are the differences between an IPA" + + " and its variants?")), + new JSONArray(List.of("

I know what makes an IPA" + + " an IPA, but what are the unique characteristics of it's" + + " common variants?", + "To be specific, the ones I'm interested in are Double IPA " + + "and Black IPA, but general differences" + + " between")))); } @Test public void multiple_highlight_test() { - String query = "SELECT highlight(Title), highlight(Body) FROM %s WHERE MULTI_MATCH([Title, Body], 'hops') LIMIT 1"; + String query = "SELECT highlight(Title), highlight(Tags) FROM %s WHERE MULTI_MATCH([Title, Tags], 'hops') LIMIT 1"; JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); verifySchema(response, schema("highlight(Title)", null, "nested"), - schema("highlight(Body)", null, "nested")); + schema("highlight(Tags)", null, "nested")); assertEquals(1, response.getInt("total")); + + verifyDataRows(response, + rows( new JSONArray(List.of("What uses do hops have outside of brewing?")), + new JSONArray(List.of("hops history")))); } - // Enable me when * is supported - @DoNotCall + @Test public void wildcard_highlight_test() { - String query = "SELECT highlight('*itle') FROM %s WHERE MULTI_MATCH([Title, Body], 'hops') LIMIT 1"; + String query = "SELECT highlight('*itle') FROM %s WHERE MULTI_MATCH([Title, Tags], 'hops') LIMIT 1"; JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); - verifySchema(response, schema("highlight('*itle')", null, "nested")); + + verifySchema(response, schema("highlight('*itle')", null, "object")); assertEquals(1, response.getInt("total")); + + verifyDataRows(response, rows(new JSONObject(ImmutableMap.of( + "Title", new JSONArray(List.of("What uses do hops have outside of brewing?")))))); } - // Enable me when * is supported - @DoNotCall + @Test public void wildcard_multi_field_highlight_test() { String query = "SELECT highlight('T*') FROM %s WHERE MULTI_MATCH([Title, Tags], 'hops') LIMIT 1"; JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); - verifySchema(response, schema("highlight('T*')", null, "nested")); - var resultMap = response.getJSONArray("datarows").getJSONArray(0).getJSONObject(0); + + verifySchema(response, schema("highlight('T*')", null, "object")); assertEquals(1, response.getInt("total")); - assertTrue(resultMap.has("highlight(\"T*\").Title")); - assertTrue(resultMap.has("highlight(\"T*\").Tags")); + + verifyDataRows(response, rows(new JSONObject(ImmutableMap.of( + "Title", new JSONArray(List.of("What uses do hops have outside of brewing?")), + "Tags", new JSONArray(List.of("hops history")))))); } - // Enable me when * is supported - @DoNotCall + @Test public void highlight_all_test() { - String query = "SELECT highlight('*') FROM %s WHERE MULTI_MATCH([Title, Body], 'hops') LIMIT 1"; + String query = "SELECT highlight('*') FROM %s WHERE MULTI_MATCH([Title, Tags], 'hops') LIMIT 1"; JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); - verifySchema(response, schema("highlight('*')", null, "nested")); + + verifySchema(response, schema("highlight('*')", null, "object")); assertEquals(1, response.getInt("total")); + + verifyDataRows(response, rows(new JSONObject(ImmutableMap.of( + "Title", new JSONArray(List.of("What uses do hops have outside of brewing?")), + "Tags", new JSONArray(List.of("hops history")))))); } @Test @@ -84,5 +135,15 @@ public void highlight_no_limit_test() { JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); verifySchema(response, schema("highlight(Body)", null, "nested")); assertEquals(2, response.getInt("total")); + + verifyDataRows(response, rows(new JSONArray(List.of("Boiling affects hops, by boiling" + + " off the aroma and extracting more of the organic acids that provide"))), + + rows(new JSONArray(List.of("

Do hops have (or had in the past) any use outside of brewing beer?", + "when-was-the-first-beer-ever-brewed\">dating first modern beers we have the first record" + + " of cultivating hops", + "predating the first record of use of hops in beer by nearly a century.", + "Could the hops have been cultivated for any other purpose than brewing, " + + "or can we safely assume if they")))); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java index 646395d790..c26413c622 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java @@ -25,9 +25,11 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; import org.opensearch.search.sort.SortBuilder; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; @@ -173,14 +175,33 @@ public void pushDownLimit(Integer limit, Integer offset) { * Add highlight to DSL requests. * @param field name of the field to highlight */ - public void pushDownHighlight(String field) { + public void pushDownHighlight(String field, Map arguments) { + String unquotedField = StringUtils.unquoteText(field); if (sourceBuilder.highlighter() != null) { - sourceBuilder.highlighter().field(StringUtils.unquoteText(field)); + // OS does not allow duplicates of highlight fields + if (sourceBuilder.highlighter().fields().stream() + .anyMatch(f -> f.name().equals(unquotedField))) { + throw new SemanticCheckException(String.format( + "Duplicate field %s in highlight", field)); + } + + sourceBuilder.highlighter().field(unquotedField); } else { HighlightBuilder highlightBuilder = - new HighlightBuilder().field(StringUtils.unquoteText(field)); + new HighlightBuilder().field(unquotedField); sourceBuilder.highlighter(highlightBuilder); } + + // lastFieldIndex denotes previously set highlighter with field parameter + int lastFieldIndex = sourceBuilder.highlighter().fields().size() - 1; + if (arguments.containsKey("pre_tags")) { + sourceBuilder.highlighter().fields().get(lastFieldIndex) + .preTags(arguments.get("pre_tags").toString()); + } + if (arguments.containsKey("post_tags")) { + sourceBuilder.highlighter().fields().get(lastFieldIndex) + .postTags(arguments.get("post_tags").toString()); + } } /** diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index ef6159020f..2849fbbec9 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -207,7 +207,8 @@ public PhysicalPlan visitAD(LogicalAD node, OpenSearchIndexScan context) { @Override public PhysicalPlan visitHighlight(LogicalHighlight node, OpenSearchIndexScan context) { - context.getRequestBuilder().pushDownHighlight(node.getHighlightField().toString()); + context.getRequestBuilder().pushDownHighlight( + StringUtils.unquoteText(node.getHighlightField().toString()), node.getArguments()); return visitChild(node, context); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java index 64b87aa2c5..ced87a7d31 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java @@ -88,6 +88,6 @@ public void visitHighlight() { new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); implementor.visitHighlight(node, indexScan); - verify(requestBuilder).pushDownHighlight(any()); + verify(requestBuilder).pushDownHighlight(any(), any()); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java index a1f2869ca5..9a606750a3 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java @@ -9,7 +9,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -20,6 +19,8 @@ import com.google.common.collect.ImmutableMap; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -32,9 +33,12 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.request.OpenSearchQueryRequest; @@ -190,14 +194,49 @@ void pushDownFilters() { @Test void pushDownHighlight() { + Map args = new HashMap<>(); assertThat() .pushDown(QueryBuilders.termQuery("name", "John")) - .pushDownHighlight("Title") - .pushDownHighlight("Body") + .pushDownHighlight("Title", args) + .pushDownHighlight("Body", args) .shouldQueryHighlight(QueryBuilders.termQuery("name", "John"), new HighlightBuilder().field("Title").field("Body")); } + @Test + void pushDownHighlightWithArguments() { + Map args = new HashMap<>(); + args.put("pre_tags", new Literal("", DataType.STRING)); + args.put("post_tags", new Literal("", DataType.STRING)); + HighlightBuilder highlightBuilder = new HighlightBuilder() + .field("Title"); + highlightBuilder.fields().get(0).preTags("").postTags(""); + assertThat() + .pushDown(QueryBuilders.termQuery("name", "John")) + .pushDownHighlight("Title", args) + .shouldQueryHighlight(QueryBuilders.termQuery("name", "John"), + highlightBuilder); + } + + @Test + void pushDownHighlightWithRepeatingFields() { + mockResponse( + new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, + new ExprValue[]{employee(3, "Allen", "IT"), employee(4, "Bob", "HR")}); + + try (OpenSearchIndexScan indexScan = + new OpenSearchIndexScan(client, settings, "test", 2, exprValueFactory)) { + indexScan.getRequestBuilder().pushDownLimit(3, 0); + indexScan.open(); + Map args = new HashMap<>(); + indexScan.getRequestBuilder().pushDownHighlight("name", args); + indexScan.getRequestBuilder().pushDownHighlight("name", args); + } catch (SemanticCheckException e) { + assertTrue(e.getClass().equals(SemanticCheckException.class)); + } + verify(client).cleanup(any()); + } + private PushDownAssertion assertThat() { return new PushDownAssertion(client, exprValueFactory, settings); } @@ -223,8 +262,8 @@ PushDownAssertion pushDown(QueryBuilder query) { return this; } - PushDownAssertion pushDownHighlight(String query) { - indexScan.getRequestBuilder().pushDownHighlight(query); + PushDownAssertion pushDownHighlight(String query, Map arguments) { + indexScan.getRequestBuilder().pushDownHighlight(query, arguments); return this; } diff --git a/sql/src/main/antlr/OpenSearchSQLLexer.g4 b/sql/src/main/antlr/OpenSearchSQLLexer.g4 index 5f2385bab3..42c302493b 100644 --- a/sql/src/main/antlr/OpenSearchSQLLexer.g4 +++ b/sql/src/main/antlr/OpenSearchSQLLexer.g4 @@ -361,6 +361,8 @@ TIME_ZONE: 'TIME_ZONE'; TYPE: 'TYPE'; ZERO_TERMS_QUERY: 'ZERO_TERMS_QUERY'; HIGHLIGHT: 'HIGHLIGHT'; +HIGHLIGHT_PRE_TAGS: 'PRE_TAGS'; +HIGHLIGHT_POST_TAGS: 'POST_TAGS'; // RELEVANCE FUNCTIONS MATCH_BOOL_PREFIX: 'MATCH_BOOL_PREFIX'; diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 23e2d9288d..42009bc5b8 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -315,7 +315,7 @@ constantFunction ; highlightFunction - : HIGHLIGHT LR_BRACKET relevanceField RR_BRACKET + : HIGHLIGHT LR_BRACKET relevanceField (COMMA highlightArg)* RR_BRACKET ; scalarFunctionName @@ -443,6 +443,10 @@ relevanceArg : relevanceArgName EQUAL_SYMBOL relevanceArgValue ; +highlightArg + : highlightArgName EQUAL_SYMBOL highlightArgValue + ; + relevanceArgName : ALLOW_LEADING_WILDCARD | ANALYZER | ANALYZE_WILDCARD | AUTO_GENERATE_SYNONYMS_PHRASE_QUERY | BOOST | CUTOFF_FREQUENCY | DEFAULT_FIELD | DEFAULT_OPERATOR | ENABLE_POSITION_INCREMENTS @@ -453,6 +457,10 @@ relevanceArgName | ZERO_TERMS_QUERY ; +highlightArgName + : HIGHLIGHT_POST_TAGS | HIGHLIGHT_PRE_TAGS + ; + relevanceFieldAndWeight : field=relevanceField | field=relevanceField weight=relevanceFieldWeight @@ -478,3 +486,7 @@ relevanceArgValue | constant ; +highlightArgValue + : stringLiteral + ; + diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index 006ed5fba2..ebfafeec23 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -50,6 +50,7 @@ import static org.opensearch.sql.sql.parser.ParserUtils.createSortOption; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -137,7 +138,15 @@ public UnresolvedExpression visitScalarFunctionCall(ScalarFunctionCallContext ct @Override public UnresolvedExpression visitHighlightFunctionCall( OpenSearchSQLParser.HighlightFunctionCallContext ctx) { - return new HighlightFunction(visit(ctx.highlightFunction().relevanceField())); + ImmutableMap.Builder builder = ImmutableMap.builder(); + ctx.highlightFunction().highlightArg().forEach(v -> builder.put( + v.highlightArgName().getText().toLowerCase(), + new Literal(StringUtils.unquoteText(v.highlightArgValue().getText()), + DataType.STRING)) + ); + + return new HighlightFunction(visit(ctx.highlightFunction().relevanceField()), + builder.build()); } @Override diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java index c3b9ed245a..0c06754261 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java @@ -33,7 +33,9 @@ import static org.opensearch.sql.utils.SystemIndexUtils.mappingTable; import com.google.common.collect.ImmutableList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.stream.Stream; import org.antlr.v4.runtime.tree.ParseTree; import org.junit.jupiter.api.Test; @@ -42,6 +44,8 @@ import org.junit.jupiter.params.provider.MethodSource; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.sql.antlr.SQLSyntaxParser; @@ -731,18 +735,36 @@ public void test_now_like_functions(String name, Boolean hasFsp, Boolean hasShor @Test public void can_build_qualified_name_highlight() { + Map args = new HashMap<>(); assertEquals( project(relation("test"), - alias("highlight(fieldA)", highlight(AstDSL.qualifiedName("fieldA")))), + alias("highlight(fieldA)", + highlight(AstDSL.qualifiedName("fieldA"), args))), buildAST("SELECT highlight(fieldA) FROM test") ); } + @Test + public void can_build_qualified_highlight_with_arguments() { + Map args = new HashMap<>(); + args.put("pre_tags", new Literal("", DataType.STRING)); + args.put("post_tags", new Literal("", DataType.STRING)); + assertEquals( + project(relation("test"), + alias("highlight(fieldA, pre_tags='', post_tags='')", + highlight(AstDSL.qualifiedName("fieldA"), args))), + buildAST("SELECT highlight(fieldA, pre_tags='', post_tags='') " + + "FROM test") + ); + } + @Test public void can_build_string_literal_highlight() { + Map args = new HashMap<>(); assertEquals( project(relation("test"), - alias("highlight(\"fieldA\")", highlight(AstDSL.stringLiteral("fieldA")))), + alias("highlight(\"fieldA\")", + highlight(AstDSL.stringLiteral("fieldA"), args))), buildAST("SELECT highlight(\"fieldA\") FROM test") ); } diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java index ef881275e5..ec0a0dd0d3 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -13,7 +13,6 @@ import static org.opensearch.sql.ast.dsl.AstDSL.caseWhen; import static org.opensearch.sql.ast.dsl.AstDSL.dateLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.doubleLiteral; -import static org.opensearch.sql.ast.dsl.AstDSL.floatLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.function; import static org.opensearch.sql.ast.dsl.AstDSL.highlight; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; @@ -35,12 +34,14 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.HashMap; import org.antlr.v4.runtime.CommonTokenStream; import org.apache.commons.lang3.tuple.ImmutablePair; import org.junit.jupiter.api.Test; import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.RelevanceFieldList; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; @@ -311,16 +312,18 @@ public void canBuildWindowFunctionWithNullOrderSpecified() { @Test public void canBuildStringLiteralHighlightFunction() { + HashMap args = new HashMap<>(); assertEquals( - highlight(AstDSL.stringLiteral("fieldA")), + highlight(AstDSL.stringLiteral("fieldA"), args), buildExprAst("highlight(\"fieldA\")") ); } @Test public void canBuildQualifiedNameHighlightFunction() { + HashMap args = new HashMap<>(); assertEquals( - highlight(AstDSL.qualifiedName("fieldA")), + highlight(AstDSL.qualifiedName("fieldA"), args), buildExprAst("highlight(fieldA)") ); }