Skip to content

Commit

Permalink
[Backport 2.7] opensearch-project#639: allow metadata fields and scor…
Browse files Browse the repository at this point in the history
…e opensearch function (#228) (opensearch-project#1509)

* opensearch-project#639: Support OpenSearch metadata fields and the score OpenSearch function (#228) (opensearch-project#1456)

Signed-off-by: Andrew Carbonetto <andrewc@bitquilltech.com>
Co-authored-by: Andrew Carbonetto <andrewc@bitquilltech.com>
  • Loading branch information
opensearch-trigger-bot[bot] and acarbonetto authored Apr 12, 2023
1 parent a796433 commit 5ec28a9
Show file tree
Hide file tree
Showing 42 changed files with 1,537 additions and 61 deletions.
7 changes: 7 additions & 0 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import org.opensearch.sql.ast.tree.Values;
import org.opensearch.sql.data.model.ExprMissingValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.DSL;
Expand Down Expand Up @@ -150,6 +151,9 @@ public LogicalPlan visitRelation(Relation node, AnalysisContext context) {
dataSourceSchemaIdentifierNameResolver.getIdentifierName());
}
table.getFieldTypes().forEach((k, v) -> curEnv.define(new Symbol(Namespace.FIELD_NAME, k), v));
table.getReservedFieldTypes().forEach(
(k, v) -> curEnv.addReservedWord(new Symbol(Namespace.FIELD_NAME, k), v)
);

// Put index name or its alias in index namespace on type environment so qualifier
// can be removed when analyzing qualified name. The value (expr type) here doesn't matter.
Expand Down Expand Up @@ -193,6 +197,9 @@ public LogicalPlan visitTableFunction(TableFunction node, AnalysisContext contex
TypeEnvironment curEnv = context.peek();
Table table = tableFunctionImplementation.applyArguments();
table.getFieldTypes().forEach((k, v) -> curEnv.define(new Symbol(Namespace.FIELD_NAME, k), v));
table.getReservedFieldTypes().forEach(
(k, v) -> curEnv.addReservedWord(new Symbol(Namespace.FIELD_NAME, k), v)
);
curEnv.define(new Symbol(Namespace.INDEX_NAME,
dataSourceSchemaIdentifierNameResolver.getIdentifierName()), STRUCT);
return new LogicalRelation(dataSourceSchemaIdentifierNameResolver.getIdentifierName(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

import static org.opensearch.sql.ast.dsl.AstDSL.and;
import static org.opensearch.sql.ast.dsl.AstDSL.compare;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.GTE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.LTE;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand All @@ -31,6 +29,7 @@
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Function;
Expand All @@ -42,6 +41,7 @@
import org.opensearch.sql.ast.expression.Or;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.RelevanceFieldList;
import org.opensearch.sql.ast.expression.ScoreFunction;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.UnresolvedArgument;
import org.opensearch.sql.ast.expression.UnresolvedAttribute;
Expand All @@ -51,6 +51,7 @@
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.DSL;
Expand All @@ -67,6 +68,7 @@
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.OpenSearchFunctions;
import org.opensearch.sql.expression.parse.ParseExpression;
import org.opensearch.sql.expression.span.SpanExpression;
import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction;
Expand Down Expand Up @@ -207,6 +209,65 @@ public Expression visitHighlightFunction(HighlightFunction node, AnalysisContext
return new HighlightExpression(expr);
}

/**
* visitScoreFunction removes the score function from the AST and replaces it with the child
* relevance function node. If the optional boost variable is provided, the boost argument
* of the relevance function is combined.
*
* @param node score function node
* @param context analysis context for the query
* @return resolved relevance function
*/
public Expression visitScoreFunction(ScoreFunction node, AnalysisContext context) {
Literal boostArg = node.getRelevanceFieldWeight();
if (!boostArg.getType().equals(DataType.DOUBLE)) {
throw new SemanticCheckException(String.format("Expected boost type '%s' but got '%s'",
DataType.DOUBLE.name(), boostArg.getType().name()));
}
Double thisBoostValue = ((Double) boostArg.getValue());

// update the existing unresolved expression to add a boost argument if it doesn't exist
// OR multiply the existing boost argument
Function relevanceQueryUnresolvedExpr = (Function) node.getRelevanceQuery();
List<UnresolvedExpression> relevanceFuncArgs = relevanceQueryUnresolvedExpr.getFuncArgs();

boolean doesFunctionContainBoostArgument = false;
List<UnresolvedExpression> updatedFuncArgs = new ArrayList<>();
for (UnresolvedExpression expr : relevanceFuncArgs) {
String argumentName = ((UnresolvedArgument) expr).getArgName();
if (argumentName.equalsIgnoreCase("boost")) {
doesFunctionContainBoostArgument = true;
Literal boostArgLiteral = (Literal) ((UnresolvedArgument) expr).getValue();
Double boostValue =
Double.parseDouble((String) boostArgLiteral.getValue()) * thisBoostValue;
UnresolvedArgument newBoostArg = new UnresolvedArgument(
argumentName,
new Literal(boostValue.toString(), DataType.STRING)
);
updatedFuncArgs.add(newBoostArg);
} else {
updatedFuncArgs.add(expr);
}
}

// since nothing was found, add an argument
if (!doesFunctionContainBoostArgument) {
UnresolvedArgument newBoostArg = new UnresolvedArgument(
"boost", new Literal(Double.toString(thisBoostValue), DataType.STRING));
updatedFuncArgs.add(newBoostArg);
}

// create a new function expression with boost argument and resolve it
Function updatedRelevanceQueryUnresolvedExpr = new Function(
relevanceQueryUnresolvedExpr.getFuncName(),
updatedFuncArgs);
OpenSearchFunctions.OpenSearchFunction relevanceQueryExpr =
(OpenSearchFunctions.OpenSearchFunction) updatedRelevanceQueryUnresolvedExpr
.accept(this, context);
relevanceQueryExpr.setScoreTracked(true);
return relevanceQueryExpr;
}

@Override
public Expression visitIn(In node, AnalysisContext context) {
return visitIn(node.getField(), node.getValueList(), context);
Expand Down Expand Up @@ -297,6 +358,23 @@ public Expression visitAllFields(AllFields node, AnalysisContext context) {
@Override
public Expression visitQualifiedName(QualifiedName node, AnalysisContext context) {
QualifierAnalyzer qualifierAnalyzer = new QualifierAnalyzer(context);

// check for reserved words in the identifier
for (String part : node.getParts()) {
for (TypeEnvironment typeEnv = context.peek();
typeEnv != null;
typeEnv = typeEnv.getParent()) {
Optional<ExprType> exprType = typeEnv.getReservedSymbolTable().lookup(
new Symbol(Namespace.FIELD_NAME, part));
if (exprType.isPresent()) {
return visitMetadata(
qualifierAnalyzer.unqualified(node),
(ExprCoreType) exprType.get(),
context
);
}
}
}
return visitIdentifier(qualifierAnalyzer.unqualified(node), context);
}

Expand All @@ -313,6 +391,19 @@ public Expression visitUnresolvedArgument(UnresolvedArgument node, AnalysisConte
return new NamedArgumentExpression(node.getArgName(), node.getValue().accept(this, context));
}

/**
* If QualifiedName is actually a reserved metadata field, return the expr type associated
* with the metadata field.
* @param ident metadata field name
* @param context analysis context
* @return DSL reference
*/
private Expression visitMetadata(String ident,
ExprCoreType exprCoreType,
AnalysisContext context) {
return DSL.ref(ident, exprCoreType);
}

private Expression visitIdentifier(String ident, AnalysisContext context) {
// ParseExpression will always override ReferenceExpression when ident conflicts
for (NamedExpression expr : context.getNamedParseExpressions()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.sql.expression.conditional.cases.CaseClause;
import org.opensearch.sql.expression.conditional.cases.WhenClause;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.OpenSearchFunctions;
import org.opensearch.sql.planner.logical.LogicalAggregation;
import org.opensearch.sql.planner.logical.LogicalPlan;
import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor;
Expand Down Expand Up @@ -70,8 +71,17 @@ public Expression visitFunction(FunctionExpression node, AnalysisContext context
final List<Expression> args =
node.getArguments().stream().map(expr -> expr.accept(this, context))
.collect(Collectors.toList());
return (Expression) repository.compile(context.getFunctionProperties(),
node.getFunctionName(), args);
Expression optimizedFunctionExpression = (Expression) repository.compile(
context.getFunctionProperties(),
node.getFunctionName(),
args
);
// Propagate scoreTracked for OpenSearch functions
if (optimizedFunctionExpression instanceof OpenSearchFunctions.OpenSearchFunction) {
((OpenSearchFunctions.OpenSearchFunction) optimizedFunctionExpression).setScoreTracked(
((OpenSearchFunctions.OpenSearchFunction)node).isScoreTracked());
}
return optimizedFunctionExpression;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,30 @@ public class TypeEnvironment implements Environment<Symbol, ExprType> {
private final TypeEnvironment parent;
private final SymbolTable symbolTable;

@Getter
private final SymbolTable reservedSymbolTable;

/**
* Constructor with empty symbol tables.
*
* @param parent parent environment
*/
public TypeEnvironment(TypeEnvironment parent) {
this.parent = parent;
this.symbolTable = new SymbolTable();
this.reservedSymbolTable = new SymbolTable();
}

/**
* Constructor with empty reserved symbol table.
*
* @param parent parent environment
* @param symbolTable type table
*/
public TypeEnvironment(TypeEnvironment parent, SymbolTable symbolTable) {
this.parent = parent;
this.symbolTable = symbolTable;
this.reservedSymbolTable = new SymbolTable();
}

/**
Expand All @@ -59,6 +75,7 @@ public ExprType resolve(Symbol symbol) {

/**
* Resolve all fields in the current environment.
*
* @param namespace a namespace
* @return all symbols in the namespace
*/
Expand Down Expand Up @@ -102,7 +119,11 @@ public void remove(ReferenceExpression ref) {
* Clear all fields in the current environment.
*/
public void clearAllFields() {
lookupAllFields(FIELD_NAME).keySet().stream()
.forEach(v -> remove(new Symbol(Namespace.FIELD_NAME, v)));
lookupAllFields(FIELD_NAME).keySet().forEach(
v -> remove(new Symbol(Namespace.FIELD_NAME, v)));
}

public void addReservedWord(Symbol symbol, ExprType type) {
reservedSymbolTable.store(symbol, type);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.sql.ast.expression.Or;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.RelevanceFieldList;
import org.opensearch.sql.ast.expression.ScoreFunction;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.UnresolvedArgument;
import org.opensearch.sql.ast.expression.UnresolvedAttribute;
Expand Down Expand Up @@ -278,6 +279,10 @@ public T visitHighlightFunction(HighlightFunction node, C context) {
return visitChildren(node, context);
}

public T visitScoreFunction(ScoreFunction node, C context) {
return visitChildren(node, context);
}

public T visitStatement(Statement node, C context) {
return visit(node, context);
}
Expand Down
7 changes: 6 additions & 1 deletion core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.opensearch.sql.ast.expression.Or;
import org.opensearch.sql.ast.expression.ParseMethod;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.ScoreFunction;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.ast.expression.UnresolvedArgument;
Expand All @@ -60,7 +61,6 @@
import org.opensearch.sql.ast.tree.TableFunction;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.ast.tree.Values;
import org.opensearch.sql.expression.function.BuiltinFunctionName;

/**
* Class of static methods to create specific node instances.
Expand Down Expand Up @@ -285,6 +285,11 @@ public UnresolvedExpression highlight(UnresolvedExpression fieldName,
return new HighlightFunction(fieldName, arguments);
}

public UnresolvedExpression score(UnresolvedExpression relevanceQuery,
Literal relevanceFieldWeight) {
return new ScoreFunction(relevanceQuery, relevanceFieldWeight);
}

public UnresolvedExpression window(UnresolvedExpression function,
List<UnresolvedExpression> partitionByList,
List<Pair<SortOption, UnresolvedExpression>> sortList) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression;

import java.util.List;
import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;

/**
* Expression node of Score function.
* Score takes a relevance-search expression as an argument and returns it
*/
@AllArgsConstructor
@EqualsAndHashCode(callSuper = false)
@Getter
@ToString
public class ScoreFunction extends UnresolvedExpression {
private final UnresolvedExpression relevanceQuery;
private final Literal relevanceFieldWeight;

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitScoreFunction(this, context);
}

@Override
public List<UnresolvedExpression> getChild() {
return List.of(relevanceQuery);
}
}
14 changes: 13 additions & 1 deletion core/src/main/java/org/opensearch/sql/expression/DSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,19 @@ public static FunctionExpression match_bool_prefix(Expression... args) {
}

public static FunctionExpression wildcard_query(Expression... args) {
return compile(FunctionProperties.None,BuiltinFunctionName.WILDCARD_QUERY, args);
return compile(FunctionProperties.None, BuiltinFunctionName.WILDCARD_QUERY, args);
}

public static FunctionExpression score(Expression... args) {
return compile(FunctionProperties.None, BuiltinFunctionName.SCORE, args);
}

public static FunctionExpression scorequery(Expression... args) {
return compile(FunctionProperties.None, BuiltinFunctionName.SCOREQUERY, args);
}

public static FunctionExpression score_query(Expression... args) {
return compile(FunctionProperties.None, BuiltinFunctionName.SCORE_QUERY, args);
}

public static FunctionExpression now(FunctionProperties functionProperties,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ public enum BuiltinFunctionName {
WEEK_OF_YEAR(FunctionName.of("week_of_year")),
YEAR(FunctionName.of("year")),
YEARWEEK(FunctionName.of("yearweek")),

// `now`-like functions
NOW(FunctionName.of("now")),
CURDATE(FunctionName.of("curdate")),
Expand All @@ -132,6 +133,7 @@ public enum BuiltinFunctionName {
CURRENT_TIMESTAMP(FunctionName.of("current_timestamp")),
LOCALTIMESTAMP(FunctionName.of("localtimestamp")),
SYSDATE(FunctionName.of("sysdate")),

/**
* Text Functions.
*/
Expand Down Expand Up @@ -255,6 +257,10 @@ public enum BuiltinFunctionName {
MATCH_BOOL_PREFIX(FunctionName.of("match_bool_prefix")),
HIGHLIGHT(FunctionName.of("highlight")),
MATCH_PHRASE_PREFIX(FunctionName.of("match_phrase_prefix")),
SCORE(FunctionName.of("score")),
SCOREQUERY(FunctionName.of("scorequery")),
SCORE_QUERY(FunctionName.of("score_query")),

/**
* Legacy Relevance Function.
*/
Expand Down
Loading

0 comments on commit 5ec28a9

Please sign in to comment.