Skip to content

Commit

Permalink
Validating field and fields parameters in TypeEnv as part of Expressi…
Browse files Browse the repository at this point in the history
…onAnalyzer.

Signed-off-by: forestmvey <forestv@bitquilltech.com>
  • Loading branch information
forestmvey committed Dec 8, 2022
1 parent 63750a1 commit 8fa4d22
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 114 deletions.
5 changes: 0 additions & 5 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
import org.opensearch.sql.expression.aggregation.NamedAggregator;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.OpenSearchFunctions;
import org.opensearch.sql.expression.function.TableFunctionImplementation;
import org.opensearch.sql.expression.parse.ParseExpression;
import org.opensearch.sql.planner.logical.LogicalAD;
Expand Down Expand Up @@ -221,10 +220,6 @@ public LogicalPlan visitFilter(Filter node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
Expression condition = expressionAnalyzer.analyze(node.getCondition(), context);

if (condition instanceof OpenSearchFunctions.OpenSearchFunction) {
((OpenSearchFunctions.OpenSearchFunction)condition).validateParameters(context);
}

ExpressionReferenceOptimizer optimizer =
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child);
Expression optimized = optimizer.optimize(condition, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.analysis;

import static org.opensearch.sql.expression.function.OpenSearchFunctions.isMultiFieldFunction;
import static org.opensearch.sql.expression.function.OpenSearchFunctions.isSingleFieldFunction;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -44,6 +46,7 @@
import org.opensearch.sql.ast.expression.WindowFunction;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.exception.SemanticCheckException;
Expand Down Expand Up @@ -180,6 +183,25 @@ public Expression visitFunction(Function node, AnalysisContext context) {
node.getFuncArgs().stream()
.map(unresolvedExpression -> analyze(unresolvedExpression, context))
.collect(Collectors.toList());

// Resolve fields and field parameters in TypeEnv
if (isSingleFieldFunction(functionName.getFunctionName())
|| isMultiFieldFunction(functionName.getFunctionName())) {
for (Expression arg : arguments) {
if (((NamedArgumentExpression) arg).getArgName().equals("field")
&& !((NamedArgumentExpression)arg).getValue().toString().contains("*")) {
visitQualifiedName(new QualifiedName(StringUtils.unquoteText(
((NamedArgumentExpression)arg).getValue().toString())), context);
} else if (((NamedArgumentExpression)arg).getArgName().equals("fields")) {
((NamedArgumentExpression) arg).getValue().valueOf().tupleValue().entrySet().stream()
.filter(field -> !field.getKey().contains("*")
).forEach(
entry -> visitQualifiedName(new QualifiedName(entry.getKey()), context)
);
}
}
}

return (Expression) repository.compile(context.getFunctionProperties(),
functionName, arguments);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.opensearch.sql.analysis.AnalysisContext;
import org.opensearch.sql.expression.function.FunctionImplementation;
import org.opensearch.sql.expression.function.FunctionName;

Expand All @@ -33,11 +32,4 @@ public <T, C> T accept(ExpressionNodeVisitor<T, C> visitor, C context) {
return visitor.visitFunction(this, context);
}

/**
* Verify if function queries fields available in type environment.
* @param context : Context of fields querying.
*/
public void validateParameters(AnalysisContext context) {
return;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@
import java.util.List;
import java.util.stream.Collectors;
import lombok.experimental.UtilityClass;
import org.opensearch.sql.analysis.AnalysisContext;
import org.opensearch.sql.analysis.TypeEnvironment;
import org.opensearch.sql.analysis.symbol.Namespace;
import org.opensearch.sql.analysis.symbol.Symbol;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
Expand Down Expand Up @@ -160,35 +155,5 @@ public String toString() {
.collect(Collectors.toList());
return String.format("%s(%s)", functionName, String.join(", ", args));
}

/**
* Verify if function queries fields available in type environment.
* @param context : Context of fields querying.
*/
@Override
public void validateParameters(AnalysisContext context) {
String funcName = this.getFunctionName().toString();

TypeEnvironment typeEnv = context.peek();
if (isSingleFieldFunction(funcName)) {
this.getArguments().stream().map(NamedArgumentExpression.class::cast).filter(arg ->
((arg.getArgName().equals("field")
&& !arg.getValue().toString().contains("*"))
)).findFirst().ifPresent(arg ->
typeEnv.resolve(new Symbol(Namespace.FIELD_NAME,
StringUtils.unquoteText(arg.getValue().toString()))
)
);
} else if (isMultiFieldFunction(funcName)) {
this.getArguments().stream().map(NamedArgumentExpression.class::cast).filter(arg ->
arg.getArgName().equals("fields")
).findFirst().ifPresent(fields ->
fields.getValue().valueOf(null).tupleValue()
.entrySet().stream().filter(k -> !(k.getKey().contains("*"))
).forEach(key -> typeEnv.resolve(new Symbol(Namespace.FIELD_NAME, key.getKey())))
);
}
}

}
}
63 changes: 23 additions & 40 deletions core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -275,23 +275,6 @@ public void analyze_filter_aggregation_relation() {
aggregate("MIN", qualifiedName("integer_value")), intLiteral(10))));
}

@Test
public void test_base_class_validate_parameters_method_does_nothing() {
var funcExpr = new FunctionExpression(FunctionName.of("func_name"),
List.of()) {
@Override
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
return null;
}

@Override
public ExprType type() {
return null;
}
};
funcExpr.validateParameters(analysisContext);
}

@Test
public void single_field_relevance_query_semantic_exception() {
SemanticCheckException exception =
Expand Down Expand Up @@ -328,10 +311,10 @@ public void single_field_relevance_query() {
public void single_field_wildcard_relevance_query() {
assertAnalyzeEqual(
LogicalPlanDSL.filter(
LogicalPlanDSL.relation("schema", table),
LogicalPlanDSL.relation("schema", table),
DSL.match(
DSL.namedArgument("field", DSL.literal("wildcard_field*")),
DSL.namedArgument("query", DSL.literal("query_value")))),
DSL.namedArgument("query", DSL.literal("query_value")))),
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("match",
Expand All @@ -357,6 +340,27 @@ public void multi_field_relevance_query_semantic_exception() {
exception.getMessage());
}

@Test
public void multi_field_wildcard_relevance_query() {
assertAnalyzeEqual(
LogicalPlanDSL.filter(
LogicalPlanDSL.relation("schema", table),
DSL.query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"wildcard_field1*", ExprValueUtils.floatValue(1.F),
"wildcard_field2*", ExprValueUtils.floatValue(.3F))
))
)),
DSL.namedArgument("query", DSL.literal("query_value")))),
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
"wildcard_field1*", 1.F, "wildcard_field2*", .3F))),
AstDSL.unresolvedArg("query", stringLiteral("query_value")))));
}

@Test
public void multi_field_relevance_query_mixed_fields_semantic_exception() {
SemanticCheckException exception =
Expand Down Expand Up @@ -409,27 +413,6 @@ public void multi_field_relevance_query() {
AstDSL.unresolvedArg("query", stringLiteral("query_value")))));
}

@Test
public void multi_field_wildcard_relevance_query() {
assertAnalyzeEqual(
LogicalPlanDSL.filter(
LogicalPlanDSL.relation("schema", table),
DSL.query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"wildcard_field1*", ExprValueUtils.floatValue(1.F),
"wildcard_field2*", ExprValueUtils.floatValue(.3F))
))
)),
DSL.namedArgument("query", DSL.literal("query_value")))),
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
"wildcard_field1*", 1.F, "wildcard_field2*", .3F))),
AstDSL.unresolvedArg("query", stringLiteral("query_value")))));
}

@Test
public void rename_relation() {
assertAnalyzeEqual(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,10 @@ public void named_non_parse_expression() {
void match_bool_prefix_expression() {
assertAnalyzeEqual(
DSL.match_bool_prefix(
DSL.namedArgument("field", DSL.literal("fieldA")),
DSL.namedArgument("field", DSL.literal("field_value1")),
DSL.namedArgument("query", DSL.literal("sample query"))),
AstDSL.function("match_bool_prefix",
AstDSL.unresolvedArg("field", stringLiteral("fieldA")),
AstDSL.unresolvedArg("field", stringLiteral("field_value1")),
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
}

Expand Down Expand Up @@ -402,11 +402,11 @@ void multi_match_expression() {
DSL.multi_match(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("sample query"))),
AstDSL.function("multi_match",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
}

Expand All @@ -416,12 +416,12 @@ void multi_match_expression_with_params() {
DSL.multi_match(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("sample query")),
DSL.namedArgument("analyzer", DSL.literal("keyword"))),
AstDSL.function("multi_match",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query")),
AstDSL.unresolvedArg("analyzer", stringLiteral("keyword"))));
}
Expand All @@ -432,12 +432,12 @@ void multi_match_expression_two_fields() {
DSL.multi_match(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field1", ExprValueUtils.floatValue(1.F),
"field2", ExprValueUtils.floatValue(.3F)))))),
"field_value1", ExprValueUtils.floatValue(1.F),
"field_value2", ExprValueUtils.floatValue(.3F)))))),
DSL.namedArgument("query", DSL.literal("sample query"))),
AstDSL.function("multi_match",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
"field1", 1.F, "field2", .3F))),
"field_value1", 1.F, "field_value2", .3F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
}

Expand All @@ -447,11 +447,11 @@ void simple_query_string_expression() {
DSL.simple_query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("sample query"))),
AstDSL.function("simple_query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
}

Expand All @@ -461,12 +461,12 @@ void simple_query_string_expression_with_params() {
DSL.simple_query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("sample query")),
DSL.namedArgument("analyzer", DSL.literal("keyword"))),
AstDSL.function("simple_query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query")),
AstDSL.unresolvedArg("analyzer", stringLiteral("keyword"))));
}
Expand All @@ -477,12 +477,12 @@ void simple_query_string_expression_two_fields() {
DSL.simple_query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field1", ExprValueUtils.floatValue(1.F),
"field2", ExprValueUtils.floatValue(.3F)))))),
"field_value1", ExprValueUtils.floatValue(1.F),
"field_value2", ExprValueUtils.floatValue(.3F)))))),
DSL.namedArgument("query", DSL.literal("sample query"))),
AstDSL.function("simple_query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
"field1", 1.F, "field2", .3F))),
"field_value1", 1.F, "field_value2", .3F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
}

Expand All @@ -501,11 +501,11 @@ void query_string_expression() {
DSL.query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("query_value"))),
AstDSL.function("query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("query_value"))));
}

Expand All @@ -515,12 +515,12 @@ void query_string_expression_with_params() {
DSL.query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("query_value")),
DSL.namedArgument("escape", DSL.literal("false"))),
AstDSL.function("query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("query_value")),
AstDSL.unresolvedArg("escape", stringLiteral("false"))));
}
Expand All @@ -531,12 +531,12 @@ void query_string_expression_two_fields() {
DSL.query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field1", ExprValueUtils.floatValue(1.F),
"field2", ExprValueUtils.floatValue(.3F)))))),
"field_value1", ExprValueUtils.floatValue(1.F),
"field_value2", ExprValueUtils.floatValue(.3F)))))),
DSL.namedArgument("query", DSL.literal("query_value"))),
AstDSL.function("query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
"field1", 1.F, "field2", .3F))),
"field_value1", 1.F, "field_value2", .3F))),
AstDSL.unresolvedArg("query", stringLiteral("query_value"))));
}

Expand Down Expand Up @@ -572,7 +572,7 @@ void wildcard_query_expression_all_params() {
public void match_phrase_prefix_all_params() {
assertAnalyzeEqual(
DSL.match_phrase_prefix(
DSL.namedArgument("field", "test"),
DSL.namedArgument("field", "field_value1"),
DSL.namedArgument("query", "search query"),
DSL.namedArgument("slop", "3"),
DSL.namedArgument("boost", "1.5"),
Expand All @@ -581,7 +581,7 @@ public void match_phrase_prefix_all_params() {
DSL.namedArgument("zero_terms_query", "NONE")
),
AstDSL.function("match_phrase_prefix",
unresolvedArg("field", stringLiteral("test")),
unresolvedArg("field", stringLiteral("field_value1")),
unresolvedArg("query", stringLiteral("search query")),
unresolvedArg("slop", stringLiteral("3")),
unresolvedArg("boost", stringLiteral("1.5")),
Expand Down
Loading

0 comments on commit 8fa4d22

Please sign in to comment.