Skip to content

Commit

Permalink
Adding optional arguments for highlight.
Browse files Browse the repository at this point in the history
Signed-off-by: forestmvey <forestv@bitquilltech.com>
  • Loading branch information
forestmvey committed Sep 13, 2022
1 parent bf0be6a commit 511fd88
Show file tree
Hide file tree
Showing 36 changed files with 483 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ public LogicalPlan visitHighlight(Highlight node, AnalysisContext context) {

HighlightFunction unresolved = (HighlightFunction) ((Alias)node.getExpression()).getDelegated();
Expression field = expressionAnalyzer.analyze(unresolved.getHighlightField(), context);
return new LogicalHighlight(child, field);
return new LogicalHighlight(child, field, node.getArguments(), node.getName());
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,9 @@ public Expression visitWindowFunction(WindowFunction node, AnalysisContext conte

@Override
public Expression visitHighlightFunction(HighlightFunction node, AnalysisContext context) {
Expression expr = node.getHighlightField().accept(this, context);
String highlightStr = "highlight(" + StringUtils.unquoteText(expr.toString()) + ")";
return new ReferenceExpression(highlightStr, expr.toString().contains("*")
? ExprCoreType.STRUCT : ExprCoreType.ARRAY);
return (node.getName() != null) ? new ReferenceExpression(node.getName(),
node.getName().contains("*") ? ExprCoreType.STRUCT : ExprCoreType.ARRAY)
: null;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ public LogicalPlan analyze(UnresolvedExpression projectItem, AnalysisContext con

@Override
public LogicalPlan visitAlias(Alias node, AnalysisContext context) {
if (!(node.getDelegated() instanceof HighlightFunction)) {
UnresolvedExpression delegated = node.getDelegated();
if (!(delegated instanceof HighlightFunction)) {
return null;
}

HighlightFunction unresolved = (HighlightFunction) node.getDelegated();
// Must set name here else operator cannot resolve delegated reference
((HighlightFunction) delegated).setName(node.getName());
HighlightFunction unresolved = (HighlightFunction) delegated;
Expression field = expressionAnalyzer.analyze(unresolved.getHighlightField(), context);
return new LogicalHighlight(child, field);
return new LogicalHighlight(child, field, unresolved.getArguments(), unresolved.getName());
}
}
5 changes: 3 additions & 2 deletions core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,9 @@ public When when(UnresolvedExpression condition, UnresolvedExpression result) {
return new When(condition, result);
}

public UnresolvedExpression highlight(UnresolvedExpression fieldName) {
return new HighlightFunction(fieldName);
public UnresolvedExpression highlight(UnresolvedExpression fieldName,
java.util.Map<String, Literal> arguments) {
return new HighlightFunction(fieldName, arguments);
}

public UnresolvedExpression window(UnresolvedExpression function,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,29 @@
package org.opensearch.sql.ast.expression;

import java.util.List;
import lombok.AllArgsConstructor;
import java.util.Map;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;

/**
* Expression node of Highlight function.
*/
@AllArgsConstructor
@EqualsAndHashCode(callSuper = false)
@Getter
@ToString
public class HighlightFunction extends UnresolvedExpression {
private final UnresolvedExpression highlightField;
private final Map<String, Literal> arguments;
@Setter
private String name;

public HighlightFunction(UnresolvedExpression highlightField, Map<String, Literal> arguments) {
this.highlightField = highlightField;
this.arguments = arguments;
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/tree/Highlight.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Map;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.UnresolvedExpression;

/**
Expand All @@ -25,6 +27,8 @@
@RequiredArgsConstructor
public class Highlight extends UnresolvedPlan {
private final UnresolvedExpression expression;
private final Map<String, Literal> arguments;
private final String name;
private UnresolvedPlan child;

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,30 @@
package org.opensearch.sql.planner.logical;

import java.util.Collections;
import java.util.Map;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.expression.Expression;

@EqualsAndHashCode(callSuper = true)
@Getter
@ToString
public class LogicalHighlight extends LogicalPlan {
private final Expression highlightField;
private final Map<String, Literal> arguments;
private final String name;

public LogicalHighlight(LogicalPlan childPlan, Expression field) {
/**
* Constructor of LogicalHighlight.
*/
public LogicalHighlight(LogicalPlan childPlan, Expression highlightField,
Map<String, Literal> arguments, String name) {
super(Collections.singletonList(childPlan));
highlightField = field;
this.highlightField = highlightField;
this.arguments = arguments;
this.name = name;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ public LogicalPlan window(LogicalPlan input,
return new LogicalWindow(input, windowFunction, windowDefinition);
}

public LogicalPlan highlight(LogicalPlan input, Expression field) {
return new LogicalHighlight(input, field);
public LogicalPlan highlight(LogicalPlan input, Expression field,
Map<String, Literal> arguments, String name) {
return new LogicalHighlight(input, field, arguments, name);
}

public static LogicalPlan remove(LogicalPlan input, ReferenceExpression... fields) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.data.model.ExprTupleValue;
import org.opensearch.sql.data.model.ExprValue;
Expand All @@ -37,16 +39,16 @@
*
*/
@EqualsAndHashCode
@AllArgsConstructor
public class HighlightOperator extends PhysicalPlan {
@Getter
private final PhysicalPlan input;
@Getter
private final Expression highlight;

public HighlightOperator(PhysicalPlan input, Expression highlight) {
this.input = input;
this.highlight = highlight;
}
@Getter
private final Map<String, Literal> arguments;
@Getter
private final String name;

@Override
public <R, C> R accept(PhysicalPlanNodeVisitor<R, C> visitor, C context) {
Expand Down Expand Up @@ -83,9 +85,9 @@ public ExprValue next() {
*/
private Pair<String, ExprValue> mapHighlight(Environment<Expression, ExprValue> env) {
String osHighlightKey = "_highlight";
String highlightStr = StringUtils.unquoteText(highlight.toString());
if (!highlightStr.contains("*")) {
osHighlightKey += "." + highlightStr;
String highlightFieldStr = StringUtils.unquoteText(highlight.toString());
if (!highlightFieldStr.contains("*")) {
osHighlightKey += "." + highlightFieldStr;
}

ReferenceExpression osOutputVar = DSL.ref(osHighlightKey, STRING);
Expand All @@ -94,19 +96,18 @@ private Pair<String, ExprValue> mapHighlight(Environment<Expression, ExprValue>
// In the event of multiple returned highlights and wildcard being
// used in conjunction with other highlight calls, we need to ensure
// only wildcard regex matching is mapped to wildcard call.
if (highlightStr.contains("*") && value.type() == STRUCT) {
if (highlightFieldStr.contains("*") && value.type() == STRUCT) {
value = new ExprTupleValue(
new LinkedHashMap<String, ExprValue>(value.tupleValue()
.entrySet()
.stream()
.filter(s -> matchesHighlightRegex(s.getKey(), highlightStr))
.filter(s -> matchesHighlightRegex(s.getKey(), highlightFieldStr))
.collect(Collectors.toMap(
e -> e.getKey(),
e -> e.getValue()))));
}

String sqlHighlightKey = "highlight(" + highlightStr + ")";
ReferenceExpression sqlOutputVar = DSL.ref(sqlHighlightKey, STRING);
ReferenceExpression sqlOutputVar = DSL.ref(name, STRING);

// Add mapping for sql output and opensearch returned highlight fields
extendEnv(env, sqlOutputVar, value);
Expand Down
19 changes: 14 additions & 5 deletions core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -232,32 +232,41 @@ public void top_source() {

@Test
public void project_highlight() {
Map<String, Literal> args = new HashMap<>();
args.put("pre_tags", new Literal("<mark>", DataType.STRING));
args.put("post_tags", new Literal("</mark>", DataType.STRING));

assertAnalyzeEqual(
LogicalPlanDSL.project(
LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table),
DSL.literal("fieldA")),
DSL.named("highlight(fieldA)", DSL.ref("highlight(fieldA)", ARRAY))
DSL.literal("fieldA"), args,
"highlight(fieldA, pre_tags='<mark>', post_tags='</mark>')"),
DSL.named("highlight(fieldA, pre_tags='<mark>', post_tags='</mark>')",
DSL.ref("highlight(fieldA, pre_tags='<mark>', post_tags='</mark>')", ARRAY))
),
AstDSL.projectWithArg(
AstDSL.relation("schema"),
AstDSL.defaultFieldsArgs(),
AstDSL.alias("highlight(fieldA)", new HighlightFunction(AstDSL.stringLiteral("fieldA")))
AstDSL.alias("highlight(fieldA, pre_tags='<mark>', post_tags='</mark>')",
new HighlightFunction(AstDSL.stringLiteral("fieldA"), args))
)
);
}

@Test
public void project_highlight_wildcard() {
Map<String, Literal> args = new HashMap<>();
assertAnalyzeEqual(
LogicalPlanDSL.project(
LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table),
DSL.literal("*")),
DSL.literal("*"), args, "highlight(*)"),
DSL.named("highlight(*)", DSL.ref("highlight(*)", STRUCT))
),
AstDSL.projectWithArg(
AstDSL.relation("schema"),
AstDSL.defaultFieldsArgs(),
AstDSL.alias("highlight(*)", new HighlightFunction(AstDSL.stringLiteral("*")))
AstDSL.alias("highlight(*)",
new HighlightFunction(AstDSL.stringLiteral("*"), args))
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.HighlightFunction;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.LiteralExpression;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.expression.NamedExpression;
import org.opensearch.sql.expression.config.ExpressionConfig;
import org.springframework.context.annotation.Configuration;
Expand All @@ -40,8 +39,25 @@ void visit_named_select_item() {

@Test
void visit_highlight() {
Map<String, Literal> args = new HashMap<>();
HighlightFunction highlightFunction = new HighlightFunction(
AstDSL.stringLiteral("fieldA"), args);
highlightFunction.setName("highlight(fieldA)");
Alias alias = AstDSL.alias("highlight(fieldA)",
new HighlightFunction(AstDSL.stringLiteral("fieldA")));
highlightFunction);
NamedExpressionAnalyzer analyzer =
new NamedExpressionAnalyzer(expressionAnalyzer);

NamedExpression analyze = analyzer.analyze(alias, analysisContext);
assertEquals("highlight(fieldA)", analyze.getNameOrAlias());
}

@Test
void visit_highlight_no_set_name() {
Map<String, Literal> args = new HashMap<>();
Alias alias = AstDSL.alias("highlight(fieldA)",
new HighlightFunction(
AstDSL.stringLiteral("fieldA"), args));
NamedExpressionAnalyzer analyzer =
new NamedExpressionAnalyzer(expressionAnalyzer);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
import static org.opensearch.sql.ast.dsl.AstDSL.relation;
import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral;

import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.analysis.AnalyzerTestBase;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.tree.Highlight;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.config.ExpressionConfig;
Expand All @@ -28,13 +31,15 @@
public class LogicalHighlightTest extends AnalyzerTestBase {
@Test
public void analyze_highlight_with_one_field() {
Map<String, Literal> args = new HashMap<>();
assertAnalyzeEqual(
LogicalPlanDSL.highlight(
LogicalPlanDSL.relation("schema", table),
DSL.literal("field")),
DSL.literal("field"), args, "highlight(field)"),
new Highlight(
alias("highlight('field')",
highlight(stringLiteral("field"))))
highlight(stringLiteral("field"), args)),
args, "highlight(field)")
.attach(relation("schema")));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -117,8 +118,9 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() {
assertNull(rareTopN.accept(new LogicalPlanNodeVisitor<Integer, Object>() {
}, null));

Map<String, Literal> args = new HashMap<>();
LogicalPlan highlight = new LogicalHighlight(filter,
new LiteralExpression(ExprValueUtils.stringValue("fieldA")));
new LiteralExpression(ExprValueUtils.stringValue("fieldA")), args, "highlight(fieldA)");
assertNull(highlight.accept(new LogicalPlanNodeVisitor<Integer, Object>() {
}, null));

Expand Down
Loading

0 comments on commit 511fd88

Please sign in to comment.