Skip to content

Commit

Permalink
Adding IT tests for permissive fields, moving field validation to Ope…
Browse files Browse the repository at this point in the history
…nSearchFunctions.

Signed-off-by: forestmvey <forestv@bitquilltech.com>
  • Loading branch information
forestmvey committed Nov 9, 2022
1 parent b0c840e commit a5d26cf
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 70 deletions.
12 changes: 4 additions & 8 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,10 @@
import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC;
import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC;
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;
import static org.opensearch.sql.utils.MLCommonsConstants.ACTION;
import static org.opensearch.sql.utils.MLCommonsConstants.MODELID;
import static org.opensearch.sql.utils.MLCommonsConstants.PREDICT;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_TIMESTAMP;
import static org.opensearch.sql.utils.MLCommonsConstants.STATUS;
import static org.opensearch.sql.utils.MLCommonsConstants.TASKID;
import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD;
import static org.opensearch.sql.utils.MLCommonsConstants.TRAIN;
import static org.opensearch.sql.utils.MLCommonsConstants.TRAINANDPREDICT;
import static org.opensearch.sql.utils.SystemIndexUtils.CATALOGS_TABLE_NAME;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -76,13 +68,15 @@
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.LiteralExpression;
import org.opensearch.sql.expression.NamedExpression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.aggregation.Aggregator;
import org.opensearch.sql.expression.aggregation.NamedAggregator;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.OpenSearchFunctions;
import org.opensearch.sql.expression.function.TableFunctionImplementation;
import org.opensearch.sql.expression.parse.ParseExpression;
import org.opensearch.sql.planner.logical.LogicalAD;
Expand Down Expand Up @@ -225,6 +219,8 @@ public LogicalPlan visitFilter(Filter node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
Expression condition = expressionAnalyzer.analyze(node.getCondition(), context);

OpenSearchFunctions.validateFieldList((FunctionExpression)condition, context);

ExpressionReferenceOptimizer optimizer =
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child);
Expression optimized = optimizer.optimize(condition, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,19 @@

package org.opensearch.sql.analysis;

import static org.opensearch.sql.expression.function.OpenSearchFunctions.isMultiFieldFunction;
import static org.opensearch.sql.expression.function.OpenSearchFunctions.isSingleFieldFunction;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.opensearch.sql.analysis.symbol.Namespace;
import org.opensearch.sql.analysis.symbol.Symbol;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.ExpressionNodeVisitor;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.NamedArgumentExpression;
import org.opensearch.sql.expression.NamedExpression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.aggregation.Aggregator;
import org.opensearch.sql.expression.conditional.cases.CaseClause;
import org.opensearch.sql.expression.conditional.cases.WhenClause;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.FunctionImplementation;
import org.opensearch.sql.planner.logical.LogicalAggregation;
import org.opensearch.sql.planner.logical.LogicalPlan;
import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor;
Expand Down Expand Up @@ -78,30 +70,7 @@ public Expression visitFunction(FunctionExpression node, AnalysisContext context
final List<Expression> args =
node.getArguments().stream().map(expr -> expr.accept(this, context))
.collect(Collectors.toList());
String funcName = node.getFunctionName().toString();
FunctionImplementation ret = repository.compile(node.getFunctionName(), args);

TypeEnvironment typeEnv = context.peek();
if (isSingleFieldFunction(funcName)) {
ret.getArguments().stream().filter(arg ->
(((NamedArgumentExpression)arg).getArgName().equals("field"))
&& !((NamedArgumentExpression)arg).getValue().toString().contains("*")
).findFirst().ifPresent(arg ->
typeEnv.resolve(new Symbol(Namespace.FIELD_NAME,
StringUtils.unquoteText(((NamedArgumentExpression)arg).getValue().toString()))
)
);
} else if (isMultiFieldFunction(funcName)) {
ret.getArguments().stream().filter(arg ->
((NamedArgumentExpression)arg).getArgName().equals("fields")
).findFirst().ifPresent(fields ->
((NamedArgumentExpression)fields).getValue().valueOf(null).tupleValue()
.entrySet().stream().filter(k -> !(k.getKey().contains("*"))
).forEach(key -> typeEnv.resolve(new Symbol(Namespace.FIELD_NAME, key.getKey())))
);
}

return (Expression) ret;
return (Expression) repository.compile(node.getFunctionName(), args);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.stream.Collectors;
import lombok.experimental.UtilityClass;
import org.opensearch.sql.analysis.AnalysisContext;
import org.opensearch.sql.analysis.TypeEnvironment;
import org.opensearch.sql.analysis.symbol.Namespace;
import org.opensearch.sql.analysis.symbol.Symbol;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
Expand All @@ -36,14 +40,53 @@ public class OpenSearchFunctions {
BuiltinFunctionName.QUERY_STRING.name()
);

/**
* Check if supplied function name is valid SingleFieldRelevanceFunction.
* @param funcName : Name of function
* @return : True if function is single-field function
*/
public static boolean isSingleFieldFunction(String funcName) {
return singleFieldFunctionNames.contains(funcName.toUpperCase());
}

/**
* Check if supplied function name is valid MultiFieldRelevanceFunction.
* @param funcName : Name of function
* @return : True if function is multi-field function
*/
public static boolean isMultiFieldFunction(String funcName) {
return multiFieldFunctionNames.contains(funcName.toUpperCase());
}

/**
* Verify if function queries fields available in type environment.
* @param node : Function used in query.
* @param context : Context of fields querying.
*/
public static void validateFieldList(FunctionExpression node, AnalysisContext context) {
String funcName = node.getFunctionName().toString();

TypeEnvironment typeEnv = context.peek();
if (isSingleFieldFunction(funcName)) {
node.getArguments().stream().map(NamedArgumentExpression.class::cast).filter(arg ->
((arg.getArgName().equals("field")
&& !arg.getValue().toString().contains("*"))
)).findFirst().ifPresent(arg ->
typeEnv.resolve(new Symbol(Namespace.FIELD_NAME,
StringUtils.unquoteText(arg.getValue().toString()))
)
);
} else if (isMultiFieldFunction(funcName)) {
node.getArguments().stream().map(NamedArgumentExpression.class::cast).filter(arg ->
arg.getArgName().equals("fields")
).findFirst().ifPresent(fields ->
fields.getValue().valueOf(null).tupleValue()
.entrySet().stream().filter(k -> !(k.getKey().contains("*"))
).forEach(key -> typeEnv.resolve(new Symbol(Namespace.FIELD_NAME, key.getKey())))
);
}
}

/**
* Add functions specific to OpenSearch to repository.
*/
Expand Down
53 changes: 53 additions & 0 deletions core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,23 @@ public void analyze_filter_aggregation_relation() {
aggregate("MIN", qualifiedName("integer_value")), intLiteral(10))));
}

@Test
public void single_field_relevance_query_semantic_exception() {
SemanticCheckException exception =
assertThrows(
SemanticCheckException.class,
() ->
analyze(
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("match",
AstDSL.unresolvedArg("field", stringLiteral("missing_value")),
AstDSL.unresolvedArg("query", stringLiteral("query_value"))))));
assertEquals(
"can't resolve Symbol(namespace=FIELD_NAME, name=missing_value) in type env",
exception.getMessage());
}

@Test
public void single_field_relevance_query() {
assertAnalyzeEqual(
Expand Down Expand Up @@ -299,6 +316,42 @@ public void single_field_wildcard_relevance_query() {
AstDSL.unresolvedArg("query", stringLiteral("query_value")))));
}

@Test
public void multi_field_relevance_query_semantic_exception() {
SemanticCheckException exception =
assertThrows(
SemanticCheckException.class,
() ->
analyze(
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
"missing_value1", 1.F, "missing_value2", .3F))),
AstDSL.unresolvedArg("query", stringLiteral("query_value"))))));
assertEquals(
"can't resolve Symbol(namespace=FIELD_NAME, name=missing_value1) in type env",
exception.getMessage());
}

@Test
public void multi_field_relevance_query_mixed_fields_semantic_exception() {
SemanticCheckException exception =
assertThrows(
SemanticCheckException.class,
() ->
analyze(
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
"string_value", 1.F, "missing_value", .3F))),
AstDSL.unresolvedArg("query", stringLiteral("query_value"))))));
assertEquals(
"can't resolve Symbol(namespace=FIELD_NAME, name=missing_value) in type env",
exception.getMessage());
}

@Test
public void multi_field_relevance_query() {
assertAnalyzeEqual(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,13 @@

import static java.util.Collections.emptyList;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE;
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.LinkedHashMap;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.opensearch.sql.data.model.ExprTupleValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.config.ExpressionConfig;
Expand Down Expand Up @@ -52,29 +46,6 @@ void group_expression_should_be_replaced() {
);
}

@Test
void missing_field_single_field_relevance_query() {
assertThrows(SemanticCheckException.class,
() -> optimize(dsl.match(
dsl.namedArgument("field", DSL.literal("missing_field")),
dsl.namedArgument("query", DSL.literal("query_value")))));
}

@Test
void missing_field_multi_field_relevance_query() {
assertThrows(SemanticCheckException.class,
() -> optimize(dsl.query_string(
dsl.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"missing_field1", ExprValueUtils.floatValue(1.F),
"missing_field2", ExprValueUtils.floatValue(.3F))
)
)
))
))
);
}

@Test
void aggregation_expression_should_be_replaced() {
assertEquals(
Expand Down
11 changes: 11 additions & 0 deletions integ-test/src/test/java/org/opensearch/sql/sql/MatchIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.json.JSONObject;
import org.junit.Test;
import org.opensearch.sql.legacy.SQLIntegTestCase;
import org.opensearch.sql.legacy.utils.StringUtils;

public class MatchIT extends SQLIntegTestCase {
@Override
Expand All @@ -35,4 +36,14 @@ public void match_in_having() throws IOException {
verifySchema(result, schema("lastname", "text"));
verifyDataRows(result, rows("Bates"));
}

@Test
public void missing_field_test() {
String query = StringUtils.format("SELECT * FROM %s WHERE match(invalid, 'Bates')", TEST_INDEX_ACCOUNT);
final RuntimeException exception =
expectThrows(RuntimeException.class, () -> executeJdbcRequest(query));
assertTrue(exception.getMessage()
.contains("can't resolve Symbol(namespace=FIELD_NAME, name=invalid) in type env") &&
exception.getMessage().contains("SemanticCheckException"));
}
}
11 changes: 11 additions & 0 deletions integ-test/src/test/java/org/opensearch/sql/sql/QueryStringIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.json.JSONObject;
import org.junit.Test;
import org.opensearch.sql.legacy.SQLIntegTestCase;
import org.opensearch.sql.legacy.utils.StringUtils;

public class QueryStringIT extends SQLIntegTestCase {
@Override
Expand Down Expand Up @@ -65,4 +66,14 @@ public void wildcard_test() throws IOException {
JSONObject result3 = executeJdbcRequest(query3);
assertEquals(10, result3.getInt("total"));
}

@Test
public void missing_field_test() {
String query = StringUtils.format("SELECT * FROM %s WHERE query_string([invalid], 'beer')", TEST_INDEX_BEER);
final RuntimeException exception =
expectThrows(RuntimeException.class, () -> executeJdbcRequest(query));
assertTrue(exception.getMessage()
.contains("can't resolve Symbol(namespace=FIELD_NAME, name=invalid) in type env") &&
exception.getMessage().contains("SemanticCheckException"));
}
}

0 comments on commit a5d26cf

Please sign in to comment.