Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Highlight Wildcard in SQL #827

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ public LogicalPlan visitProject(Project node, AnalysisContext context) {
for (UnresolvedExpression expr : node.getProjectList()) {
HighlightAnalyzer highlightAnalyzer = new HighlightAnalyzer(expressionAnalyzer, child);
child = highlightAnalyzer.analyze(expr, context);

}

List<NamedExpression> namedExpressions =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ public Expression visitWindowFunction(WindowFunction node, AnalysisContext conte
}

@Override
public Expression visitHighlight(HighlightFunction node, AnalysisContext context) {
public Expression visitHighlightFunction(HighlightFunction node, AnalysisContext context) {
Expression expr = node.getHighlightField().accept(this, context);
return new HighlightExpression(expr);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ public LogicalPlan analyze(UnresolvedExpression projectItem, AnalysisContext con

@Override
public LogicalPlan visitAlias(Alias node, AnalysisContext context) {
if (!(node.getDelegated() instanceof HighlightFunction)) {
UnresolvedExpression delegated = node.getDelegated();
if (!(delegated instanceof HighlightFunction)) {
return null;
}

HighlightFunction unresolved = (HighlightFunction) node.getDelegated();
HighlightFunction unresolved = (HighlightFunction) delegated;
Expression field = expressionAnalyzer.analyze(unresolved.getHighlightField(), context);
return new LogicalHighlight(child, field);
return new LogicalHighlight(child, field, unresolved.getArguments());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ public T visitAD(AD node, C context) {
return visitChildren(node, context);
}

public T visitHighlight(HighlightFunction node, C context) {
public T visitHighlightFunction(HighlightFunction node, C context) {
return visitChildren(node, context);
}
}
5 changes: 3 additions & 2 deletions core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,9 @@ public When when(UnresolvedExpression condition, UnresolvedExpression result) {
return new When(condition, result);
}

public UnresolvedExpression highlight(UnresolvedExpression fieldName) {
return new HighlightFunction(fieldName);
public UnresolvedExpression highlight(UnresolvedExpression fieldName,
java.util.Map<String, Literal> arguments) {
return new HighlightFunction(fieldName, arguments);
}

public UnresolvedExpression window(UnresolvedExpression function,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.sql.ast.expression;

import java.util.List;
import java.util.Map;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
Expand All @@ -21,10 +22,11 @@
@ToString
public class HighlightFunction extends UnresolvedExpression {
private final UnresolvedExpression highlightField;
private final Map<String, Literal> arguments;

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitHighlight(this, context);
return nodeVisitor.visitHighlightFunction(this, context);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@

package org.opensearch.sql.expression;

import java.util.LinkedHashMap;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import lombok.Getter;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.data.model.ExprTupleValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.expression.env.Environment;
Expand All @@ -20,6 +26,7 @@
@Getter
public class HighlightExpression extends FunctionExpression {
private final Expression highlightField;
private final ExprType type;

/**
* HighlightExpression Constructor.
Expand All @@ -28,6 +35,8 @@ public class HighlightExpression extends FunctionExpression {
public HighlightExpression(Expression highlightField) {
super(BuiltinFunctionName.HIGHLIGHT.getName(), List.of(highlightField));
this.highlightField = highlightField;
this.type = this.highlightField.toString().contains("*")
? ExprCoreType.STRUCT : ExprCoreType.ARRAY;
}

/**
Expand All @@ -37,21 +46,57 @@ public HighlightExpression(Expression highlightField) {
*/
@Override
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
String refName = "_highlight" + "." + StringUtils.unquoteText(getHighlightField().toString());
return valueEnv.resolve(DSL.ref(refName, ExprCoreType.STRING));
String refName = "_highlight";
// Not a wilcard expression
if (this.type == ExprCoreType.ARRAY) {
refName += "." + StringUtils.unquoteText(getHighlightField().toString());
}
ExprValue value = valueEnv.resolve(DSL.ref(refName, ExprCoreType.STRING));

// In the event of multiple returned highlights and wildcard being
// used in conjunction with other highlight calls, we need to ensure
// only wildcard regex matching is mapped to wildcard call.
if (this.type == ExprCoreType.STRUCT && value.type() == ExprCoreType.STRUCT) {
value = new ExprTupleValue(
new LinkedHashMap<String, ExprValue>(value.tupleValue()
.entrySet()
.stream()
.filter(s -> matchesHighlightRegex(s.getKey(),
StringUtils.unquoteText(highlightField.toString())))
.collect(Collectors.toMap(
e -> e.getKey(),
e -> e.getValue()))));
if (value.tupleValue().isEmpty()) {
value = ExprValueUtils.missingValue();
}
}

return value;
}

/**
* Get type for HighlightExpression.
* @return : String type.
* @return : Expression type.
*/
@Override
public ExprType type() {
return ExprCoreType.ARRAY;
return this.type;
}

@Override
public <T, C> T accept(ExpressionNodeVisitor<T, C> visitor, C context) {
return visitor.visitHighlight(this, context);
}

/**
* Check if field matches the wildcard pattern used in highlight query.
* @param field Highlight selected field for query
* @param pattern Wildcard regex to match field against
* @return True if field matches wildcard pattern
*/
private boolean matchesHighlightRegex(String field, String pattern) {
Pattern p = Pattern.compile(pattern.replace("*", ".*"));
Matcher matcher = p.matcher(field);
return matcher.matches();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.HighlightExpression;
import org.opensearch.sql.expression.NamedArgumentExpression;
import org.opensearch.sql.expression.env.Environment;

Expand All @@ -37,15 +36,6 @@ public void register(BuiltinFunctionRepository repository) {
repository.register(match_phrase(BuiltinFunctionName.MATCH_PHRASE));
repository.register(match_phrase(BuiltinFunctionName.MATCHPHRASE));
repository.register(match_phrase_prefix());
repository.register(highlight());
}

private static FunctionResolver highlight() {
FunctionName functionName = BuiltinFunctionName.HIGHLIGHT.getName();
FunctionSignature functionSignature = new FunctionSignature(functionName, List.of(STRING));
FunctionBuilder functionBuilder = arguments -> new HighlightExpression(arguments.get(0));
return new DefaultFunctionResolver(functionName,
ImmutableMap.of(functionSignature, functionBuilder));
}

private static FunctionResolver match_bool_prefix() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,28 @@
package org.opensearch.sql.planner.logical;

import java.util.Collections;
import java.util.Map;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.expression.Expression;

@EqualsAndHashCode(callSuper = true)
@Getter
@ToString
public class LogicalHighlight extends LogicalPlan {
private final Expression highlightField;
private final Map<String, Literal> arguments;

public LogicalHighlight(LogicalPlan childPlan, Expression field) {
/**
* Constructor of LogicalHighlight.
*/
public LogicalHighlight(LogicalPlan childPlan, Expression highlightField,
Map<String, Literal> arguments) {
super(Collections.singletonList(childPlan));
highlightField = field;
this.highlightField = highlightField;
this.arguments = arguments;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ public LogicalPlan window(LogicalPlan input,
return new LogicalWindow(input, windowFunction, windowDefinition);
}

public LogicalPlan highlight(LogicalPlan input, Expression field) {
return new LogicalHighlight(input, field);
public LogicalPlan highlight(LogicalPlan input, Expression field,
Map<String, Literal> arguments) {
return new LogicalHighlight(input, field, arguments);
}

public static LogicalPlan remove(LogicalPlan input, ReferenceExpression... fields) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC;
import static org.opensearch.sql.ast.tree.Sort.SortOrder;
import static org.opensearch.sql.data.model.ExprValueUtils.integerValue;
import static org.opensearch.sql.data.type.ExprCoreType.ARRAY;
import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE;
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
import static org.opensearch.sql.data.type.ExprCoreType.LONG;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -270,16 +272,41 @@ public void project_source() {

@Test
public void project_highlight() {
Map<String, Literal> args = new HashMap<>();
args.put("pre_tags", new Literal("<mark>", DataType.STRING));
args.put("post_tags", new Literal("</mark>", DataType.STRING));

assertAnalyzeEqual(
LogicalPlanDSL.project(
LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table),
DSL.literal("fieldA"), args),
DSL.named("highlight(fieldA, pre_tags='<mark>', post_tags='</mark>')",
new HighlightExpression(DSL.literal("fieldA")))
),
AstDSL.projectWithArg(
AstDSL.relation("schema"),
AstDSL.defaultFieldsArgs(),
AstDSL.alias("highlight(fieldA, pre_tags='<mark>', post_tags='</mark>')",
new HighlightFunction(AstDSL.stringLiteral("fieldA"), args))
)
);
}

@Test
public void project_highlight_wildcard() {
Map<String, Literal> args = new HashMap<>();
assertAnalyzeEqual(
LogicalPlanDSL.project(
LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table),
DSL.literal("fieldA")),
DSL.named("highlight(fieldA)", new HighlightExpression(DSL.literal("fieldA")))
DSL.literal("*"), args),
DSL.named("highlight(*)",
new HighlightExpression(DSL.literal("*")))
),
AstDSL.projectWithArg(
AstDSL.relation("schema"),
AstDSL.defaultFieldsArgs(),
AstDSL.alias("highlight(fieldA)", new HighlightFunction(AstDSL.stringLiteral("fieldA")))
AstDSL.alias("highlight(*)",
new HighlightFunction(AstDSL.stringLiteral("*"), args))
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import static java.util.Collections.emptyList;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.opensearch.sql.ast.dsl.AstDSL.field;
Expand Down Expand Up @@ -38,7 +37,6 @@
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.expression.AllFields;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.HighlightFunction;
import org.opensearch.sql.ast.expression.RelevanceFieldList;
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
Expand All @@ -50,7 +48,6 @@
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.HighlightExpression;
import org.opensearch.sql.expression.LiteralExpression;
import org.opensearch.sql.expression.config.ExpressionConfig;
import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction;
Expand Down Expand Up @@ -592,12 +589,6 @@ public void now_as_a_function_not_cached() {
assertTrue(values.stream().noneMatch(v -> v.valueOf(null) == referenceValue));
}

@Test
void highlight() {
assertAnalyzeEqual(new HighlightExpression(DSL.literal("fieldA")),
new HighlightFunction(stringLiteral("fieldA")));
}

protected Expression analyze(UnresolvedExpression unresolvedExpression) {
return expressionAnalyzer.analyze(unresolvedExpression, analysisContext);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.HighlightFunction;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.LiteralExpression;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.expression.NamedExpression;
import org.opensearch.sql.expression.config.ExpressionConfig;
import org.springframework.context.annotation.Configuration;
Expand All @@ -40,8 +39,10 @@ void visit_named_select_item() {

@Test
void visit_highlight() {
Map<String, Literal> args = new HashMap<>();
Alias alias = AstDSL.alias("highlight(fieldA)",
new HighlightFunction(AstDSL.stringLiteral("fieldA")));
new HighlightFunction(
AstDSL.stringLiteral("fieldA"), args));
NamedExpressionAnalyzer analyzer =
new NamedExpressionAnalyzer(expressionAnalyzer);

Expand Down
Loading