diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 043d299b40..0bc1008152 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -313,7 +313,6 @@ public LogicalPlan visitProject(Project node, AnalysisContext context) { for (UnresolvedExpression expr : node.getProjectList()) { HighlightAnalyzer highlightAnalyzer = new HighlightAnalyzer(expressionAnalyzer, child); child = highlightAnalyzer.analyze(expr, context); - } List namedExpressions = diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index e1dbedebb2..606602ccb1 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -19,6 +19,7 @@ import org.opensearch.sql.analysis.symbol.Symbol; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.And; import org.opensearch.sql.ast.expression.Case; @@ -44,7 +45,9 @@ import org.opensearch.sql.ast.expression.WindowFunction; import org.opensearch.sql.ast.expression.Xor; import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; @@ -91,6 +94,14 @@ public Expression analyze(UnresolvedExpression unresolved, AnalysisContext conte return unresolved.accept(this, context); } + @Override + public Expression visitAlias(Alias node, AnalysisContext context) { + // Only purpose for this override currently is to avoid null pointer exception when using + // '-' flag with a highlight call in a fields command. + throw new SemanticCheckException(String.format("can't resolve Symbol %s in type env", + node.getName())); + } + @Override public Expression visitUnresolvedAttribute(UnresolvedAttribute node, AnalysisContext context) { return visitIdentifier(node.getAttr(), context); @@ -205,7 +216,7 @@ public Expression visitWindowFunction(WindowFunction node, AnalysisContext conte } @Override - public Expression visitHighlight(HighlightFunction node, AnalysisContext context) { + public Expression visitHighlightFunction(HighlightFunction node, AnalysisContext context) { Expression expr = node.getHighlightField().accept(this, context); return new HighlightExpression(expr); } diff --git a/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java index 06a601327c..0a15c6bac8 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java @@ -30,12 +30,13 @@ public LogicalPlan analyze(UnresolvedExpression projectItem, AnalysisContext con @Override public LogicalPlan visitAlias(Alias node, AnalysisContext context) { - if (!(node.getDelegated() instanceof HighlightFunction)) { + UnresolvedExpression delegated = node.getDelegated(); + if (!(delegated instanceof HighlightFunction)) { return null; } - HighlightFunction unresolved = (HighlightFunction) node.getDelegated(); + HighlightFunction unresolved = (HighlightFunction) delegated; Expression field = expressionAnalyzer.analyze(unresolved.getHighlightField(), context); - return new LogicalHighlight(child, field); + return new LogicalHighlight(child, field, unresolved.getArguments()); } } diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index e75f8f4ce5..5aeedcc58d 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -261,7 +261,7 @@ public T visitAD(AD node, C context) { return visitChildren(node, context); } - public T visitHighlight(HighlightFunction node, C context) { + public T visitHighlightFunction(HighlightFunction node, C context) { return visitChildren(node, context); } } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 6a757ccab8..4fa0b89f6c 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -272,8 +272,9 @@ public When when(UnresolvedExpression condition, UnresolvedExpression result) { return new When(condition, result); } - public UnresolvedExpression highlight(UnresolvedExpression fieldName) { - return new HighlightFunction(fieldName); + public UnresolvedExpression highlight(UnresolvedExpression fieldName, + java.util.Map arguments) { + return new HighlightFunction(fieldName, arguments); } public UnresolvedExpression window(UnresolvedExpression function, diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/HighlightFunction.java b/core/src/main/java/org/opensearch/sql/ast/expression/HighlightFunction.java index 5f1bb652d9..0d4e57a78c 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/HighlightFunction.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/HighlightFunction.java @@ -6,6 +6,7 @@ package org.opensearch.sql.ast.expression; import java.util.List; +import java.util.Map; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; @@ -21,10 +22,11 @@ @ToString public class HighlightFunction extends UnresolvedExpression { private final UnresolvedExpression highlightField; + private final Map arguments; @Override public T accept(AbstractNodeVisitor nodeVisitor, C context) { - return nodeVisitor.visitHighlight(this, context); + return nodeVisitor.visitHighlightFunction(this, context); } @Override diff --git a/core/src/main/java/org/opensearch/sql/expression/HighlightExpression.java b/core/src/main/java/org/opensearch/sql/expression/HighlightExpression.java index 9745696111..804c38a6f7 100644 --- a/core/src/main/java/org/opensearch/sql/expression/HighlightExpression.java +++ b/core/src/main/java/org/opensearch/sql/expression/HighlightExpression.java @@ -5,10 +5,16 @@ package org.opensearch.sql.expression; +import java.util.LinkedHashMap; import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; import lombok.Getter; import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.env.Environment; @@ -20,6 +26,7 @@ @Getter public class HighlightExpression extends FunctionExpression { private final Expression highlightField; + private final ExprType type; /** * HighlightExpression Constructor. @@ -28,6 +35,8 @@ public class HighlightExpression extends FunctionExpression { public HighlightExpression(Expression highlightField) { super(BuiltinFunctionName.HIGHLIGHT.getName(), List.of(highlightField)); this.highlightField = highlightField; + this.type = this.highlightField.toString().contains("*") + ? ExprCoreType.STRUCT : ExprCoreType.ARRAY; } /** @@ -37,21 +46,57 @@ public HighlightExpression(Expression highlightField) { */ @Override public ExprValue valueOf(Environment valueEnv) { - String refName = "_highlight" + "." + StringUtils.unquoteText(getHighlightField().toString()); - return valueEnv.resolve(DSL.ref(refName, ExprCoreType.STRING)); + String refName = "_highlight"; + // Not a wilcard expression + if (this.type == ExprCoreType.ARRAY) { + refName += "." + StringUtils.unquoteText(getHighlightField().toString()); + } + ExprValue value = valueEnv.resolve(DSL.ref(refName, ExprCoreType.STRING)); + + // In the event of multiple returned highlights and wildcard being + // used in conjunction with other highlight calls, we need to ensure + // only wildcard regex matching is mapped to wildcard call. + if (this.type == ExprCoreType.STRUCT && value.type() == ExprCoreType.STRUCT) { + value = new ExprTupleValue( + new LinkedHashMap(value.tupleValue() + .entrySet() + .stream() + .filter(s -> matchesHighlightRegex(s.getKey(), + StringUtils.unquoteText(highlightField.toString()))) + .collect(Collectors.toMap( + e -> e.getKey(), + e -> e.getValue())))); + if (value.tupleValue().isEmpty()) { + value = ExprValueUtils.missingValue(); + } + } + + return value; } /** * Get type for HighlightExpression. - * @return : String type. + * @return : Expression type. */ @Override public ExprType type() { - return ExprCoreType.ARRAY; + return this.type; } @Override public T accept(ExpressionNodeVisitor visitor, C context) { return visitor.visitHighlight(this, context); } + + /** + * Check if field matches the wildcard pattern used in highlight query. + * @param field Highlight selected field for query + * @param pattern Wildcard regex to match field against + * @return True if field matches wildcard pattern + */ + private boolean matchesHighlightRegex(String field, String pattern) { + Pattern p = Pattern.compile(pattern.replace("*", ".*")); + Matcher matcher = p.matcher(field); + return matcher.matches(); + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java index bb3eb7008b..43a722b838 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java @@ -17,7 +17,6 @@ import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.HighlightExpression; import org.opensearch.sql.expression.NamedArgumentExpression; import org.opensearch.sql.expression.env.Environment; @@ -37,15 +36,6 @@ public void register(BuiltinFunctionRepository repository) { repository.register(match_phrase(BuiltinFunctionName.MATCH_PHRASE)); repository.register(match_phrase(BuiltinFunctionName.MATCHPHRASE)); repository.register(match_phrase_prefix()); - repository.register(highlight()); - } - - private static FunctionResolver highlight() { - FunctionName functionName = BuiltinFunctionName.HIGHLIGHT.getName(); - FunctionSignature functionSignature = new FunctionSignature(functionName, List.of(STRING)); - FunctionBuilder functionBuilder = arguments -> new HighlightExpression(arguments.get(0)); - return new DefaultFunctionResolver(functionName, - ImmutableMap.of(functionSignature, functionBuilder)); } private static FunctionResolver match_bool_prefix() { diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalHighlight.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalHighlight.java index 986a545486..c1e873a00d 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalHighlight.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalHighlight.java @@ -6,9 +6,11 @@ package org.opensearch.sql.planner.logical; import java.util.Collections; +import java.util.Map; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.expression.Expression; @EqualsAndHashCode(callSuper = true) @@ -16,10 +18,16 @@ @ToString public class LogicalHighlight extends LogicalPlan { private final Expression highlightField; + private final Map arguments; - public LogicalHighlight(LogicalPlan childPlan, Expression field) { + /** + * Constructor of LogicalHighlight. + */ + public LogicalHighlight(LogicalPlan childPlan, Expression highlightField, + Map arguments) { super(Collections.singletonList(childPlan)); - highlightField = field; + this.highlightField = highlightField; + this.arguments = arguments; } @Override diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java index 005a5d84fd..9e07e702de 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java @@ -63,8 +63,9 @@ public LogicalPlan window(LogicalPlan input, return new LogicalWindow(input, windowFunction, windowDefinition); } - public LogicalPlan highlight(LogicalPlan input, Expression field) { - return new LogicalHighlight(input, field); + public LogicalPlan highlight(LogicalPlan input, Expression field, + Map arguments) { + return new LogicalHighlight(input, field, arguments); } public static LogicalPlan remove(LogicalPlan input, ReferenceExpression... fields) { diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 19bfcabfec..31309a1953 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -27,10 +27,12 @@ import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; import static org.opensearch.sql.ast.tree.Sort.SortOrder; import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; +import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.LONG; import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -270,16 +272,41 @@ public void project_source() { @Test public void project_highlight() { + Map args = new HashMap<>(); + args.put("pre_tags", new Literal("", DataType.STRING)); + args.put("post_tags", new Literal("", DataType.STRING)); + + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table), + DSL.literal("fieldA"), args), + DSL.named("highlight(fieldA, pre_tags='', post_tags='')", + new HighlightExpression(DSL.literal("fieldA"))) + ), + AstDSL.projectWithArg( + AstDSL.relation("schema"), + AstDSL.defaultFieldsArgs(), + AstDSL.alias("highlight(fieldA, pre_tags='', post_tags='')", + new HighlightFunction(AstDSL.stringLiteral("fieldA"), args)) + ) + ); + } + + @Test + public void project_highlight_wildcard() { + Map args = new HashMap<>(); assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table), - DSL.literal("fieldA")), - DSL.named("highlight(fieldA)", new HighlightExpression(DSL.literal("fieldA"))) + DSL.literal("*"), args), + DSL.named("highlight(*)", + new HighlightExpression(DSL.literal("*"))) ), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("highlight(fieldA)", new HighlightFunction(AstDSL.stringLiteral("fieldA"))) + AstDSL.alias("highlight(*)", + new HighlightFunction(AstDSL.stringLiteral("*"), args)) ) ); } diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index 5787d08f3c..6af6a27d04 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -8,7 +8,6 @@ import static java.util.Collections.emptyList; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.opensearch.sql.ast.dsl.AstDSL.field; @@ -28,6 +27,7 @@ import com.google.common.collect.ImmutableMap; import java.util.Collections; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -36,9 +36,11 @@ import org.opensearch.sql.analysis.symbol.Namespace; import org.opensearch.sql.analysis.symbol.Symbol; import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.HighlightFunction; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.RelevanceFieldList; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.expression.UnresolvedExpression; @@ -50,7 +52,6 @@ import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.HighlightExpression; import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction; @@ -163,6 +164,17 @@ public void castAnalyzer() { "boolean_value"), AstDSL.stringLiteral("INTERVAL")))); } + @Test + public void highlight_throws_semantic_check_exception() { + Map args = new HashMap<>(); + HighlightFunction highlightFunction = new HighlightFunction( + AstDSL.stringLiteral("invalid_field"), args); + Alias alias = AstDSL.alias("highlight(invalid_field)", + highlightFunction); + + assertThrows(SemanticCheckException.class, () -> analyze(alias)); + } + @Test public void case_with_default_result_type_different() { UnresolvedExpression caseWhen = AstDSL.caseWhen( @@ -592,12 +604,6 @@ public void now_as_a_function_not_cached() { assertTrue(values.stream().noneMatch(v -> v.valueOf(null) == referenceValue)); } - @Test - void highlight() { - assertAnalyzeEqual(new HighlightExpression(DSL.literal("fieldA")), - new HighlightFunction(stringLiteral("fieldA"))); - } - protected Expression analyze(UnresolvedExpression unresolvedExpression) { return expressionAnalyzer.analyze(unresolvedExpression, analysisContext); } diff --git a/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java index 2293d125aa..b944115a48 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java @@ -8,15 +8,14 @@ import static org.junit.jupiter.api.Assertions.assertEquals; +import java.util.HashMap; +import java.util.Map; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.HighlightFunction; -import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.expression.DSL; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.LiteralExpression; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.config.ExpressionConfig; import org.springframework.context.annotation.Configuration; @@ -40,8 +39,10 @@ void visit_named_select_item() { @Test void visit_highlight() { + Map args = new HashMap<>(); Alias alias = AstDSL.alias("highlight(fieldA)", - new HighlightFunction(AstDSL.stringLiteral("fieldA"))); + new HighlightFunction( + AstDSL.stringLiteral("fieldA"), args)); NamedExpressionAnalyzer analyzer = new NamedExpressionAnalyzer(expressionAnalyzer); diff --git a/core/src/test/java/org/opensearch/sql/expression/HighlightExpressionTest.java b/core/src/test/java/org/opensearch/sql/expression/HighlightExpressionTest.java index c6e2dccf69..41f3bad030 100644 --- a/core/src/test/java/org/opensearch/sql/expression/HighlightExpressionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/HighlightExpressionTest.java @@ -6,18 +6,19 @@ package org.opensearch.sql.expression; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; import com.google.common.collect.ImmutableMap; -import com.google.errorprone.annotations.DoNotCall; import org.junit.jupiter.api.Test; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.expression.env.Environment; + public class HighlightExpressionTest extends ExpressionTestBase { @Test @@ -35,33 +36,73 @@ public void single_highlight_test() { public void missing_highlight_test() { Environment hlTuple = ExprValueUtils.tupleValue( ImmutableMap.of("_highlight.Title", "result value")).bindingTuples(); + HighlightExpression expr = new HighlightExpression(DSL.literal("invalid")); ExprValue resultVal = expr.valueOf(hlTuple); assertTrue(resultVal.isMissing()); } - /** - * Enable me when '*' is supported in highlight. - */ - @DoNotCall + @Test + public void missing_highlight_wildcard_test() { + ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); + var hlBuilder = ImmutableMap.builder(); + hlBuilder.put("Title", ExprValueUtils.stringValue("first result value")); + hlBuilder.put("Body", ExprValueUtils.stringValue("secondary result value")); + builder.put("_highlight", ExprTupleValue.fromExprValueMap(hlBuilder.build())); + + HighlightExpression hlExpr = new HighlightExpression(DSL.literal("invalid*")); + ExprValue resultVal = hlExpr.valueOf( + ExprTupleValue.fromExprValueMap(builder.build()).bindingTuples()); + + assertTrue(resultVal.isMissing()); + } + + @Test public void highlight_all_test() { ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); var hlBuilder = ImmutableMap.builder(); hlBuilder.put("Title", ExprValueUtils.stringValue("correct result value")); - hlBuilder.put("Body", ExprValueUtils.stringValue("incorrect result value")); + hlBuilder.put("Body", ExprValueUtils.stringValue("secondary correct result value")); builder.put("_highlight", ExprTupleValue.fromExprValueMap(hlBuilder.build())); - HighlightExpression hlExpr = new HighlightExpression(DSL.literal("*")); + HighlightExpression hlExpr = new HighlightExpression(DSL.literal("T*")); ExprValue resultVal = hlExpr.valueOf( ExprTupleValue.fromExprValueMap(builder.build()).bindingTuples()); - assertEquals(ARRAY, resultVal.type()); - for (var field : resultVal.tupleValue().entrySet()) { - assertTrue(field.toString().contains(hlExpr.getHighlightField().toString())); - } + + assertEquals(STRUCT, resultVal.type()); assertTrue(resultVal.tupleValue().containsValue( - ExprValueUtils.stringValue("\"correct result value\""))); + ExprValueUtils.stringValue("correct result value"))); + assertFalse(resultVal.tupleValue().containsValue( + ExprValueUtils.stringValue("secondary correct result value"))); + } + + @Test + public void do_nothing_with_missing_value() { + Environment hlTuple = ExprValueUtils.tupleValue( + ImmutableMap.of("NonHighlightField", "ResultValue")).bindingTuples(); + HighlightExpression expr = new HighlightExpression(DSL.literal("*")); + ExprValue resultVal = expr.valueOf(hlTuple); + + assertTrue(resultVal.isMissing()); + } + + @Test + public void highlight_wildcard_test() { + ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); + var hlBuilder = ImmutableMap.builder(); + hlBuilder.put("Title", ExprValueUtils.stringValue("correct result value")); + hlBuilder.put("Body", ExprValueUtils.stringValue("incorrect result value")); + builder.put("_highlight", ExprTupleValue.fromExprValueMap(hlBuilder.build())); + + HighlightExpression hlExpr = new HighlightExpression(DSL.literal("T*")); + ExprValue resultVal = hlExpr.valueOf( + ExprTupleValue.fromExprValueMap(builder.build()).bindingTuples()); + + assertEquals(STRUCT, resultVal.type()); assertTrue(resultVal.tupleValue().containsValue( - ExprValueUtils.stringValue("\"correct result value\""))); + ExprValueUtils.stringValue("correct result value"))); + assertFalse(resultVal.tupleValue().containsValue( + ExprValueUtils.stringValue("incorrect result value"))); } } diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index c90ea365d2..329708b7d8 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -13,6 +13,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.HashMap; +import java.util.Map; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Test; @@ -117,8 +118,9 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() { assertNull(rareTopN.accept(new LogicalPlanNodeVisitor() { }, null)); + Map args = new HashMap<>(); LogicalPlan highlight = new LogicalHighlight(filter, - new LiteralExpression(ExprValueUtils.stringValue("fieldA"))); + new LiteralExpression(ExprValueUtils.stringValue("fieldA")), args); assertNull(highlight.accept(new LogicalPlanNodeVisitor() { }, null)); diff --git a/docs/user/ppl/functions/relevance.rst b/docs/user/ppl/functions/relevance.rst index 7f7cf50964..6f34ae4e2d 100644 --- a/docs/user/ppl/functions/relevance.rst +++ b/docs/user/ppl/functions/relevance.rst @@ -352,6 +352,32 @@ Another example to show how to set custom values for the optional parameters:: | 1 | The House at Pooh Corner | Alan Alexander Milne | +------+--------------------------+----------------------+ + +HIGHLIGHT +------------ + +Description +>>>>>>>>>>> + +``highlight(field_expression)`` + +The highlight function maps to the highlight function used in search engine to return highlight fields for the given search. +The syntax allows to specify the field in double quotes or single quotes or without any wrap. +Please refer to examples below: + +| ``highlight(title)`` + +Example searching for field Tags:: + + os> source=books | where query_string(['title'], 'Pooh House') | fields highlight(title); + fetched rows / total rows = 2/2 + +----------------------------------------------+ + | highlight(title) | + |----------------------------------------------| + | [The House at Pooh Corner] | + | [Winnie-the-Pooh] | + +----------------------------------------------+ + Limitations >>>>>>>>>>> diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/HighlightFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/HighlightFunctionIT.java new file mode 100644 index 0000000000..79e0a865db --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/HighlightFunctionIT.java @@ -0,0 +1,150 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl; + +import java.io.IOException; +import java.util.List; +import com.google.common.collect.ImmutableMap; +import org.json.JSONArray; +import org.json.JSONObject; +import org.junit.Test; +import org.opensearch.client.ResponseException; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.legacy.TestsConstants; + +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; + +public class HighlightFunctionIT extends PPLIntegTestCase { + @Override + public void init() throws IOException { + loadIndex(Index.BEER); + } + + @Test + public void single_highlight_test() throws IOException { + JSONObject result = + executeQuery( + String.format( + "SOURCE=%s | WHERE match(Title, 'Cicerone') | fields highlight(Title)", TestsConstants.TEST_INDEX_BEER)); + + assertEquals(1, result.getInt("total")); + + verifySchema(result, schema("highlight(Title)", null, "array")); + verifyDataRows(result, rows(new JSONArray(List.of("What exactly is a Cicerone? What do they do?")))); + + } + + @Test + public void highlight_optional_arguments_test() throws IOException { + JSONObject result = + executeQuery( + String.format( + "SOURCE=%s | WHERE match(Title, 'Cicerone') | fields " + + "highlight(Title, pre_tags='', post_tags='')", TestsConstants.TEST_INDEX_BEER)); + + assertEquals(1, result.getInt("total")); + + verifyDataRows(result, rows(new JSONArray(List.of("What exactly is a Cicerone? What do they do?")))); + verifySchema(result, schema("highlight(Title, pre_tags='', post_tags='')", + null, "array")); + } + + @Test + public void highlight_multiple_optional_arguments_test() throws IOException { + JSONObject result = + executeQuery( + String.format( + "SOURCE=%s | WHERE multi_match([Title, Body], 'IPA') | fields highlight(Title), highlight(Body, " + + "pre_tags='', post_tags='') | head 1", + TestsConstants.TEST_INDEX_BEER)); + + verifyDataRows(result, rows(new JSONArray(List.of("What are the differences between an IPA and its variants?")), + new JSONArray(List.of("

I know what makes an IPA an IPA, but what are the unique characteristics of it's common variants?", + "To be specific, the ones I'm interested in are Double IPA and Black IPA, but general differences between")))); + } + + @Test + public void quoted_highlight_test() throws IOException { + JSONObject result = + executeQuery( + String.format( + "SOURCE=%s | WHERE match(Title, 'Cicerone') | fields highlight('Title')", TestsConstants.TEST_INDEX_BEER)); + + assertEquals(1, result.getInt("total")); + verifyDataRows(result, rows(new JSONArray(List.of("What exactly is a Cicerone? What do they do?")))); + } + + @Test + public void multiple_highlights_test() throws IOException { + JSONObject result = + executeQuery( + String.format( + "SOURCE=%s | WHERE multi_match([Title, Body], 'IPA') | fields highlight('Title'), highlight(Body) | HEAD 1", + TestsConstants.TEST_INDEX_BEER)); + + assertEquals(1, result.getInt("total")); + verifyDataRows(result, rows(new JSONArray(List.of("What are the differences between an IPA and its variants?")), + new JSONArray(List.of("

I know what makes an IPA an IPA, but what are the unique characteristics of it's common variants?", + "To be specific, the ones I'm interested in are Double IPA and Black IPA, but general differences between")))); + } + + @Test + public void highlight_wildcard_test() throws IOException { + JSONObject result = + executeQuery( + String.format( + "SOURCE=%s | WHERE multi_match([Title, Body], 'Cicerone') | fields highlight('T*')", + TestsConstants.TEST_INDEX_BEER)); + + assertEquals(1, result.getInt("total")); + verifyDataRows(result, rows(new JSONObject(ImmutableMap.of( + "Title", new JSONArray(List.of("What exactly is a Cicerone? What do they do?")))))); + verifySchema(result, schema("highlight('T*')", null, "struct")); + } + + @Test + public void highlight_all_test() throws IOException { + JSONObject result = + executeQuery( + String.format( + "SOURCE=%s | WHERE multi_match([Title, Body], 'Cicerone') | fields highlight('*')", + TestsConstants.TEST_INDEX_BEER)); + + assertEquals(1, result.getInt("total")); + verifySchema(result, schema("highlight('*')", null, "struct")); + + verifyDataRows(result, rows(new JSONObject(ImmutableMap.of( + "Title", new JSONArray(List.of("What exactly is a Cicerone? What do they do?")), + "Body", new JSONArray(List.of("

Recently I've started seeing references to the term 'Cicerone' " + + "pop up around the internet; generally", "What exactly does a cicerone do?")) + )))); + } + + @Test + public void highlight_semantic_check_test() throws SemanticCheckException { + String query = String.format("SOURCE=%s | WHERE multi_match([Title, Body], 'Cicerone') | fields - highlight('*')", + TestsConstants.TEST_INDEX_BEER); + queryShouldThrowSemanticException(query, "can't resolve Symbol highlight('*') in type env"); + } + + private void queryShouldThrowSemanticException(String query, String... messages) { + try { + executeQuery(query); + fail("Expected to throw SemanticCheckException, but none was thrown for query: " + query); + } catch (ResponseException e) { + String errorMsg = e.getMessage(); + assertTrue(errorMsg.contains("SemanticCheckException")); + for (String msg : messages) { + assertTrue(errorMsg.contains(msg)); + } + } catch (IOException e) { + throw new IllegalStateException("Unexpected exception raised for query: " + query); + } + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java index 422f71968f..809e2dc7c5 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java @@ -5,14 +5,18 @@ package org.opensearch.sql.sql; +import static org.opensearch.sql.util.MatcherUtils.rows; import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; import static org.opensearch.sql.util.MatcherUtils.verifySchema; -import com.google.errorprone.annotations.DoNotCall; +import com.google.common.collect.ImmutableMap; +import org.json.JSONArray; import org.json.JSONObject; import org.junit.Test; import org.opensearch.sql.legacy.SQLIntegTestCase; import org.opensearch.sql.legacy.TestsConstants; +import java.util.List; public class HighlightFunctionIT extends SQLIntegTestCase { @@ -25,57 +29,104 @@ protected void init() throws Exception { public void single_highlight_test() { String query = "SELECT Tags, highlight('Tags') FROM %s WHERE match(Tags, 'yeast') LIMIT 1"; JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); + verifySchema(response, schema("Tags", null, "text"), schema("highlight('Tags')", null, "nested")); assertEquals(1, response.getInt("total")); + + verifyDataRows(response, + rows("alcohol-level yeast home-brew champagne", + new JSONArray(List.of("alcohol-level yeast home-brew champagne")))); } @Test - public void accepts_unquoted_test() { - String query = "SELECT Tags, highlight(Tags) FROM %s WHERE match(Tags, 'yeast') LIMIT 1"; + public void highlight_optional_arguments_test() { + String query = "SELECT highlight('Tags', pre_tags='', post_tags='') " + + "FROM %s WHERE match(Tags, 'yeast') LIMIT 1"; JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); - verifySchema(response, schema("Tags", null, "text"), - schema("highlight(Tags)", null, "nested")); + + verifySchema(response, schema("highlight('Tags', pre_tags='', post_tags='')", + null, "nested")); + assertEquals(1, response.getInt("total")); + + verifyDataRows(response, + rows(new JSONArray(List.of("alcohol-level yeast home-brew champagne")))); + } + + @Test + public void highlight_multiple_optional_arguments_test() { + String query = "SELECT highlight(Title), highlight(Body, pre_tags='', post_tags='') FROM %s WHERE multi_match([Title, Body], 'IPA') LIMIT 1"; + JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); + + verifySchema(response, schema("highlight(Title)", null, "nested"), + schema("highlight(Body, pre_tags='', " + + "post_tags='')", null, "nested")); + + assertEquals(1, response.getInt("total")); + + verifyDataRows(response, rows(new JSONArray(List.of("What are the differences between an IPA" + + " and its variants?")), + new JSONArray(List.of("

I know what makes an IPA" + + " an IPA, but what are the unique characteristics of it's" + + " common variants?", + "To be specific, the ones I'm interested in are Double IPA " + + "and Black IPA, but general differences" + + " between")))); } @Test public void multiple_highlight_test() { - String query = "SELECT highlight(Title), highlight(Body) FROM %s WHERE MULTI_MATCH([Title, Body], 'hops') LIMIT 1"; + String query = "SELECT highlight(Title), highlight(Tags) FROM %s WHERE MULTI_MATCH([Title, Tags], 'hops') LIMIT 1"; JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); verifySchema(response, schema("highlight(Title)", null, "nested"), - schema("highlight(Body)", null, "nested")); + schema("highlight(Tags)", null, "nested")); assertEquals(1, response.getInt("total")); + + verifyDataRows(response, + rows( new JSONArray(List.of("What uses do hops have outside of brewing?")), + new JSONArray(List.of("hops history")))); } - // Enable me when * is supported - @DoNotCall + @Test public void wildcard_highlight_test() { - String query = "SELECT highlight('*itle') FROM %s WHERE MULTI_MATCH([Title, Body], 'hops') LIMIT 1"; + String query = "SELECT highlight('*itle') FROM %s WHERE MULTI_MATCH([Title, Tags], 'hops') LIMIT 1"; JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); - verifySchema(response, schema("highlight('*itle')", null, "nested")); + + verifySchema(response, schema("highlight('*itle')", null, "object")); assertEquals(1, response.getInt("total")); + + verifyDataRows(response, rows(new JSONObject(ImmutableMap.of( + "Title", new JSONArray(List.of("What uses do hops have outside of brewing?")))))); } - // Enable me when * is supported - @DoNotCall + @Test public void wildcard_multi_field_highlight_test() { String query = "SELECT highlight('T*') FROM %s WHERE MULTI_MATCH([Title, Tags], 'hops') LIMIT 1"; JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); - verifySchema(response, schema("highlight('T*')", null, "nested")); - var resultMap = response.getJSONArray("datarows").getJSONArray(0).getJSONObject(0); + + verifySchema(response, schema("highlight('T*')", null, "object")); assertEquals(1, response.getInt("total")); - assertTrue(resultMap.has("highlight(\"T*\").Title")); - assertTrue(resultMap.has("highlight(\"T*\").Tags")); + + verifyDataRows(response, rows(new JSONObject(ImmutableMap.of( + "Title", new JSONArray(List.of("What uses do hops have outside of brewing?")), + "Tags", new JSONArray(List.of("hops history")))))); } - // Enable me when * is supported - @DoNotCall + @Test public void highlight_all_test() { - String query = "SELECT highlight('*') FROM %s WHERE MULTI_MATCH([Title, Body], 'hops') LIMIT 1"; + String query = "SELECT highlight('*') FROM %s WHERE MULTI_MATCH([Title, Tags], 'hops') LIMIT 1"; JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); - verifySchema(response, schema("highlight('*')", null, "nested")); + + verifySchema(response, schema("highlight('*')", null, "object")); assertEquals(1, response.getInt("total")); + + verifyDataRows(response, rows(new JSONObject(ImmutableMap.of( + "Title", new JSONArray(List.of("What uses do hops have outside of brewing?")), + "Tags", new JSONArray(List.of("hops history")))))); } @Test @@ -84,5 +135,15 @@ public void highlight_no_limit_test() { JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); verifySchema(response, schema("highlight(Body)", null, "nested")); assertEquals(2, response.getInt("total")); + + verifyDataRows(response, rows(new JSONArray(List.of("Boiling affects hops, by boiling" + + " off the aroma and extracting more of the organic acids that provide"))), + + rows(new JSONArray(List.of("

Do hops have (or had in the past) any use outside of brewing beer?", + "when-was-the-first-beer-ever-brewed\">dating first modern beers we have the first record" + + " of cultivating hops", + "predating the first record of use of hops in beer by nearly a century.", + "Could the hops have been cultivated for any other purpose than brewing, " + + "or can we safely assume if they")))); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java index 646395d790..c26413c622 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java @@ -25,9 +25,11 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; import org.opensearch.search.sort.SortBuilder; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; @@ -173,14 +175,33 @@ public void pushDownLimit(Integer limit, Integer offset) { * Add highlight to DSL requests. * @param field name of the field to highlight */ - public void pushDownHighlight(String field) { + public void pushDownHighlight(String field, Map arguments) { + String unquotedField = StringUtils.unquoteText(field); if (sourceBuilder.highlighter() != null) { - sourceBuilder.highlighter().field(StringUtils.unquoteText(field)); + // OS does not allow duplicates of highlight fields + if (sourceBuilder.highlighter().fields().stream() + .anyMatch(f -> f.name().equals(unquotedField))) { + throw new SemanticCheckException(String.format( + "Duplicate field %s in highlight", field)); + } + + sourceBuilder.highlighter().field(unquotedField); } else { HighlightBuilder highlightBuilder = - new HighlightBuilder().field(StringUtils.unquoteText(field)); + new HighlightBuilder().field(unquotedField); sourceBuilder.highlighter(highlightBuilder); } + + // lastFieldIndex denotes previously set highlighter with field parameter + int lastFieldIndex = sourceBuilder.highlighter().fields().size() - 1; + if (arguments.containsKey("pre_tags")) { + sourceBuilder.highlighter().fields().get(lastFieldIndex) + .preTags(arguments.get("pre_tags").toString()); + } + if (arguments.containsKey("post_tags")) { + sourceBuilder.highlighter().fields().get(lastFieldIndex) + .postTags(arguments.get("post_tags").toString()); + } } /** diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index ef6159020f..2849fbbec9 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -207,7 +207,8 @@ public PhysicalPlan visitAD(LogicalAD node, OpenSearchIndexScan context) { @Override public PhysicalPlan visitHighlight(LogicalHighlight node, OpenSearchIndexScan context) { - context.getRequestBuilder().pushDownHighlight(node.getHighlightField().toString()); + context.getRequestBuilder().pushDownHighlight( + StringUtils.unquoteText(node.getHighlightField().toString()), node.getArguments()); return visitChild(node, context); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java index 64b87aa2c5..ced87a7d31 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java @@ -88,6 +88,6 @@ public void visitHighlight() { new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); implementor.visitHighlight(node, indexScan); - verify(requestBuilder).pushDownHighlight(any()); + verify(requestBuilder).pushDownHighlight(any(), any()); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java index a1f2869ca5..9a606750a3 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java @@ -9,7 +9,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -20,6 +19,8 @@ import com.google.common.collect.ImmutableMap; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -32,9 +33,12 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.request.OpenSearchQueryRequest; @@ -190,14 +194,49 @@ void pushDownFilters() { @Test void pushDownHighlight() { + Map args = new HashMap<>(); assertThat() .pushDown(QueryBuilders.termQuery("name", "John")) - .pushDownHighlight("Title") - .pushDownHighlight("Body") + .pushDownHighlight("Title", args) + .pushDownHighlight("Body", args) .shouldQueryHighlight(QueryBuilders.termQuery("name", "John"), new HighlightBuilder().field("Title").field("Body")); } + @Test + void pushDownHighlightWithArguments() { + Map args = new HashMap<>(); + args.put("pre_tags", new Literal("", DataType.STRING)); + args.put("post_tags", new Literal("", DataType.STRING)); + HighlightBuilder highlightBuilder = new HighlightBuilder() + .field("Title"); + highlightBuilder.fields().get(0).preTags("").postTags(""); + assertThat() + .pushDown(QueryBuilders.termQuery("name", "John")) + .pushDownHighlight("Title", args) + .shouldQueryHighlight(QueryBuilders.termQuery("name", "John"), + highlightBuilder); + } + + @Test + void pushDownHighlightWithRepeatingFields() { + mockResponse( + new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, + new ExprValue[]{employee(3, "Allen", "IT"), employee(4, "Bob", "HR")}); + + try (OpenSearchIndexScan indexScan = + new OpenSearchIndexScan(client, settings, "test", 2, exprValueFactory)) { + indexScan.getRequestBuilder().pushDownLimit(3, 0); + indexScan.open(); + Map args = new HashMap<>(); + indexScan.getRequestBuilder().pushDownHighlight("name", args); + indexScan.getRequestBuilder().pushDownHighlight("name", args); + } catch (SemanticCheckException e) { + assertTrue(e.getClass().equals(SemanticCheckException.class)); + } + verify(client).cleanup(any()); + } + private PushDownAssertion assertThat() { return new PushDownAssertion(client, exprValueFactory, settings); } @@ -223,8 +262,8 @@ PushDownAssertion pushDown(QueryBuilder query) { return this; } - PushDownAssertion pushDownHighlight(String query) { - indexScan.getRequestBuilder().pushDownHighlight(query); + PushDownAssertion pushDownHighlight(String query, Map arguments) { + indexScan.getRequestBuilder().pushDownHighlight(query, arguments); return this; } diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 230e183855..76b7bf221b 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -333,6 +333,9 @@ SLOP: 'SLOP'; TIE_BREAKER: 'TIE_BREAKER'; TYPE: 'TYPE'; ZERO_TERMS_QUERY: 'ZERO_TERMS_QUERY'; +HIGHLIGHT: 'HIGHLIGHT'; +HIGHLIGHT_PRE_TAGS: 'PRE_TAGS'; +HIGHLIGHT_POST_TAGS: 'POST_TAGS'; // SPAN KEYWORDS SPAN: 'SPAN'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 5773d4975c..c0f965bfa7 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -296,6 +296,7 @@ sortFieldExpression fieldExpression : qualifiedName + | highlightFunction ; wcFieldExpression @@ -307,6 +308,10 @@ evalFunctionCall : evalFunctionName LT_PRTHS functionArgs RT_PRTHS ; +highlightFunction + : HIGHLIGHT LT_PRTHS relevanceField (COMMA highlightArg)* RT_PRTHS #highlightFunctionCall + ; + /** cast function */ dataTypeFunctionCall : CAST LT_PRTHS expression AS convertedDataType RT_PRTHS @@ -349,6 +354,10 @@ relevanceArg : relevanceArgName EQUAL relevanceArgValue ; +highlightArg + : highlightArgName EQUAL highlightArgValue + ; + relevanceArgName : ALLOW_LEADING_WILDCARD | ANALYZER | ANALYZE_WILDCARD | AUTO_GENERATE_SYNONYMS_PHRASE_QUERY | BOOST | CUTOFF_FREQUENCY | DEFAULT_FIELD | DEFAULT_OPERATOR | ENABLE_POSITION_INCREMENTS @@ -359,6 +368,11 @@ relevanceArgName | ZERO_TERMS_QUERY ; + +highlightArgName + : HIGHLIGHT_POST_TAGS| HIGHLIGHT_PRE_TAGS + ; + relevanceFieldAndWeight : field=relevanceField | field=relevanceField weight=relevanceFieldWeight @@ -384,6 +398,10 @@ relevanceArgValue | literalValue ; +highlightArgValue + : stringLiteral + ; + mathematicalFunctionBase : ABS | CEIL | CEILING | CONV | CRC32 | E | EXP | FLOOR | LN | LOG | LOG10 | LOG2 | MOD | PI |POW | POWER | RAND | ROUND | SIGN | SQRT | TRUNCATE diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 2f9fed6e62..8a94fbdf34 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -134,12 +134,10 @@ public UnresolvedPlan visitWhereCommand(WhereCommandContext ctx) { */ @Override public UnresolvedPlan visitFieldsCommand(FieldsCommandContext ctx) { + ImmutableList.Builder builder = new ImmutableList.Builder<>(); + ctx.fieldList().fieldExpression().forEach(field -> builder.add(visitFieldsItem(field))); return new Project( - ctx.fieldList() - .fieldExpression() - .stream() - .map(this::internalVisitExpression) - .collect(Collectors.toList()), + builder.build(), ArgumentFactory.getArgumentList(ctx) ); } @@ -381,6 +379,22 @@ public UnresolvedPlan visitAdCommand(AdCommandContext ctx) { return new AD(builder.build()); } + /** + * Returns expression for both fields and highlight functions in fields command. + * @param ctx : field or highlight function context + * @return : Return Alias of highlight expression or Field for field expression + */ + private UnresolvedExpression visitFieldsItem(OpenSearchPPLParser.FieldExpressionContext ctx) { + if (ctx.qualifiedName() != null) { + return internalVisitExpression(ctx); + } + // If not field expression then is a highlight expression + String name = StringUtils.unquoteIdentifier(getTextInQuery(ctx.highlightFunction())); + UnresolvedExpression expr = internalVisitExpression(ctx.highlightFunction()); + + return new Alias(name, expr); + } + /** * Get original text in query. */ diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 5df1c4ec56..faad34521f 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -66,6 +66,7 @@ import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.HighlightFunction; import org.opensearch.sql.ast.expression.In; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.IntervalUnit; @@ -208,6 +209,19 @@ public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunction return new AggregateFunction("count", visit(ctx.valueExpression()), true); } + @Override + public UnresolvedExpression visitHighlightFunctionCall( + OpenSearchPPLParser.HighlightFunctionCallContext ctx) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + ctx.highlightArg().forEach(v -> builder.put( + v.highlightArgName().getText().toLowerCase(), + new Literal(StringUtils.unquoteText(v.highlightArgValue().getText()), + DataType.STRING)) + ); + return new HighlightFunction(AstDSL.stringLiteral( + StringUtils.unquoteText(ctx.relevanceField().getText())), builder.build()); + } + @Override public UnresolvedExpression visitPercentileAggFunction(PercentileAggFunctionContext ctx) { return new AggregateFunction(ctx.PERCENTILE().getText(), visit(ctx.aggField), diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java index ec513c7c4d..94bf58e708 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java @@ -21,6 +21,7 @@ import org.opensearch.sql.ast.expression.Compare; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.HighlightFunction; import org.opensearch.sql.ast.expression.Interval; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; @@ -300,5 +301,16 @@ public String visitAlias(Alias node, String context) { String expr = node.getDelegated().accept(this, context); return StringUtils.format("%s", expr); } + + @Override + public String visitHighlightFunction(HighlightFunction node, String context) { + String args = node.getArguments().containsKey("pre_tags") + ? ", pre_tags='" + node.getArguments().get("pre_tags") + "'" + : ""; + args += node.getArguments().containsKey("post_tags") + ? ", post_tags='" + node.getArguments().get("post_tags") + "'" + : ""; + return StringUtils.format("highlight(%s%s)", node.getHighlightField(), args); + } } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java index dcf961dc24..3d71db327f 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java @@ -80,6 +80,18 @@ public void testTopCommandWithoutNAndGroupByShouldPass() { assertNotEquals(null, tree); } + @Test + public void testHighlightShouldPass() { + ParseTree tree = new PPLSyntaxParser().parse("source=shakespeare | fields highlight(field)"); + assertNotEquals(null, tree); + } + + @Test + public void testQuotedHighlightShouldPass() { + ParseTree tree = new PPLSyntaxParser().parse("source=shakespeare | fields highlight('field')"); + assertNotEquals(null, tree); + } + @Test public void can_parse_multi_match_relevance_function() { assertNotEquals(null, new PPLSyntaxParser().parse( diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index 09909949b3..ba8b74a822 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -25,6 +25,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.filter; import static org.opensearch.sql.ast.dsl.AstDSL.function; import static org.opensearch.sql.ast.dsl.AstDSL.head; +import static org.opensearch.sql.ast.dsl.AstDSL.highlight; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.let; import static org.opensearch.sql.ast.dsl.AstDSL.map; @@ -41,6 +42,8 @@ import static org.opensearch.sql.utils.SystemIndexUtils.mappingTable; import com.google.common.collect.ImmutableMap; +import java.util.HashMap; +import java.util.Map; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; @@ -669,6 +672,46 @@ public void testPatternsCommandWithoutArguments() { ImmutableMap.of())); } + @Test + public void testQuotedHighlightFunction() { + Map args = new HashMap<>(); + assertEqual("source=t | fields highlight('FieldA')", + projectWithArg( + relation("t"), + defaultFieldsArgs(), + alias("highlight('FieldA')", + highlight(stringLiteral("FieldA"), args)) + ) + ); + } + + @Test + public void testUnquotedHighlightFunction() { + Map args = new HashMap<>(); + assertEqual("source=t | fields highlight(FieldA)", + projectWithArg( + relation("t"), + defaultFieldsArgs(), + alias("highlight(FieldA)", + highlight(stringLiteral("FieldA"), args))) + ); + } + + @Test + public void testHighlightCommandWithArguments() { + Map args = new HashMap<>(); + args.put("pre_tags", new Literal("", DataType.STRING)); + args.put("post_tags", new Literal("", DataType.STRING)); + + assertEqual("source=t | fields highlight(FieldA, pre_tags='', post_tags='')", + projectWithArg( + relation("t"), + defaultFieldsArgs(), + alias("highlight(FieldA, pre_tags='', post_tags='')", + highlight(stringLiteral("FieldA"), args))) + ); + } + @Test public void testKmeansCommand() { assertEqual("source=t | kmeans centroids=3 iterations=2 distance_type='l1'", diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java index 7caa4bab13..5d2e7ea0b4 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java @@ -167,6 +167,20 @@ public void testDateFunction() { ); } + @Test + public void testHighlightFunction() { + assertEquals("source=t | fields + highlight(field)", + anonymize("source=t | fields highlight(field)") + ); + } + + @Test + public void testHighlightFunctionWithArguments() { + assertEquals("source=t | fields + highlight(field, pre_tags='', post_tags='')", + anonymize("source=t | fields highlight(field, pre_tags='', post_tags='')") + ); + } + @Test public void anonymizeFieldsNoArg() { assertEquals("source=t | fields + f", diff --git a/sql/src/main/antlr/OpenSearchSQLLexer.g4 b/sql/src/main/antlr/OpenSearchSQLLexer.g4 index 5f2385bab3..42c302493b 100644 --- a/sql/src/main/antlr/OpenSearchSQLLexer.g4 +++ b/sql/src/main/antlr/OpenSearchSQLLexer.g4 @@ -361,6 +361,8 @@ TIME_ZONE: 'TIME_ZONE'; TYPE: 'TYPE'; ZERO_TERMS_QUERY: 'ZERO_TERMS_QUERY'; HIGHLIGHT: 'HIGHLIGHT'; +HIGHLIGHT_PRE_TAGS: 'PRE_TAGS'; +HIGHLIGHT_POST_TAGS: 'POST_TAGS'; // RELEVANCE FUNCTIONS MATCH_BOOL_PREFIX: 'MATCH_BOOL_PREFIX'; diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 23e2d9288d..42009bc5b8 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -315,7 +315,7 @@ constantFunction ; highlightFunction - : HIGHLIGHT LR_BRACKET relevanceField RR_BRACKET + : HIGHLIGHT LR_BRACKET relevanceField (COMMA highlightArg)* RR_BRACKET ; scalarFunctionName @@ -443,6 +443,10 @@ relevanceArg : relevanceArgName EQUAL_SYMBOL relevanceArgValue ; +highlightArg + : highlightArgName EQUAL_SYMBOL highlightArgValue + ; + relevanceArgName : ALLOW_LEADING_WILDCARD | ANALYZER | ANALYZE_WILDCARD | AUTO_GENERATE_SYNONYMS_PHRASE_QUERY | BOOST | CUTOFF_FREQUENCY | DEFAULT_FIELD | DEFAULT_OPERATOR | ENABLE_POSITION_INCREMENTS @@ -453,6 +457,10 @@ relevanceArgName | ZERO_TERMS_QUERY ; +highlightArgName + : HIGHLIGHT_POST_TAGS | HIGHLIGHT_PRE_TAGS + ; + relevanceFieldAndWeight : field=relevanceField | field=relevanceField weight=relevanceFieldWeight @@ -478,3 +486,7 @@ relevanceArgValue | constant ; +highlightArgValue + : stringLiteral + ; + diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index 006ed5fba2..ebfafeec23 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -50,6 +50,7 @@ import static org.opensearch.sql.sql.parser.ParserUtils.createSortOption; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -137,7 +138,15 @@ public UnresolvedExpression visitScalarFunctionCall(ScalarFunctionCallContext ct @Override public UnresolvedExpression visitHighlightFunctionCall( OpenSearchSQLParser.HighlightFunctionCallContext ctx) { - return new HighlightFunction(visit(ctx.highlightFunction().relevanceField())); + ImmutableMap.Builder builder = ImmutableMap.builder(); + ctx.highlightFunction().highlightArg().forEach(v -> builder.put( + v.highlightArgName().getText().toLowerCase(), + new Literal(StringUtils.unquoteText(v.highlightArgValue().getText()), + DataType.STRING)) + ); + + return new HighlightFunction(visit(ctx.highlightFunction().relevanceField()), + builder.build()); } @Override diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java index c3b9ed245a..0c06754261 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java @@ -33,7 +33,9 @@ import static org.opensearch.sql.utils.SystemIndexUtils.mappingTable; import com.google.common.collect.ImmutableList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.stream.Stream; import org.antlr.v4.runtime.tree.ParseTree; import org.junit.jupiter.api.Test; @@ -42,6 +44,8 @@ import org.junit.jupiter.params.provider.MethodSource; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.sql.antlr.SQLSyntaxParser; @@ -731,18 +735,36 @@ public void test_now_like_functions(String name, Boolean hasFsp, Boolean hasShor @Test public void can_build_qualified_name_highlight() { + Map args = new HashMap<>(); assertEquals( project(relation("test"), - alias("highlight(fieldA)", highlight(AstDSL.qualifiedName("fieldA")))), + alias("highlight(fieldA)", + highlight(AstDSL.qualifiedName("fieldA"), args))), buildAST("SELECT highlight(fieldA) FROM test") ); } + @Test + public void can_build_qualified_highlight_with_arguments() { + Map args = new HashMap<>(); + args.put("pre_tags", new Literal("", DataType.STRING)); + args.put("post_tags", new Literal("", DataType.STRING)); + assertEquals( + project(relation("test"), + alias("highlight(fieldA, pre_tags='', post_tags='')", + highlight(AstDSL.qualifiedName("fieldA"), args))), + buildAST("SELECT highlight(fieldA, pre_tags='', post_tags='') " + + "FROM test") + ); + } + @Test public void can_build_string_literal_highlight() { + Map args = new HashMap<>(); assertEquals( project(relation("test"), - alias("highlight(\"fieldA\")", highlight(AstDSL.stringLiteral("fieldA")))), + alias("highlight(\"fieldA\")", + highlight(AstDSL.stringLiteral("fieldA"), args))), buildAST("SELECT highlight(\"fieldA\") FROM test") ); } diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java index ef881275e5..ec0a0dd0d3 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -13,7 +13,6 @@ import static org.opensearch.sql.ast.dsl.AstDSL.caseWhen; import static org.opensearch.sql.ast.dsl.AstDSL.dateLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.doubleLiteral; -import static org.opensearch.sql.ast.dsl.AstDSL.floatLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.function; import static org.opensearch.sql.ast.dsl.AstDSL.highlight; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; @@ -35,12 +34,14 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.HashMap; import org.antlr.v4.runtime.CommonTokenStream; import org.apache.commons.lang3.tuple.ImmutablePair; import org.junit.jupiter.api.Test; import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.RelevanceFieldList; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; @@ -311,16 +312,18 @@ public void canBuildWindowFunctionWithNullOrderSpecified() { @Test public void canBuildStringLiteralHighlightFunction() { + HashMap args = new HashMap<>(); assertEquals( - highlight(AstDSL.stringLiteral("fieldA")), + highlight(AstDSL.stringLiteral("fieldA"), args), buildExprAst("highlight(\"fieldA\")") ); } @Test public void canBuildQualifiedNameHighlightFunction() { + HashMap args = new HashMap<>(); assertEquals( - highlight(AstDSL.qualifiedName("fieldA")), + highlight(AstDSL.qualifiedName("fieldA"), args), buildExprAst("highlight(fieldA)") ); }