Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate Field and Fields Parameters in Relevance Search Functions #1067

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@

package org.opensearch.sql.expression.function;

import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;

import com.google.common.collect.ImmutableMap;
import java.util.List;
import java.util.stream.Collectors;
import lombok.experimental.UtilityClass;
Expand Down Expand Up @@ -48,46 +44,46 @@ public void register(BuiltinFunctionRepository repository) {

private static FunctionResolver match_bool_prefix() {
FunctionName name = BuiltinFunctionName.MATCH_BOOL_PREFIX.getName();
return new RelevanceFunctionResolver(name, STRING);
return new RelevanceFunctionResolver(name);
}

private static FunctionResolver match(BuiltinFunctionName match) {
FunctionName funcName = match.getName();
return new RelevanceFunctionResolver(funcName, STRING);
return new RelevanceFunctionResolver(funcName);
}

private static FunctionResolver match_phrase_prefix() {
FunctionName funcName = BuiltinFunctionName.MATCH_PHRASE_PREFIX.getName();
return new RelevanceFunctionResolver(funcName, STRING);
return new RelevanceFunctionResolver(funcName);
}

private static FunctionResolver match_phrase(BuiltinFunctionName matchPhrase) {
FunctionName funcName = matchPhrase.getName();
return new RelevanceFunctionResolver(funcName, STRING);
return new RelevanceFunctionResolver(funcName);
}

private static FunctionResolver multi_match(BuiltinFunctionName multiMatchName) {
return new RelevanceFunctionResolver(multiMatchName.getName(), STRUCT);
return new RelevanceFunctionResolver(multiMatchName.getName());
}

private static FunctionResolver simple_query_string() {
FunctionName funcName = BuiltinFunctionName.SIMPLE_QUERY_STRING.getName();
return new RelevanceFunctionResolver(funcName, STRUCT);
return new RelevanceFunctionResolver(funcName);
}

private static FunctionResolver query() {
FunctionName funcName = BuiltinFunctionName.QUERY.getName();
return new RelevanceFunctionResolver(funcName, STRING);
return new RelevanceFunctionResolver(funcName);
}

private static FunctionResolver query_string() {
FunctionName funcName = BuiltinFunctionName.QUERY_STRING.getName();
return new RelevanceFunctionResolver(funcName, STRUCT);
return new RelevanceFunctionResolver(funcName);
}

private static FunctionResolver wildcard_query(BuiltinFunctionName wildcardQuery) {
FunctionName funcName = wildcardQuery.getName();
return new RelevanceFunctionResolver(funcName, STRING);
return new RelevanceFunctionResolver(funcName);
}

public static class OpenSearchFunction extends FunctionExpression {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,13 @@ public class RelevanceFunctionResolver
@Getter
private final FunctionName functionName;

@Getter
private final ExprType declaredFirstParamType;

@Override
public Pair<FunctionSignature, FunctionBuilder> resolve(FunctionSignature unresolvedSignature) {
if (!unresolvedSignature.getFunctionName().equals(functionName)) {
throw new SemanticCheckException(String.format("Expected '%s' but got '%s'",
functionName.getFunctionName(), unresolvedSignature.getFunctionName().getFunctionName()));
}
List<ExprType> paramTypes = unresolvedSignature.getParamTypeList();
ExprType providedFirstParamType = paramTypes.get(0);

// Check if the first parameter is of the specified type.
if (!declaredFirstParamType.equals(providedFirstParamType)) {
throw new SemanticCheckException(
getWrongParameterErrorMessage(0, providedFirstParamType, declaredFirstParamType));
}

// Check if all but the first parameter are of type STRING.
for (int i = 1; i < paramTypes.size(); i++) {
ExprType paramType = paramTypes.get(i);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,10 @@ public void named_non_parse_expression() {
void match_bool_prefix_expression() {
assertAnalyzeEqual(
DSL.match_bool_prefix(
DSL.namedArgument("field", DSL.literal("fieldA")),
DSL.namedArgument("field", DSL.literal("field_value1")),
DSL.namedArgument("query", DSL.literal("sample query"))),
AstDSL.function("match_bool_prefix",
AstDSL.unresolvedArg("field", stringLiteral("fieldA")),
AstDSL.unresolvedArg("field", stringLiteral("field_value1")),
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
}

Expand Down Expand Up @@ -402,11 +402,11 @@ void multi_match_expression() {
DSL.multi_match(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("sample query"))),
AstDSL.function("multi_match",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
}

Expand All @@ -416,12 +416,12 @@ void multi_match_expression_with_params() {
DSL.multi_match(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("sample query")),
DSL.namedArgument("analyzer", DSL.literal("keyword"))),
AstDSL.function("multi_match",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query")),
AstDSL.unresolvedArg("analyzer", stringLiteral("keyword"))));
}
Expand All @@ -432,12 +432,12 @@ void multi_match_expression_two_fields() {
DSL.multi_match(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field1", ExprValueUtils.floatValue(1.F),
"field2", ExprValueUtils.floatValue(.3F)))))),
"field_value1", ExprValueUtils.floatValue(1.F),
"field_value2", ExprValueUtils.floatValue(.3F)))))),
DSL.namedArgument("query", DSL.literal("sample query"))),
AstDSL.function("multi_match",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
"field1", 1.F, "field2", .3F))),
"field_value1", 1.F, "field_value2", .3F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
}

Expand All @@ -447,11 +447,11 @@ void simple_query_string_expression() {
DSL.simple_query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("sample query"))),
AstDSL.function("simple_query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
}

Expand All @@ -461,12 +461,12 @@ void simple_query_string_expression_with_params() {
DSL.simple_query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("sample query")),
DSL.namedArgument("analyzer", DSL.literal("keyword"))),
AstDSL.function("simple_query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query")),
AstDSL.unresolvedArg("analyzer", stringLiteral("keyword"))));
}
Expand All @@ -477,12 +477,12 @@ void simple_query_string_expression_two_fields() {
DSL.simple_query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field1", ExprValueUtils.floatValue(1.F),
"field2", ExprValueUtils.floatValue(.3F)))))),
"field_value1", ExprValueUtils.floatValue(1.F),
"field_value2", ExprValueUtils.floatValue(.3F)))))),
DSL.namedArgument("query", DSL.literal("sample query"))),
AstDSL.function("simple_query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
"field1", 1.F, "field2", .3F))),
"field_value1", 1.F, "field_value2", .3F))),
AstDSL.unresolvedArg("query", stringLiteral("sample query"))));
}

Expand All @@ -501,11 +501,11 @@ void query_string_expression() {
DSL.query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("query_value"))),
AstDSL.function("query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("query_value"))));
}

Expand All @@ -515,12 +515,12 @@ void query_string_expression_with_params() {
DSL.query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field", ExprValueUtils.floatValue(1.F)))))),
"field_value1", ExprValueUtils.floatValue(1.F)))))),
DSL.namedArgument("query", DSL.literal("query_value")),
DSL.namedArgument("escape", DSL.literal("false"))),
AstDSL.function("query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of(
"field", 1.F))),
"field_value1", 1.F))),
AstDSL.unresolvedArg("query", stringLiteral("query_value")),
AstDSL.unresolvedArg("escape", stringLiteral("false"))));
}
Expand All @@ -531,12 +531,12 @@ void query_string_expression_two_fields() {
DSL.query_string(
DSL.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"field1", ExprValueUtils.floatValue(1.F),
"field2", ExprValueUtils.floatValue(.3F)))))),
"field_value1", ExprValueUtils.floatValue(1.F),
"field_value2", ExprValueUtils.floatValue(.3F)))))),
DSL.namedArgument("query", DSL.literal("query_value"))),
AstDSL.function("query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
"field1", 1.F, "field2", .3F))),
"field_value1", 1.F, "field_value2", .3F))),
AstDSL.unresolvedArg("query", stringLiteral("query_value"))));
}

Expand Down Expand Up @@ -572,7 +572,7 @@ void wildcard_query_expression_all_params() {
public void match_phrase_prefix_all_params() {
assertAnalyzeEqual(
DSL.match_phrase_prefix(
DSL.namedArgument("field", "test"),
DSL.namedArgument("field", "field_value1"),
DSL.namedArgument("query", "search query"),
DSL.namedArgument("slop", "3"),
DSL.namedArgument("boost", "1.5"),
Expand All @@ -581,7 +581,7 @@ public void match_phrase_prefix_all_params() {
DSL.namedArgument("zero_terms_query", "NONE")
),
AstDSL.function("match_phrase_prefix",
unresolvedArg("field", stringLiteral("test")),
unresolvedArg("field", stringLiteral("field_value1")),
unresolvedArg("query", stringLiteral("search query")),
unresolvedArg("slop", stringLiteral("3")),
unresolvedArg("boost", stringLiteral("1.5")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ public class TestConfig {
.put("struct_value", ExprCoreType.STRUCT)
.put("array_value", ExprCoreType.ARRAY)
.put("timestamp_value", ExprCoreType.TIMESTAMP)
.put("field_value1", ExprCoreType.STRING)
.put("field_value2", ExprCoreType.STRING)
.build();

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class RelevanceFunctionResolverTest {

@BeforeEach
void setUp() {
resolver = new RelevanceFunctionResolver(sampleFuncName, STRING);
resolver = new RelevanceFunctionResolver(sampleFuncName);
}

@Test
Expand All @@ -44,15 +44,6 @@ void resolve_invalid_name_test() {
exception.getMessage());
}

@Test
void resolve_invalid_first_param_type_test() {
var sig = new FunctionSignature(sampleFuncName, List.of(INTEGER));
Exception exception = assertThrows(SemanticCheckException.class,
() -> resolver.resolve(sig));
assertEquals("Expected type STRING instead of INTEGER for parameter #1",
exception.getMessage());
}

@Test
void resolve_invalid_third_param_type_test() {
var sig = new FunctionSignature(sampleFuncName, List.of(STRING, STRING, INTEGER, STRING));
Expand Down
37 changes: 37 additions & 0 deletions integ-test/src/test/java/org/opensearch/sql/sql/MatchIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.json.JSONObject;
import org.junit.Test;
import org.opensearch.sql.legacy.SQLIntegTestCase;
import org.opensearch.sql.legacy.utils.StringUtils;

public class MatchIT extends SQLIntegTestCase {
@Override
Expand All @@ -36,6 +37,42 @@ public void match_in_having() throws IOException {
verifyDataRows(result, rows("Bates"));
}

@Test
public void missing_field_test() {
String query = StringUtils.format("SELECT * FROM %s WHERE match(invalid, 'Bates')", TEST_INDEX_ACCOUNT);
final RuntimeException exception =
expectThrows(RuntimeException.class, () -> executeJdbcRequest(query));

assertTrue(exception.getMessage()
.contains("can't resolve Symbol(namespace=FIELD_NAME, name=invalid) in type env"));

assertTrue(exception.getMessage().contains("SemanticCheckException"));
}

@Test
public void missing_quoted_field_test() {
String query = StringUtils.format("SELECT * FROM %s WHERE match('invalid', 'Bates')", TEST_INDEX_ACCOUNT);
final RuntimeException exception =
expectThrows(RuntimeException.class, () -> executeJdbcRequest(query));

assertTrue(exception.getMessage()
.contains("can't resolve Symbol(namespace=FIELD_NAME, name=invalid) in type env"));

assertTrue(exception.getMessage().contains("SemanticCheckException"));
}

@Test
public void missing_backtick_field_test() {
String query = StringUtils.format("SELECT * FROM %s WHERE match(`invalid`, 'Bates')", TEST_INDEX_ACCOUNT);
final RuntimeException exception =
expectThrows(RuntimeException.class, () -> executeJdbcRequest(query));

assertTrue(exception.getMessage()
.contains("can't resolve Symbol(namespace=FIELD_NAME, name=invalid) in type env"));

assertTrue(exception.getMessage().contains("SemanticCheckException"));
}

@Test
public void matchquery_in_where() throws IOException {
JSONObject result = executeJdbcRequest("SELECT firstname FROM " + TEST_INDEX_ACCOUNT + " WHERE matchquery(lastname, 'Bates')");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.NamedArgumentExpression;
import org.opensearch.sql.expression.ReferenceExpression;

/**
* Base class to represent builder class for relevance queries like match_query, match_bool_prefix,
Expand All @@ -36,7 +37,7 @@ protected T createQueryBuilder(List<NamedArgumentExpression> arguments) {
.orElseThrow(() -> new SemanticCheckException("'query' parameter is missing"));

return createBuilder(
field.getValue().valueOf().stringValue(),
((ReferenceExpression)field.getValue()).getAttr(),
query.getValue().valueOf().stringValue());
}

Expand Down
Loading