Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC Refactor: Move relevance search functions from :core to :opensearch #2025

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ dependencies {
api "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}"
api group: 'com.google.code.gson', name: 'gson', version: '2.8.9'
api project(':common')
implementation project(':sql')
implementation project(':ppl')

testImplementation('org.junit.jupiter:junit-jupiter:5.6.2')
testImplementation group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1'
Expand Down
117 changes: 66 additions & 51 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.opensearch.sql.analysis.symbol.Namespace;
import org.opensearch.sql.analysis.symbol.Symbol;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.Node;
import org.opensearch.sql.ast.expression.Argument;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Let;
Expand Down Expand Up @@ -108,30 +109,45 @@
* Analyze the {@link UnresolvedPlan} in the {@link AnalysisContext} to construct the {@link
* LogicalPlan}.
*/
public class Analyzer extends AbstractNodeVisitor<LogicalPlan, AnalysisContext> {
// TODO make Analyzer abstract, don't create it; delete `visit` function
public class Analyzer implements AbstractNodeVisitor<LogicalPlan, AnalysisContext> {

private final ExpressionAnalyzer expressionAnalyzer;
protected final ExpressionAnalyzer expressionAnalyzer;

private final SelectExpressionAnalyzer selectExpressionAnalyzer;
protected final SelectExpressionAnalyzer selectExpressionAnalyzer;

private final NamedExpressionAnalyzer namedExpressionAnalyzer;
protected final NamedExpressionAnalyzer namedExpressionAnalyzer;

private final DataSourceService dataSourceService;
protected final DataSourceService dataSourceService;

private final BuiltinFunctionRepository repository;
protected final BuiltinFunctionRepository repository;

/** Constructor. */
public Analyzer(
protected Analyzer(
ExpressionAnalyzer expressionAnalyzer,
SelectExpressionAnalyzer selectExpressionAnalyzer,
NamedExpressionAnalyzer namedExpressionAnalyzer,
DataSourceService dataSourceService,
BuiltinFunctionRepository repository) {
this.expressionAnalyzer = expressionAnalyzer;
this.selectExpressionAnalyzer = selectExpressionAnalyzer;
this.namedExpressionAnalyzer = namedExpressionAnalyzer;
this.dataSourceService = dataSourceService;
this.selectExpressionAnalyzer = new SelectExpressionAnalyzer(expressionAnalyzer);
this.namedExpressionAnalyzer = new NamedExpressionAnalyzer(expressionAnalyzer);
this.repository = repository;
}

/** Constructor. */
public Analyzer(
ExpressionAnalyzer expressionAnalyzer,
DataSourceService dataSourceService,
BuiltinFunctionRepository repository) {
this(
expressionAnalyzer,
new SelectExpressionAnalyzer(expressionAnalyzer),
new NamedExpressionAnalyzer(expressionAnalyzer),
dataSourceService,
repository);
}

public LogicalPlan analyze(UnresolvedPlan unresolved, AnalysisContext context) {
return unresolved.accept(this, context);
}
Expand Down Expand Up @@ -242,28 +258,6 @@ public LogicalPlan visitFilter(Filter node, AnalysisContext context) {
return new LogicalFilter(child, optimized);
}

/**
* Ensure NESTED function is not used in GROUP BY, and HAVING clauses. Fallback to legacy engine.
* Can remove when support is added for NESTED function in WHERE, GROUP BY, ORDER BY, and HAVING
* clauses.
*
* @param condition : Filter condition
*/
private void verifySupportsCondition(Expression condition) {
if (condition instanceof FunctionExpression) {
if (((FunctionExpression) condition)
.getFunctionName()
.getFunctionName()
.equalsIgnoreCase(BuiltinFunctionName.NESTED.name())) {
throw new SyntaxCheckException(
"Falling back to legacy engine. Nested function is not supported in WHERE,"
+ " GROUP BY, and HAVING clauses.");
}
((FunctionExpression) condition)
.getArguments().stream().forEach(e -> verifySupportsCondition(e));
}
}

/** Build {@link LogicalRename}. */
@Override
public LogicalPlan visitRename(Rename node, AnalysisContext context) {
Expand Down Expand Up @@ -384,35 +378,18 @@ public LogicalPlan visitProject(Project node, AnalysisContext context) {
node.getProjectList().stream()
.map(expr -> (ReferenceExpression) expressionAnalyzer.analyze(expr, context))
.collect(Collectors.toList());
referenceExpressions.forEach(ref -> curEnv.remove(ref));
referenceExpressions.forEach(curEnv::remove);
return new LogicalRemove(child, ImmutableSet.copyOf(referenceExpressions));
}
}

// For each unresolved window function, analyze it by "insert" a window and sort operator
// between project and its child.
for (UnresolvedExpression expr : node.getProjectList()) {
WindowExpressionAnalyzer windowAnalyzer =
new WindowExpressionAnalyzer(expressionAnalyzer, child);
child = windowAnalyzer.analyze(expr, context);
}

for (UnresolvedExpression expr : node.getProjectList()) {
HighlightAnalyzer highlightAnalyzer = new HighlightAnalyzer(expressionAnalyzer, child);
child = highlightAnalyzer.analyze(expr, context);
}

List<NamedExpression> namedExpressions =
selectExpressionAnalyzer.analyze(
node.getProjectList(),
context,
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child));

for (UnresolvedExpression expr : node.getProjectList()) {
NestedAnalyzer nestedAnalyzer =
new NestedAnalyzer(namedExpressions, expressionAnalyzer, child);
child = nestedAnalyzer.analyze(expr, context);
}
child = visitProjectList(node.getProjectList(), namedExpressions, child, context);

// new context
context.push();
Expand All @@ -424,6 +401,20 @@ public LogicalPlan visitProject(Project node, AnalysisContext context) {
return new LogicalProject(child, namedExpressions, namedParseExpressions);
}

@Override
public LogicalPlan visitProjectList(
List<UnresolvedExpression> columns, List<NamedExpression> namedExpressions, LogicalPlan child, AnalysisContext context) {

// For each unresolved window function, analyze it by "insert" a window and sort operator
// between project and its child.
for (UnresolvedExpression expr : columns) {
WindowExpressionAnalyzer windowAnalyzer =
new WindowExpressionAnalyzer(expressionAnalyzer, child);
child = windowAnalyzer.analyze(expr, context);
}
return child;
}

/** Build {@link LogicalEval}. */
@Override
public LogicalPlan visitEval(Eval node, AnalysisContext context) {
Expand Down Expand Up @@ -603,4 +594,28 @@ private SortOption analyzeSortOption(List<Argument> fieldArgs) {
}
return asc ? SortOption.DEFAULT_ASC : SortOption.DEFAULT_DESC;
}

@Override
public LogicalPlan visit(Node node, AnalysisContext context) {
// TODO
// actually would be never called for opensearch,
// because it doesn't override any visit*(node, context) func
// TODO
// rework: this code may call analyzers from different datasources while processing one tree (query)
for (var metadata : dataSourceService.getDataSourceMetadata(true)) {
var analyzer = dataSourceService
.getDataSource(metadata.getName())
.getStorageEngine()
.getAnalyzer(dataSourceService, repository);
if (analyzer == null) {
continue;
}
var res = node.accept(analyzer, context);
if (res != null) {
return res;
}
}
throw new SemanticCheckException(String.format("Unknown node: %s", node));
//return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,16 @@
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cast;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.HighlightFunction;
import org.opensearch.sql.ast.expression.In;
import org.opensearch.sql.ast.expression.Interval;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Not;
import org.opensearch.sql.ast.expression.Or;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.RelevanceFieldList;
import org.opensearch.sql.ast.expression.ScoreFunction;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.UnresolvedArgument;
import org.opensearch.sql.ast.expression.UnresolvedAttribute;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.When;
Expand All @@ -54,9 +49,6 @@
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.HighlightExpression;
import org.opensearch.sql.expression.LiteralExpression;
import org.opensearch.sql.expression.NamedArgumentExpression;
import org.opensearch.sql.expression.NamedExpression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.aggregation.AggregationState;
Expand All @@ -66,7 +58,6 @@
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.OpenSearchFunctions;
import org.opensearch.sql.expression.parse.ParseExpression;
import org.opensearch.sql.expression.span.SpanExpression;
import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction;
Expand All @@ -75,8 +66,8 @@
* Analyze the {@link UnresolvedExpression} in the {@link AnalysisContext} to construct the {@link
* Expression}.
*/
public class ExpressionAnalyzer extends AbstractNodeVisitor<Expression, AnalysisContext> {
@Getter private final BuiltinFunctionRepository repository;
public class ExpressionAnalyzer implements AbstractNodeVisitor<Expression, AnalysisContext> {
@Getter protected final BuiltinFunctionRepository repository;

@Override
public Expression visitCast(Cast node, AnalysisContext context) {
Expand Down Expand Up @@ -177,30 +168,15 @@ public Expression visitAggregateFunction(AggregateFunction node, AnalysisContext
}
}

@Override
public Expression visitRelevanceFieldList(RelevanceFieldList node, AnalysisContext context) {
return new LiteralExpression(
ExprValueUtils.tupleValue(ImmutableMap.copyOf(node.getFieldList())));
}

@Override
public Expression visitFunction(Function node, AnalysisContext context) {
FunctionName functionName = FunctionName.of(node.getFuncName());
List<Expression> arguments =
node.getFuncArgs().stream()
.map(
unresolvedExpression -> {
var ret = analyze(unresolvedExpression, context);
if (ret == null) {
throw new UnsupportedOperationException(
String.format("Invalid use of expression %s", unresolvedExpression));
} else {
return ret;
}
})
.map(unresolvedExpression -> analyze(unresolvedExpression, context))
.collect(Collectors.toList());
return (Expression)
repository.compile(context.getFunctionProperties(), functionName, arguments);
return (Expression) repository.compile(context.getFunctionProperties(),
functionName, arguments);
}

@SuppressWarnings("unchecked")
Expand All @@ -214,72 +190,6 @@ public Expression visitWindowFunction(WindowFunction node, AnalysisContext conte
return expr;
}

@Override
public Expression visitHighlightFunction(HighlightFunction node, AnalysisContext context) {
Expression expr = node.getHighlightField().accept(this, context);
return new HighlightExpression(expr);
}

/**
* visitScoreFunction removes the score function from the AST and replaces it with the child
* relevance function node. If the optional boost variable is provided, the boost argument of the
* relevance function is combined.
*
* @param node score function node
* @param context analysis context for the query
* @return resolved relevance function
*/
public Expression visitScoreFunction(ScoreFunction node, AnalysisContext context) {
Literal boostArg = node.getRelevanceFieldWeight();
if (!boostArg.getType().equals(DataType.DOUBLE)) {
throw new SemanticCheckException(
String.format(
"Expected boost type '%s' but got '%s'",
DataType.DOUBLE.name(), boostArg.getType().name()));
}
Double thisBoostValue = ((Double) boostArg.getValue());

// update the existing unresolved expression to add a boost argument if it doesn't exist
// OR multiply the existing boost argument
Function relevanceQueryUnresolvedExpr = (Function) node.getRelevanceQuery();
List<UnresolvedExpression> relevanceFuncArgs = relevanceQueryUnresolvedExpr.getFuncArgs();

boolean doesFunctionContainBoostArgument = false;
List<UnresolvedExpression> updatedFuncArgs = new ArrayList<>();
for (UnresolvedExpression expr : relevanceFuncArgs) {
String argumentName = ((UnresolvedArgument) expr).getArgName();
if (argumentName.equalsIgnoreCase("boost")) {
doesFunctionContainBoostArgument = true;
Literal boostArgLiteral = (Literal) ((UnresolvedArgument) expr).getValue();
Double boostValue =
Double.parseDouble((String) boostArgLiteral.getValue()) * thisBoostValue;
UnresolvedArgument newBoostArg =
new UnresolvedArgument(
argumentName, new Literal(boostValue.toString(), DataType.STRING));
updatedFuncArgs.add(newBoostArg);
} else {
updatedFuncArgs.add(expr);
}
}

// since nothing was found, add an argument
if (!doesFunctionContainBoostArgument) {
UnresolvedArgument newBoostArg =
new UnresolvedArgument(
"boost", new Literal(Double.toString(thisBoostValue), DataType.STRING));
updatedFuncArgs.add(newBoostArg);
}

// create a new function expression with boost argument and resolve it
Function updatedRelevanceQueryUnresolvedExpr =
new Function(relevanceQueryUnresolvedExpr.getFuncName(), updatedFuncArgs);
OpenSearchFunctions.OpenSearchFunction relevanceQueryExpr =
(OpenSearchFunctions.OpenSearchFunction)
updatedRelevanceQueryUnresolvedExpr.accept(this, context);
relevanceQueryExpr.setScoreTracked(true);
return relevanceQueryExpr;
}

@Override
public Expression visitIn(In node, AnalysisContext context) {
return visitIn(node.getField(), node.getValueList(), context);
Expand Down Expand Up @@ -396,11 +306,6 @@ public Expression visitSpan(Span node, AnalysisContext context) {
node.getUnit());
}

@Override
public Expression visitUnresolvedArgument(UnresolvedArgument node, AnalysisContext context) {
return new NamedArgumentExpression(node.getArgName(), node.getValue().accept(this, context));
}

/**
* If QualifiedName is actually a reserved metadata field, return the expr type associated with
* the metadata field.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
* NamedExpression}.
*/
@RequiredArgsConstructor
public class NamedExpressionAnalyzer extends AbstractNodeVisitor<NamedExpression, AnalysisContext> {
public class NamedExpressionAnalyzer implements AbstractNodeVisitor<NamedExpression, AnalysisContext> {
private final ExpressionAnalyzer expressionAnalyzer;

/** Analyze Select fields. */
Expand Down
Loading