diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index efc0dc93a3..f105076e16 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -364,7 +364,7 @@ public LogicalPlan visitHighlight(Highlight node, AnalysisContext context) { HighlightFunction unresolved = (HighlightFunction) ((Alias)node.getExpression()).getDelegated(); Expression field = expressionAnalyzer.analyze(unresolved.getHighlightField(), context); - return new LogicalHighlight(child, field); + return new LogicalHighlight(child, field, node.getArguments(), node.getName()); } diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index 0939e59716..d8370792ec 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -193,10 +193,9 @@ public Expression visitWindowFunction(WindowFunction node, AnalysisContext conte @Override public Expression visitHighlightFunction(HighlightFunction node, AnalysisContext context) { - Expression expr = node.getHighlightField().accept(this, context); - String highlightStr = "highlight(" + StringUtils.unquoteText(expr.toString()) + ")"; - return new ReferenceExpression(highlightStr, expr.toString().contains("*") - ? ExprCoreType.STRUCT : ExprCoreType.ARRAY); + return (node.getName() != null) ? new ReferenceExpression(node.getName(), + node.getName().contains("*") ? ExprCoreType.STRUCT : ExprCoreType.ARRAY) + : null; } @Override diff --git a/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java index 06a601327c..705a2332e1 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/HighlightAnalyzer.java @@ -30,12 +30,15 @@ public LogicalPlan analyze(UnresolvedExpression projectItem, AnalysisContext con @Override public LogicalPlan visitAlias(Alias node, AnalysisContext context) { - if (!(node.getDelegated() instanceof HighlightFunction)) { + UnresolvedExpression delegated = node.getDelegated(); + if (!(delegated instanceof HighlightFunction)) { return null; } - HighlightFunction unresolved = (HighlightFunction) node.getDelegated(); + // Must set name here else operator cannot resolve delegated reference + ((HighlightFunction) delegated).setName(node.getName()); + HighlightFunction unresolved = (HighlightFunction) delegated; Expression field = expressionAnalyzer.analyze(unresolved.getHighlightField(), context); - return new LogicalHighlight(child, field); + return new LogicalHighlight(child, field, unresolved.getArguments(), unresolved.getName()); } } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 99d8aaa882..9a3fb62193 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -267,8 +267,9 @@ public When when(UnresolvedExpression condition, UnresolvedExpression result) { return new When(condition, result); } - public UnresolvedExpression highlight(UnresolvedExpression fieldName) { - return new HighlightFunction(fieldName); + public UnresolvedExpression highlight(UnresolvedExpression fieldName, + java.util.Map arguments) { + return new HighlightFunction(fieldName, arguments); } public UnresolvedExpression window(UnresolvedExpression function, diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/HighlightFunction.java b/core/src/main/java/org/opensearch/sql/ast/expression/HighlightFunction.java index 97ccb4cf5e..f3e44795e9 100644 --- a/core/src/main/java/org/opensearch/sql/ast/expression/HighlightFunction.java +++ b/core/src/main/java/org/opensearch/sql/ast/expression/HighlightFunction.java @@ -6,21 +6,29 @@ package org.opensearch.sql.ast.expression; import java.util.List; -import lombok.AllArgsConstructor; +import java.util.Map; import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.Setter; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; /** * Expression node of Highlight function. */ -@AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter @ToString public class HighlightFunction extends UnresolvedExpression { private final UnresolvedExpression highlightField; + private final Map arguments; + @Setter + private String name; + + public HighlightFunction(UnresolvedExpression highlightField, Map arguments) { + this.highlightField = highlightField; + this.arguments = arguments; + } @Override public T accept(AbstractNodeVisitor nodeVisitor, C context) { diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Highlight.java b/core/src/main/java/org/opensearch/sql/ast/tree/Highlight.java index 93900b49a8..9e929b82cb 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/Highlight.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Highlight.java @@ -7,12 +7,14 @@ import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.Map; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.Setter; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.UnresolvedExpression; /** @@ -25,6 +27,8 @@ @RequiredArgsConstructor public class Highlight extends UnresolvedPlan { private final UnresolvedExpression expression; + private final Map arguments; + private final String name; private UnresolvedPlan child; @Override diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalHighlight.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalHighlight.java index 986a545486..9eb5d48c2d 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalHighlight.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalHighlight.java @@ -6,9 +6,11 @@ package org.opensearch.sql.planner.logical; import java.util.Collections; +import java.util.Map; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.expression.Expression; @EqualsAndHashCode(callSuper = true) @@ -16,10 +18,18 @@ @ToString public class LogicalHighlight extends LogicalPlan { private final Expression highlightField; + private final Map arguments; + private final String name; - public LogicalHighlight(LogicalPlan childPlan, Expression field) { + /** + * Constructor of LogicalHighlight. + */ + public LogicalHighlight(LogicalPlan childPlan, Expression highlightField, + Map arguments, String name) { super(Collections.singletonList(childPlan)); - highlightField = field; + this.highlightField = highlightField; + this.arguments = arguments; + this.name = name; } @Override diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java index 005a5d84fd..2c115d7ac9 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java @@ -63,8 +63,9 @@ public LogicalPlan window(LogicalPlan input, return new LogicalWindow(input, windowFunction, windowDefinition); } - public LogicalPlan highlight(LogicalPlan input, Expression field) { - return new LogicalHighlight(input, field); + public LogicalPlan highlight(LogicalPlan input, Expression field, + Map arguments, String name) { + return new LogicalHighlight(input, field, arguments, name); } public static LogicalPlan remove(LogicalPlan input, ReferenceExpression... fields) { diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/HighlightOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/HighlightOperator.java index 77419f6821..014b6fc18c 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/HighlightOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/HighlightOperator.java @@ -16,10 +16,12 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; +import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; @@ -37,16 +39,16 @@ * */ @EqualsAndHashCode +@AllArgsConstructor public class HighlightOperator extends PhysicalPlan { @Getter private final PhysicalPlan input; @Getter private final Expression highlight; - - public HighlightOperator(PhysicalPlan input, Expression highlight) { - this.input = input; - this.highlight = highlight; - } + @Getter + private final Map arguments; + @Getter + private final String name; @Override public R accept(PhysicalPlanNodeVisitor visitor, C context) { @@ -83,9 +85,9 @@ public ExprValue next() { */ private Pair mapHighlight(Environment env) { String osHighlightKey = "_highlight"; - String highlightStr = StringUtils.unquoteText(highlight.toString()); - if (!highlightStr.contains("*")) { - osHighlightKey += "." + highlightStr; + String highlightFieldStr = StringUtils.unquoteText(highlight.toString()); + if (!highlightFieldStr.contains("*")) { + osHighlightKey += "." + highlightFieldStr; } ReferenceExpression osOutputVar = DSL.ref(osHighlightKey, STRING); @@ -94,19 +96,18 @@ private Pair mapHighlight(Environment // In the event of multiple returned highlights and wildcard being // used in conjunction with other highlight calls, we need to ensure // only wildcard regex matching is mapped to wildcard call. - if (highlightStr.contains("*") && value.type() == STRUCT) { + if (highlightFieldStr.contains("*") && value.type() == STRUCT) { value = new ExprTupleValue( new LinkedHashMap(value.tupleValue() .entrySet() .stream() - .filter(s -> matchesHighlightRegex(s.getKey(), highlightStr)) + .filter(s -> matchesHighlightRegex(s.getKey(), highlightFieldStr)) .collect(Collectors.toMap( e -> e.getKey(), e -> e.getValue())))); } - String sqlHighlightKey = "highlight(" + highlightStr + ")"; - ReferenceExpression sqlOutputVar = DSL.ref(sqlHighlightKey, STRING); + ReferenceExpression sqlOutputVar = DSL.ref(name, STRING); // Add mapping for sql output and opensearch returned highlight fields extendEnv(env, sqlOutputVar, value); diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 9e0bc4b17d..53437efcd8 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -232,32 +232,41 @@ public void top_source() { @Test public void project_highlight() { + Map args = new HashMap<>(); + args.put("pre_tags", new Literal("", DataType.STRING)); + args.put("post_tags", new Literal("", DataType.STRING)); + assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table), - DSL.literal("fieldA")), - DSL.named("highlight(fieldA)", DSL.ref("highlight(fieldA)", ARRAY)) + DSL.literal("fieldA"), args, + "highlight(fieldA, pre_tags='', post_tags='')"), + DSL.named("highlight(fieldA, pre_tags='', post_tags='')", + DSL.ref("highlight(fieldA, pre_tags='', post_tags='')", ARRAY)) ), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("highlight(fieldA)", new HighlightFunction(AstDSL.stringLiteral("fieldA"))) + AstDSL.alias("highlight(fieldA, pre_tags='', post_tags='')", + new HighlightFunction(AstDSL.stringLiteral("fieldA"), args)) ) ); } @Test public void project_highlight_wildcard() { + Map args = new HashMap<>(); assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table), - DSL.literal("*")), + DSL.literal("*"), args, "highlight(*)"), DSL.named("highlight(*)", DSL.ref("highlight(*)", STRUCT)) ), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("highlight(*)", new HighlightFunction(AstDSL.stringLiteral("*"))) + AstDSL.alias("highlight(*)", + new HighlightFunction(AstDSL.stringLiteral("*"), args)) ) ); } diff --git a/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java index 2293d125aa..f724e47103 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java @@ -8,15 +8,14 @@ import static org.junit.jupiter.api.Assertions.assertEquals; +import java.util.HashMap; +import java.util.Map; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.HighlightFunction; -import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.expression.DSL; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.LiteralExpression; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.config.ExpressionConfig; import org.springframework.context.annotation.Configuration; @@ -40,8 +39,25 @@ void visit_named_select_item() { @Test void visit_highlight() { + Map args = new HashMap<>(); + HighlightFunction highlightFunction = new HighlightFunction( + AstDSL.stringLiteral("fieldA"), args); + highlightFunction.setName("highlight(fieldA)"); Alias alias = AstDSL.alias("highlight(fieldA)", - new HighlightFunction(AstDSL.stringLiteral("fieldA"))); + highlightFunction); + NamedExpressionAnalyzer analyzer = + new NamedExpressionAnalyzer(expressionAnalyzer); + + NamedExpression analyze = analyzer.analyze(alias, analysisContext); + assertEquals("highlight(fieldA)", analyze.getNameOrAlias()); + } + + @Test + void visit_highlight_no_set_name() { + Map args = new HashMap<>(); + Alias alias = AstDSL.alias("highlight(fieldA)", + new HighlightFunction( + AstDSL.stringLiteral("fieldA"), args)); NamedExpressionAnalyzer analyzer = new NamedExpressionAnalyzer(expressionAnalyzer); diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalHighlightTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalHighlightTest.java index 237f8a9f1b..1f89bd17d1 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalHighlightTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalHighlightTest.java @@ -10,10 +10,13 @@ import static org.opensearch.sql.ast.dsl.AstDSL.relation; import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; +import java.util.HashMap; +import java.util.Map; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.analysis.AnalyzerTestBase; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.Highlight; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.config.ExpressionConfig; @@ -28,13 +31,15 @@ public class LogicalHighlightTest extends AnalyzerTestBase { @Test public void analyze_highlight_with_one_field() { + Map args = new HashMap<>(); assertAnalyzeEqual( LogicalPlanDSL.highlight( LogicalPlanDSL.relation("schema", table), - DSL.literal("field")), + DSL.literal("field"), args, "highlight(field)"), new Highlight( alias("highlight('field')", - highlight(stringLiteral("field")))) + highlight(stringLiteral("field"), args)), + args, "highlight(field)") .attach(relation("schema"))); } } diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index c90ea365d2..a75c46f673 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -13,6 +13,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.HashMap; +import java.util.Map; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Test; @@ -117,8 +118,9 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() { assertNull(rareTopN.accept(new LogicalPlanNodeVisitor() { }, null)); + Map args = new HashMap<>(); LogicalPlan highlight = new LogicalHighlight(filter, - new LiteralExpression(ExprValueUtils.stringValue("fieldA"))); + new LiteralExpression(ExprValueUtils.stringValue("fieldA")), args, "highlight(fieldA)"); assertNull(highlight.accept(new LogicalPlanNodeVisitor() { }, null)); diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/HighlightOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/HighlightOperatorTest.java index 55f6c94caa..27b30f1588 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/HighlightOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/HighlightOperatorTest.java @@ -17,11 +17,14 @@ import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; import com.google.common.collect.ImmutableMap; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -35,10 +38,11 @@ class HighlightOperatorTest extends PhysicalPlanTestBase { @Test public void do_nothing_with_none_tuple_value() { + Map args = new HashMap<>(); when(inputPlan.hasNext()).thenReturn(true, false); when(inputPlan.next()).thenReturn(ExprValueUtils.integerValue(1)); ReferenceExpression highlightReferenceExp = DSL.ref("reference", STRING); - PhysicalPlan plan = new HighlightOperator(inputPlan, highlightReferenceExp); + PhysicalPlan plan = new HighlightOperator(inputPlan, highlightReferenceExp, args, "reference"); List result = execute(plan); assertTrue(((HighlightOperator)plan).getInput().equals(inputPlan)); @@ -48,6 +52,7 @@ public void do_nothing_with_none_tuple_value() { @Test public void highlight_one_field() { + Map args = new HashMap<>(); when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) .thenReturn( @@ -61,7 +66,8 @@ public void highlight_one_field() { "_highlight.region", "us-east-1", "action", "PUT", "response", 200))); assertThat( - execute(new HighlightOperator(inputPlan, DSL.ref("region", STRING))), + execute(new HighlightOperator(inputPlan, DSL.ref("region", STRING), args, + "highlight(region)")), contains( tupleValue(ImmutableMap.of( "_highlight.region", "us-east-1", "action", "GET", @@ -77,6 +83,7 @@ public void highlight_one_field() { @Test public void highlight_wildcard() { + Map args = new HashMap<>(); when(inputPlan.hasNext()).thenReturn(true, true, false); when(inputPlan.next()) .thenReturn( @@ -90,7 +97,7 @@ public void highlight_wildcard() { "action", "GET", "response", 200))); assertThat( - execute(new HighlightOperator(inputPlan, DSL.ref("r*", STRUCT))), + execute(new HighlightOperator(inputPlan, DSL.ref("r*", STRUCT), args, "highlight(r*)")), contains( tupleValue(ImmutableMap.of( "_highlight", ExprNullValue.of(), diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java index 6aa6630749..fac29bf6fe 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java @@ -17,11 +17,14 @@ import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.HashMap; +import java.util.Map; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.expression.DSL; @@ -134,7 +137,9 @@ public void test_PhysicalPlanVisitor_should_return_null() { assertNull(limit.accept(new PhysicalPlanNodeVisitor() { }, null)); - PhysicalPlan highlight = new HighlightOperator(plan, DSL.ref("reference", STRING)); + Map args = new HashMap<>(); + PhysicalPlan highlight = new HighlightOperator(plan, DSL.ref("reference", STRING), + args, "highlight(reference)"); assertNull(highlight.accept(new PhysicalPlanNodeVisitor() { }, null)); } diff --git a/docs/user/ppl/functions/relevance.rst b/docs/user/ppl/functions/relevance.rst index 6cdb3e10f7..7535a9dd8e 100644 --- a/docs/user/ppl/functions/relevance.rst +++ b/docs/user/ppl/functions/relevance.rst @@ -369,10 +369,10 @@ Please refer to examples below: Example searching for field Tags:: - os> source=books | where query_string(['title'], 'Pooh House') | highlight(title); + os> source=books | where query_string(['title'], 'Pooh House') | highlight title; fetched rows / total rows = 2/2 +------+--------------------------+----------------------+----------------------------------------------+ - | id | title | author | highlight(title) | + | id | title | author | highlight title | |------+--------------------------+----------------------+----------------------------------------------| | 1 | The House at Pooh Corner | Alan Alexander Milne | [The House at Pooh Corner] | | 2 | Winnie-the-Pooh | Alan Alexander Milne | [Winnie-the-Pooh] | diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/HighlightFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/HighlightFunctionIT.java index 6feaf2e623..7a9ad8fa10 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/HighlightFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/HighlightFunctionIT.java @@ -7,11 +7,14 @@ import java.io.IOException; import org.json.JSONObject; -import org.junit.jupiter.api.Test; +import org.junit.Test; import org.opensearch.sql.legacy.TestsConstants; public class HighlightFunctionIT extends PPLIntegTestCase { - + // allFields is returned since we can't use highlight in a fields command. + // Additional highlight fields begin at index 19 + private int firstHighlightFieldIndex = 19; + private int secondHighlightFieldIndex = 20; @Override public void init() throws IOException { loadIndex(Index.BEER); @@ -22,28 +25,64 @@ public void single_highlight_test() throws IOException { JSONObject result = executeQuery( String.format( - "SOURCE=%s | WHERE match(Title, 'Cicerone') | highlight(Title)", TestsConstants.TEST_INDEX_BEER)); + "SOURCE=%s | WHERE match(Title, 'Cicerone') | highlight Title", TestsConstants.TEST_INDEX_BEER)); assertEquals(1, result.getInt("total")); assertTrue( - result.getJSONArray("datarows") - .getJSONArray(0) - .getJSONArray(19) - .getString(0) - .equals("What exactly is a Cicerone? What do they do?")); + verifyFirstIndexHighlight(result, + "What exactly is a Cicerone? What do they do?") + ); + + assertTrue( + verifyFirstIndexHighlight(result, "What exactly is a Cicerone? What do they do?") + ); + + assertTrue( + verifyHighlightSchema(result, "highlight Title") + ); + } + + @Test + public void highlight_optional_arguments_test() throws IOException { + JSONObject result = + executeQuery( + String.format( + "SOURCE=%s | WHERE match(Title, 'Cicerone') | " + + "highlight Title, pre_tags='', post_tags=''", TestsConstants.TEST_INDEX_BEER)); + + assertEquals(1, result.getInt("total")); assertTrue( - result.getJSONArray("schema") - .getJSONObject(19) - .getString("name") - .equals("highlight(Title)")); + verifyFirstIndexHighlight(result, "What exactly is a Cicerone? What do they do?") + ); assertTrue( - result.getJSONArray("schema") - .getJSONObject(19) - .getString("type") - .equals("nested")); + verifyHighlightSchema(result, "highlight Title, pre_tags='', post_tags=''") + ); + } + + @Test + public void highlight_multiple_optional_arguments_test() throws IOException { + JSONObject result = + executeQuery( + String.format( + "SOURCE=%s | WHERE multi_match([Title, Body], 'IPA') | highlight Title | highlight Body, " + + "pre_tags='', post_tags=''", + TestsConstants.TEST_INDEX_BEER)); + + assertEquals(3, result.getInt("total")); + + assertTrue( + verifyFirstIndexHighlight(result, "What are the differences between an IPA and its variants?") + ); + + assertTrue( + verifySecondIndexHighlight(result, + "

I know what makes an IPA an " + + "IPA, but what are the unique characteristics " + + "of it's common variants?") + ); } @Test @@ -51,8 +90,15 @@ public void quoted_highlight_test() throws IOException { JSONObject result = executeQuery( String.format( - "SOURCE=%s | WHERE match(Title, 'Cicerone') | highlight('Title')", TestsConstants.TEST_INDEX_BEER)); + "SOURCE=%s | WHERE match(Title, 'Cicerone') | highlight 'Title'", TestsConstants.TEST_INDEX_BEER)); + assertEquals(1, result.getInt("total")); + + assertTrue( + verifyFirstIndexHighlight(result, + "What exactly is a Cicerone? What do they do?") + ); + } @Test @@ -60,9 +106,21 @@ public void multiple_highlights_test() throws IOException { JSONObject result = executeQuery( String.format( - "SOURCE=%s | WHERE multi_match([Title, Body], 'hops') | highlight('Title') | highlight(Body)", + "SOURCE=%s | WHERE multi_match([Title, Body], 'IPA') | highlight 'Title' | highlight Body", TestsConstants.TEST_INDEX_BEER)); - assertEquals(2, result.getInt("total")); + + assertEquals(3, result.getInt("total")); + + assertTrue( + verifyFirstIndexHighlight(result, + "What are the differences between an IPA and its variants?") + ); + + assertTrue( + verifySecondIndexHighlight(result, + "

I know what makes an IPA an IPA, but what are the unique " + + "characteristics of it's common variants?") + ); } @Test @@ -70,24 +128,19 @@ public void highlight_wildcard_test() throws IOException { JSONObject result = executeQuery( String.format( - "SOURCE=%s | WHERE multi_match([Title, Body], 'Cicerone') | highlight('T*')", + "SOURCE=%s | WHERE multi_match([Title, Body], 'Cicerone') | highlight 'T*'", TestsConstants.TEST_INDEX_BEER)); assertEquals(1, result.getInt("total")); assertTrue( - result.getJSONArray("datarows") - .getJSONArray(0) - .getJSONObject(19) - .getJSONArray("Title") - .get(0) - .equals("What exactly is a Cicerone? What do they do?")); + verifyFirstIndexHighlightWildcard(result, "Title", + "What exactly is a Cicerone? What do they do?") + ); assertTrue( - result.getJSONArray("schema") - .getJSONObject(19) - .getString("name") - .equals("highlight('T*')")); + verifyHighlightSchema(result, "highlight 'T*'") + ); } @Test @@ -95,37 +148,55 @@ public void highlight_all_test() throws IOException { JSONObject result = executeQuery( String.format( - "SOURCE=%s | WHERE multi_match([Title, Body], 'Cicerone') | highlight('*')", + "SOURCE=%s | WHERE multi_match([Title, Body], 'Cicerone') | highlight '*'", TestsConstants.TEST_INDEX_BEER)); assertEquals(1, result.getInt("total")); assertTrue( - result.getJSONArray("datarows") - .getJSONArray(0) - .getJSONObject(19) - .getJSONArray("Title") - .get(0) - .equals("What exactly is a Cicerone? What do they do?")); + verifyFirstIndexHighlightWildcard(result, "Title", + "What exactly is a Cicerone? What do they do?") + ); assertTrue( - result.getJSONArray("datarows") - .getJSONArray(0) - .getJSONObject(19) - .getJSONArray("Body") - .get(0) - .equals("

Recently I've started seeing references to the term 'Cicerone' pop up around the internet; generally")); - + verifyFirstIndexHighlightWildcard(result, "Body", + "

Recently I've started seeing references to the term 'Cicerone' " + + "pop up around the internet; generally") + ); assertTrue( - result.getJSONArray("schema") - .getJSONObject(19) - .getString("name") - .equals("highlight('*')")); + verifyHighlightSchema(result, "highlight '*'") + ); + } - assertTrue( - result.getJSONArray("schema") - .getJSONObject(19) - .getString("type") - .equals("object")); + private boolean verifyFirstIndexHighlightWildcard(JSONObject result, String highlightField, String match) { + return result.getJSONArray("datarows") + .getJSONArray(0) + .getJSONObject(firstHighlightFieldIndex) + .getJSONArray(highlightField) + .get(0) + .equals(match); + } + + private boolean verifySecondIndexHighlight(JSONObject result, String match) { + return result.getJSONArray("datarows") + .getJSONArray(0) + .getJSONArray(secondHighlightFieldIndex) + .getString(0) + .equals(match); + } + + private boolean verifyFirstIndexHighlight(JSONObject result, String match) { + return result.getJSONArray("datarows") + .getJSONArray(0) + .getJSONArray(firstHighlightFieldIndex) + .getString(0) + .equals(match); + } + + private boolean verifyHighlightSchema(JSONObject result, String name) { + return result.getJSONArray("schema") + .getJSONObject(firstHighlightFieldIndex) + .getString("name") + .equals(name); } } diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java index 9e400f28fb..a53dedbe20 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/HighlightFunctionIT.java @@ -5,13 +5,17 @@ package org.opensearch.sql.sql; +import static org.opensearch.sql.util.MatcherUtils.rows; import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; import static org.opensearch.sql.util.MatcherUtils.verifySchema; +import org.json.JSONArray; import org.json.JSONObject; import org.junit.Test; import org.opensearch.sql.legacy.SQLIntegTestCase; import org.opensearch.sql.legacy.TestsConstants; +import java.util.List; public class HighlightFunctionIT extends SQLIntegTestCase { @@ -29,6 +33,40 @@ public void single_highlight_test() { assertEquals(1, response.getInt("total")); } + @Test + public void highlight_optional_arguments_test() { + String query = "SELECT highlight('Tags', pre_tags='', post_tags='') " + + "FROM %s WHERE match(Tags, 'yeast') LIMIT 1"; + JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); + verifySchema(response, schema("highlight('Tags', pre_tags='', post_tags='')", + null, "nested")); + + assertEquals(1, response.getInt("total")); + + verifyDataRows(response, + rows(new JSONArray(List.of("alcohol-level yeast home-brew champagne")))); + } + + @Test + public void highlight_multiple_optional_arguments_test() { + String query = "SELECT highlight(Title), highlight(Body, pre_tags='', post_tags='') FROM %s WHERE multi_match([Title, Body], 'IPA') LIMIT 1"; + JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_BEER)); + + verifySchema(response, schema("highlight(Title)", null, "nested"), + schema("highlight(Body, pre_tags='', " + + "post_tags='')", null, "nested")); + + assertEquals(1, response.getInt("total")); + + assertEquals("What are the differences between an IPA and its variants?", + response.getJSONArray("datarows").getJSONArray(0).getJSONArray(0).getString(0)); + + assertEquals("

I know what makes an IPA an IPA, but what are the unique characteristics of it's common variants?", + response.getJSONArray("datarows").getJSONArray(0).getJSONArray(1).getString(0)); + } + @Test public void accepts_unquoted_test() { String query = "SELECT Tags, highlight(Tags) FROM %s WHERE match(Tags, 'yeast') LIMIT 1"; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java index 78918ca552..d7e521bb27 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java @@ -156,7 +156,9 @@ public PhysicalPlan visitHighlight(PhysicalPlan node, Object context) { HighlightOperator hlOperator = (HighlightOperator) node; return doProtect( new HighlightOperator(visitInput(hlOperator.getInput(), context), - ((HighlightOperator) node).getHighlight()) + ((HighlightOperator) node).getHighlight(), + ((HighlightOperator) node).getArguments(), + ((HighlightOperator) node).getName()) ); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java index 646395d790..c26413c622 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java @@ -25,9 +25,11 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; import org.opensearch.search.sort.SortBuilder; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; @@ -173,14 +175,33 @@ public void pushDownLimit(Integer limit, Integer offset) { * Add highlight to DSL requests. * @param field name of the field to highlight */ - public void pushDownHighlight(String field) { + public void pushDownHighlight(String field, Map arguments) { + String unquotedField = StringUtils.unquoteText(field); if (sourceBuilder.highlighter() != null) { - sourceBuilder.highlighter().field(StringUtils.unquoteText(field)); + // OS does not allow duplicates of highlight fields + if (sourceBuilder.highlighter().fields().stream() + .anyMatch(f -> f.name().equals(unquotedField))) { + throw new SemanticCheckException(String.format( + "Duplicate field %s in highlight", field)); + } + + sourceBuilder.highlighter().field(unquotedField); } else { HighlightBuilder highlightBuilder = - new HighlightBuilder().field(StringUtils.unquoteText(field)); + new HighlightBuilder().field(unquotedField); sourceBuilder.highlighter(highlightBuilder); } + + // lastFieldIndex denotes previously set highlighter with field parameter + int lastFieldIndex = sourceBuilder.highlighter().fields().size() - 1; + if (arguments.containsKey("pre_tags")) { + sourceBuilder.highlighter().fields().get(lastFieldIndex) + .preTags(arguments.get("pre_tags").toString()); + } + if (arguments.containsKey("post_tags")) { + sourceBuilder.highlighter().fields().get(lastFieldIndex) + .postTags(arguments.get("post_tags").toString()); + } } /** diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index 629708c054..978cfb4b07 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -209,8 +209,9 @@ public PhysicalPlan visitAD(LogicalAD node, OpenSearchIndexScan context) { @Override public PhysicalPlan visitHighlight(LogicalHighlight node, OpenSearchIndexScan context) { context.getRequestBuilder().pushDownHighlight( - StringUtils.unquoteText(node.getHighlightField().toString())); - return new HighlightOperator(visitChild(node, context), node.getHighlightField()); + StringUtils.unquoteText(node.getHighlightField().toString()), node.getArguments()); + return new HighlightOperator(visitChild(node, context), node.getHighlightField(), + node.getArguments(), node.getName()); } } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index be63ef342c..fe3a68fed3 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -296,10 +296,13 @@ public void testVisitAD() { @Test public void testVisitHighlight() { + Map args = new HashMap<>(); HighlightOperator hlOperator = new HighlightOperator( values(emptyList()), - DSL.ref("reference", STRING)); + DSL.ref("reference", STRING), + args, + "highlight(reference)"); assertEquals(executionProtector.doProtect(hlOperator), executionProtector.visitHighlight(hlOperator, null)); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java index 43b9353190..00b5e3fc98 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java @@ -7,8 +7,11 @@ package org.opensearch.sql.opensearch.request; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.when; +import java.util.HashMap; +import java.util.Map; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -16,7 +19,9 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.common.unit.TimeValue; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; @ExtendWith(MockitoExtension.class) diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java index 64b87aa2c5..0bae307a00 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java @@ -13,12 +13,16 @@ import static org.mockito.Mockito.verify; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; +import java.util.HashMap; +import java.util.Map; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Answers; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.planner.logical.LogicalAD; @@ -88,6 +92,6 @@ public void visitHighlight() { new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); implementor.visitHighlight(node, indexScan); - verify(requestBuilder).pushDownHighlight(any()); + verify(requestBuilder).pushDownHighlight(any(), any()); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java index a1f2869ca5..9a606750a3 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexScanTest.java @@ -9,7 +9,6 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -20,6 +19,8 @@ import com.google.common.collect.ImmutableMap; import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -32,9 +33,12 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.request.OpenSearchQueryRequest; @@ -190,14 +194,49 @@ void pushDownFilters() { @Test void pushDownHighlight() { + Map args = new HashMap<>(); assertThat() .pushDown(QueryBuilders.termQuery("name", "John")) - .pushDownHighlight("Title") - .pushDownHighlight("Body") + .pushDownHighlight("Title", args) + .pushDownHighlight("Body", args) .shouldQueryHighlight(QueryBuilders.termQuery("name", "John"), new HighlightBuilder().field("Title").field("Body")); } + @Test + void pushDownHighlightWithArguments() { + Map args = new HashMap<>(); + args.put("pre_tags", new Literal("", DataType.STRING)); + args.put("post_tags", new Literal("", DataType.STRING)); + HighlightBuilder highlightBuilder = new HighlightBuilder() + .field("Title"); + highlightBuilder.fields().get(0).preTags("").postTags(""); + assertThat() + .pushDown(QueryBuilders.termQuery("name", "John")) + .pushDownHighlight("Title", args) + .shouldQueryHighlight(QueryBuilders.termQuery("name", "John"), + highlightBuilder); + } + + @Test + void pushDownHighlightWithRepeatingFields() { + mockResponse( + new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, + new ExprValue[]{employee(3, "Allen", "IT"), employee(4, "Bob", "HR")}); + + try (OpenSearchIndexScan indexScan = + new OpenSearchIndexScan(client, settings, "test", 2, exprValueFactory)) { + indexScan.getRequestBuilder().pushDownLimit(3, 0); + indexScan.open(); + Map args = new HashMap<>(); + indexScan.getRequestBuilder().pushDownHighlight("name", args); + indexScan.getRequestBuilder().pushDownHighlight("name", args); + } catch (SemanticCheckException e) { + assertTrue(e.getClass().equals(SemanticCheckException.class)); + } + verify(client).cleanup(any()); + } + private PushDownAssertion assertThat() { return new PushDownAssertion(client, exprValueFactory, settings); } @@ -223,8 +262,8 @@ PushDownAssertion pushDown(QueryBuilder query) { return this; } - PushDownAssertion pushDownHighlight(String query) { - indexScan.getRequestBuilder().pushDownHighlight(query); + PushDownAssertion pushDownHighlight(String query, Map arguments) { + indexScan.getRequestBuilder().pushDownHighlight(query, arguments); return this; } diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 2c65483868..298a917155 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -312,6 +312,8 @@ TIE_BREAKER: 'TIE_BREAKER'; TYPE: 'TYPE'; ZERO_TERMS_QUERY: 'ZERO_TERMS_QUERY'; HIGHLIGHT: 'HIGHLIGHT'; +HIGHLIGHT_PRE_TAGS: 'PRE_TAGS'; +HIGHLIGHT_POST_TAGS: 'POST_TAGS'; // SPAN KEYWORDS SPAN: 'SPAN'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index d54c127737..beff3873ed 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -290,7 +290,7 @@ evalFunctionCall ; highlightFunction - : HIGHLIGHT LT_PRTHS field=relevanceField RT_PRTHS #highlightFunctionCall + : HIGHLIGHT relevanceField (COMMA highlightArg)* #highlightFunctionCall ; /** cast function */ @@ -335,6 +335,10 @@ relevanceArg : relevanceArgName EQUAL relevanceArgValue ; +highlightArg + : highlightArgName EQUAL highlightArgValue + ; + relevanceArgName : ALLOW_LEADING_WILDCARD | ANALYZER | ANALYZE_WILDCARD | AUTO_GENERATE_SYNONYMS_PHRASE_QUERY | BOOST | CUTOFF_FREQUENCY | DEFAULT_FIELD | DEFAULT_OPERATOR | ENABLE_POSITION_INCREMENTS @@ -345,6 +349,11 @@ relevanceArgName | ZERO_TERMS_QUERY ; + +highlightArgName + : HIGHLIGHT_POST_TAGS| HIGHLIGHT_PRE_TAGS + ; + relevanceFieldAndWeight : field=relevanceField | field=relevanceField weight=relevanceFieldWeight @@ -370,6 +379,10 @@ relevanceArgValue | literalValue ; +highlightArgValue + : stringLiteral + ; + mathematicalFunctionBase : ABS | CEIL | CEILING | CONV | CRC32 | E | EXP | FLOOR | LN | LOG | LOG10 | LOG2 | MOD | PI |POW | POWER | RAND | ROUND | SIGN | SQRT | TRUNCATE diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index f264a4ff56..ff6312ce1c 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -37,7 +37,9 @@ import org.antlr.v4.runtime.Token; import org.antlr.v4.runtime.tree.ParseTree; import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.HighlightFunction; import org.opensearch.sql.ast.expression.Let; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Map; @@ -133,8 +135,11 @@ public UnresolvedPlan visitWhereCommand(WhereCommandContext ctx) { */ @Override public UnresolvedPlan visitHighlightCommand(OpenSearchPPLParser.HighlightCommandContext ctx) { - return new Highlight(new Alias(StringUtils.unquoteText(getTextInQuery(ctx)), - internalVisitExpression(ctx.highlightFunction().getRuleContext()))); + Alias highlightFunction = new Alias(StringUtils.unquoteText(getTextInQuery(ctx)), + internalVisitExpression(ctx.highlightFunction().getRuleContext())); + return new Highlight(highlightFunction, + ((HighlightFunction) highlightFunction.getDelegated()).getArguments(), + highlightFunction.getName()); } /** diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index c16436f4d7..2aaf64c951 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -51,6 +51,7 @@ import java.util.stream.Collectors; import org.antlr.v4.runtime.ParserRuleContext; import org.antlr.v4.runtime.RuleContext; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.Alias; @@ -208,7 +209,14 @@ public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunction @Override public UnresolvedExpression visitHighlightFunctionCall( OpenSearchPPLParser.HighlightFunctionCallContext ctx) { - return new HighlightFunction(AstDSL.stringLiteral(ctx.relevanceField().getText())); + ImmutableMap.Builder builder = ImmutableMap.builder(); + ctx.highlightArg().forEach(v -> builder.put( + v.highlightArgName().getText().toLowerCase(), + new Literal(StringUtils.unquoteText(v.highlightArgValue().getText()), + DataType.STRING)) + ); + return new HighlightFunction(AstDSL.stringLiteral( + StringUtils.unquoteText(ctx.relevanceField().getText())), builder.build()); } @Override diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java index ca77de87d6..140bcbb76a 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java @@ -82,7 +82,7 @@ public void testTopCommandWithoutNAndGroupByShouldPass() { @Test public void testHighlightShouldPass() { - ParseTree tree = new PPLSyntaxParser().parse("source=shakespeare | highlight(text_entry)"); + ParseTree tree = new PPLSyntaxParser().parse("source=shakespeare | highlight field"); assertNotEquals(null, tree); } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index 570c0f5687..93c1f3f4f6 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -42,6 +42,8 @@ import static org.opensearch.sql.utils.SystemIndexUtils.mappingTable; import com.google.common.collect.ImmutableMap; +import java.util.HashMap; +import java.util.Map; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; @@ -630,24 +632,42 @@ public void testParseCommand() { @Test public void testQuotedHighlightCommand() { - assertEqual("source=t | highlight('FieldA')", + Map args = new HashMap<>(); + assertEqual("source=t | highlight 'FieldA'", new Highlight( - alias("highlight('FieldA')", - highlight(stringLiteral("'FieldA'")))) + alias("highlight 'FieldA'", + highlight(stringLiteral("FieldA"), args)), + args, "highlight 'FieldA'") .attach(relation("t")) ); } @Test public void testUnquotedHighlightCommand() { - assertEqual("source=t | highlight(FieldA)", + Map args = new HashMap<>(); + assertEqual("source=t | highlight fieldA", new Highlight( - alias("highlight(FieldA)", - highlight(stringLiteral("FieldA")))) + alias("highlight fieldA", + highlight(stringLiteral("fieldA"), args)), + args, "highlight fieldA") .attach(relation("t")) ); } + @Test + public void testHighlightCommandWithArguments() { + Map args = new HashMap<>(); + args.put("pre_tags", new Literal("", DataType.STRING)); + args.put("post_tags", new Literal("", DataType.STRING)); + assertEqual("source=t | highlight fieldA, pre_tags='', post_tags=''", + new Highlight( + alias("highlight fieldA, pre_tags='', post_tags=''", + highlight(stringLiteral("fieldA"), args)), + args, "highlight fieldA, pre_tags='', post_tags=''") + .attach(relation("t")) + ); + } + @Test public void testKmeansCommand() { assertEqual("source=t | kmeans centroids=3 iterations=2 distance_type='l1'", diff --git a/sql/src/main/antlr/OpenSearchSQLLexer.g4 b/sql/src/main/antlr/OpenSearchSQLLexer.g4 index 6d2d7d8a64..72f331d39b 100644 --- a/sql/src/main/antlr/OpenSearchSQLLexer.g4 +++ b/sql/src/main/antlr/OpenSearchSQLLexer.g4 @@ -348,6 +348,8 @@ TIME_ZONE: 'TIME_ZONE'; TYPE: 'TYPE'; ZERO_TERMS_QUERY: 'ZERO_TERMS_QUERY'; HIGHLIGHT: 'HIGHLIGHT'; +HIGHLIGHT_PRE_TAGS: 'PRE_TAGS'; +HIGHLIGHT_POST_TAGS: 'POST_TAGS'; // RELEVANCE FUNCTIONS MATCH_BOOL_PREFIX: 'MATCH_BOOL_PREFIX'; diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 40207df82a..9d4bcd8fcf 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -304,7 +304,7 @@ functionCall ; highlightFunction - : HIGHLIGHT LR_BRACKET relevanceField RR_BRACKET + : HIGHLIGHT LR_BRACKET relevanceField (COMMA highlightArg)* RR_BRACKET ; scalarFunctionName @@ -425,6 +425,10 @@ relevanceArg : relevanceArgName EQUAL_SYMBOL relevanceArgValue ; +highlightArg + : highlightArgName EQUAL_SYMBOL highlightArgValue + ; + relevanceArgName : ALLOW_LEADING_WILDCARD | ANALYZER | ANALYZE_WILDCARD | AUTO_GENERATE_SYNONYMS_PHRASE_QUERY | BOOST | CUTOFF_FREQUENCY | DEFAULT_FIELD | DEFAULT_OPERATOR | ENABLE_POSITION_INCREMENTS @@ -435,6 +439,10 @@ relevanceArgName | ZERO_TERMS_QUERY ; +highlightArgName + : HIGHLIGHT_POST_TAGS | HIGHLIGHT_PRE_TAGS + ; + relevanceFieldAndWeight : field=relevanceField | field=relevanceField weight=relevanceFieldWeight @@ -460,3 +468,7 @@ relevanceArgValue | constant ; +highlightArgValue + : stringLiteral + ; + diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index 453162e335..f2e4683a57 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -48,6 +48,7 @@ import static org.opensearch.sql.sql.parser.ParserUtils.createSortOption; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -134,7 +135,15 @@ public UnresolvedExpression visitScalarFunctionCall(ScalarFunctionCallContext ct @Override public UnresolvedExpression visitHighlightFunctionCall( OpenSearchSQLParser.HighlightFunctionCallContext ctx) { - return new HighlightFunction(visit(ctx.highlightFunction().relevanceField())); + ImmutableMap.Builder builder = ImmutableMap.builder(); + ctx.highlightFunction().highlightArg().forEach(v -> builder.put( + v.highlightArgName().getText().toLowerCase(), + new Literal(StringUtils.unquoteText(v.highlightArgValue().getText()), + DataType.STRING)) + ); + + return new HighlightFunction(visit(ctx.highlightFunction().relevanceField()), + builder.build()); } @Override diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java index 8bf38b14a6..9bd92f77bc 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java @@ -32,10 +32,14 @@ import static org.opensearch.sql.utils.SystemIndexUtils.mappingTable; import com.google.common.collect.ImmutableList; +import java.util.HashMap; +import java.util.Map; import org.antlr.v4.runtime.tree.ParseTree; import org.junit.jupiter.api.Test; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.sql.antlr.SQLSyntaxParser; @@ -671,18 +675,36 @@ public void can_build_limit_clause_with_offset() { @Test public void can_build_qualified_name_highlight() { + Map args = new HashMap<>(); assertEquals( project(relation("test"), - alias("highlight(fieldA)", highlight(AstDSL.qualifiedName("fieldA")))), + alias("highlight(fieldA)", + highlight(AstDSL.qualifiedName("fieldA"), args))), buildAST("SELECT highlight(fieldA) FROM test") ); } + @Test + public void can_build_qualified_highlight_with_arguments() { + Map args = new HashMap<>(); + args.put("pre_tags", new Literal("", DataType.STRING)); + args.put("post_tags", new Literal("", DataType.STRING)); + assertEquals( + project(relation("test"), + alias("highlight(fieldA, pre_tags='', post_tags='')", + highlight(AstDSL.qualifiedName("fieldA"), args))), + buildAST("SELECT highlight(fieldA, pre_tags='', post_tags='') " + + "FROM test") + ); + } + @Test public void can_build_string_literal_highlight() { + Map args = new HashMap<>(); assertEquals( project(relation("test"), - alias("highlight(\"fieldA\")", highlight(AstDSL.stringLiteral("fieldA")))), + alias("highlight(\"fieldA\")", + highlight(AstDSL.stringLiteral("fieldA"), args))), buildAST("SELECT highlight(\"fieldA\") FROM test") ); } diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java index ef881275e5..4be6cf6f24 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -13,7 +13,6 @@ import static org.opensearch.sql.ast.dsl.AstDSL.caseWhen; import static org.opensearch.sql.ast.dsl.AstDSL.dateLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.doubleLiteral; -import static org.opensearch.sql.ast.dsl.AstDSL.floatLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.function; import static org.opensearch.sql.ast.dsl.AstDSL.highlight; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; @@ -35,12 +34,15 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.HashMap; +import java.util.Map; import org.antlr.v4.runtime.CommonTokenStream; import org.apache.commons.lang3.tuple.ImmutablePair; import org.junit.jupiter.api.Test; import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.DataType; +import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.RelevanceFieldList; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; @@ -311,16 +313,18 @@ public void canBuildWindowFunctionWithNullOrderSpecified() { @Test public void canBuildStringLiteralHighlightFunction() { + Map args = new HashMap<>(); assertEquals( - highlight(AstDSL.stringLiteral("fieldA")), + highlight(AstDSL.stringLiteral("fieldA"), args), buildExprAst("highlight(\"fieldA\")") ); } @Test public void canBuildQualifiedNameHighlightFunction() { + Map args = new HashMap<>(); assertEquals( - highlight(AstDSL.qualifiedName("fieldA")), + highlight(AstDSL.qualifiedName("fieldA"), args), buildExprAst("highlight(fieldA)") ); }