diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 39aa1b8553..c559276688 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -651,7 +651,16 @@ public FunctionExpression castDatetime(Expression value) { } public FunctionExpression match(Expression... args) { - return (FunctionExpression) repository - .compile(BuiltinFunctionName.MATCH.getName(), Arrays.asList(args.clone())); + return compile(BuiltinFunctionName.MATCH, args); + } + + public FunctionExpression match_phrase(Expression... args) { + return compile(BuiltinFunctionName.MATCH_PHRASE, args); } + + private FunctionExpression compile(BuiltinFunctionName bfn, Expression... args) { + return (FunctionExpression) repository.compile(bfn.getName(), Arrays.asList(args.clone())); + } + + } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index c52ff150cd..a36f289024 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -187,6 +187,8 @@ public enum BuiltinFunctionName { * Relevance Function. */ MATCH(FunctionName.of("match")), + MATCH_PHRASE(FunctionName.of("match_phrase")), + MATCHPHRASE(FunctionName.of("matchphrase")), /** * Legacy Relevance Function. diff --git a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java index 9b7325bb59..4b9aefd8e5 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java @@ -7,11 +7,11 @@ import static org.opensearch.sql.data.type.ExprCoreType.STRING; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; -import lombok.ToString; import lombok.experimental.UtilityClass; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; @@ -23,56 +23,49 @@ @UtilityClass public class OpenSearchFunctions { + + public static final int MATCH_MAX_NUM_PARAMETERS = 12; + public static final int MATCH_PHRASE_MAX_NUM_PARAMETERS = 3; + public static final int MIN_NUM_PARAMETERS = 2; + + /** + * Add functions specific to OpenSearch to repository. + */ public void register(BuiltinFunctionRepository repository) { repository.register(match()); + // Register MATCHPHRASE as MATCH_PHRASE as well for backwards + // compatibility. + repository.register(match_phrase(BuiltinFunctionName.MATCH_PHRASE)); + repository.register(match_phrase(BuiltinFunctionName.MATCHPHRASE)); } private static FunctionResolver match() { FunctionName funcName = BuiltinFunctionName.MATCH.getName(); + return getRelevanceFunctionResolver(funcName, MATCH_MAX_NUM_PARAMETERS); + } + + private static FunctionResolver match_phrase(BuiltinFunctionName matchPhrase) { + FunctionName funcName = matchPhrase.getName(); + return getRelevanceFunctionResolver(funcName, MATCH_PHRASE_MAX_NUM_PARAMETERS); + } + + private static FunctionResolver getRelevanceFunctionResolver( + FunctionName funcName, int maxNumParameters) { return new FunctionResolver(funcName, - ImmutableMap.builder() - .put(new FunctionSignature(funcName, ImmutableList.of(STRING, STRING)), - args -> new OpenSearchFunction(funcName, args)) - .put(new FunctionSignature(funcName, ImmutableList.of(STRING, STRING, STRING)), - args -> new OpenSearchFunction(funcName, args)) - .put(new FunctionSignature(funcName, ImmutableList.of(STRING, STRING, STRING, STRING)), - args -> new OpenSearchFunction(funcName, args)) - .put(new FunctionSignature(funcName, ImmutableList - .of(STRING, STRING, STRING, STRING, STRING)), - args -> new OpenSearchFunction(funcName, args)) - .put(new FunctionSignature(funcName, ImmutableList - .of(STRING, STRING, STRING, STRING, STRING, STRING)), - args -> new OpenSearchFunction(funcName, args)) - .put(new FunctionSignature(funcName, ImmutableList - .of(STRING, STRING, STRING, STRING, STRING, STRING, STRING)), - args -> new OpenSearchFunction(funcName, args)) - .put(new FunctionSignature(funcName, ImmutableList - .of(STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING)), - args -> new OpenSearchFunction(funcName, args)) - .put(new FunctionSignature(funcName, ImmutableList - .of(STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING)), - args -> new OpenSearchFunction(funcName, args)) - .put(new FunctionSignature(funcName, ImmutableList - .of(STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING, - STRING)), - args -> new OpenSearchFunction(funcName, args)) - .put(new FunctionSignature(funcName, ImmutableList - .of(STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING, - STRING, STRING)), - args -> new OpenSearchFunction(funcName, args)) - .put(new FunctionSignature(funcName, ImmutableList - .of(STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING, - STRING, STRING, STRING)), - args -> new OpenSearchFunction(funcName, args)) - .put(new FunctionSignature(funcName, ImmutableList - .of(STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING, - STRING, STRING, STRING, STRING)), - args -> new OpenSearchFunction(funcName, args)) - .put(new FunctionSignature(funcName, ImmutableList - .of(STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING, STRING, - STRING, STRING, STRING, STRING, STRING)), - args -> new OpenSearchFunction(funcName, args)) - .build()); + getRelevanceFunctionSignatureMap(funcName, maxNumParameters)); + } + + private static Map getRelevanceFunctionSignatureMap( + FunctionName funcName, int numOptionalParameters) { + FunctionBuilder buildFunction = args -> new OpenSearchFunction(funcName, args); + var signatureMapBuilder = ImmutableMap.builder(); + for (int numParameters = MIN_NUM_PARAMETERS; + numParameters <= MIN_NUM_PARAMETERS + numOptionalParameters; + numParameters++) { + List args = Collections.nCopies(numParameters, STRING); + signatureMapBuilder.put(new FunctionSignature(funcName, args), buildFunction); + } + return signatureMapBuilder.build(); } private static class OpenSearchFunction extends FunctionExpression { diff --git a/core/src/test/java/org/opensearch/sql/expression/function/OpenSearchFunctionsTest.java b/core/src/test/java/org/opensearch/sql/expression/function/OpenSearchFunctionsTest.java index c425be704b..02dbc40545 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/OpenSearchFunctionsTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/OpenSearchFunctionsTest.java @@ -9,12 +9,15 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; +import java.util.List; import org.junit.jupiter.api.Test; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.ExpressionTestBase; import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.NamedArgumentExpression; + + public class OpenSearchFunctionsTest extends ExpressionTestBase { private final NamedArgumentExpression field = new NamedArgumentExpression( "field", DSL.literal("message")); @@ -40,10 +43,14 @@ public class OpenSearchFunctionsTest extends ExpressionTestBase { "operator", DSL.literal("OR")); private final NamedArgumentExpression minimumShouldMatch = new NamedArgumentExpression( "minimum_should_match", DSL.literal("1")); - private final NamedArgumentExpression zeroTermsQuery = new NamedArgumentExpression( - "zero_terms_query", DSL.literal("ALL")); + private final NamedArgumentExpression zeroTermsQueryAll = new NamedArgumentExpression( + "zero_terms_query", DSL.literal("ALL")); + private final NamedArgumentExpression zeroTermsQueryNone = new NamedArgumentExpression( + "zero_terms_query", DSL.literal("None")); private final NamedArgumentExpression boost = new NamedArgumentExpression( "boost", DSL.literal("2.0")); + private final NamedArgumentExpression slop = new NamedArgumentExpression( + "slop", DSL.literal("3")); @Test void match() { @@ -98,16 +105,34 @@ void match() { expr = dsl.match( field, query, analyzer, autoGenerateSynonymsPhrase, fuzziness, maxExpansions, prefixLength, - fuzzyTranspositions, fuzzyRewrite, lenient, operator, minimumShouldMatch, zeroTermsQuery); + fuzzyTranspositions, fuzzyRewrite, lenient, operator, minimumShouldMatch, + zeroTermsQueryAll); assertEquals(BOOLEAN, expr.type()); expr = dsl.match( field, query, analyzer, autoGenerateSynonymsPhrase, fuzziness, maxExpansions, prefixLength, - fuzzyTranspositions, fuzzyRewrite, lenient, operator, minimumShouldMatch, zeroTermsQuery, + fuzzyTranspositions, fuzzyRewrite, lenient, operator, minimumShouldMatch, zeroTermsQueryAll, boost); assertEquals(BOOLEAN, expr.type()); } + @Test + void match_phrase() { + for (FunctionExpression expr : match_phrase_dsl_expressions()) { + assertEquals(BOOLEAN, expr.type()); + } + } + + + List match_phrase_dsl_expressions() { + return List.of( + dsl.match_phrase(field, query), + dsl.match_phrase(field, query, analyzer), + dsl.match_phrase(field, query, analyzer, zeroTermsQueryAll), + dsl.match_phrase(field, query, analyzer, zeroTermsQueryNone, slop) + ); + } + @Test void match_in_memory() { FunctionExpression expr = dsl.match(field, query); diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index 188c326f6d..cde48ae25d 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -2195,3 +2195,43 @@ Another example to show how to set custom values for the optional parameters:: | Bond | +------------+ +MATCH_PHRASE +----- + +Description +>>>>>>>>>>> + +``match_phrase(field_expression, query_expression[, option=]*)`` + +The match_phrase function maps to the match_phrase query used in search engine, to return the documents that match a provided text with a given field. Available parameters include: + +- analyzer +- slop +- zero_terms_query + +For backward compatibility, matchphrase is also supported and mapped to match_phrase query as well. + +Example with only ``field`` and ``query`` expressions, and all other parameters are set default values:: + + os> SELECT author, title FROM books WHERE match_phrase(author, 'Alexander Milne'); + fetched rows / total rows = 2/2 + +----------------------+--------------------------+ + | author | title | + |----------------------+--------------------------| + | Alan Alexander Milne | The House at Pooh Corner | + | Alan Alexander Milne | Winnie-the-Pooh | + +----------------------+--------------------------+ + + + +Another example to show how to set custom values for the optional parameters:: + + os> SELECT author, title FROM books WHERE match_phrase(author, 'Alan Milne', slop = 2); + fetched rows / total rows = 2/2 + +----------------------+--------------------------+ + | author | title | + |----------------------+--------------------------| + | Alan Alexander Milne | The House at Pooh Corner | + | Alan Alexander Milne | Winnie-the-Pooh | + +----------------------+--------------------------+ + diff --git a/docs/user/ppl/functions/relevance.rst b/docs/user/ppl/functions/relevance.rst index 0b00f382ec..204e942e70 100644 --- a/docs/user/ppl/functions/relevance.rst +++ b/docs/user/ppl/functions/relevance.rst @@ -56,6 +56,46 @@ Another example to show how to set custom values for the optional parameters:: | Bond | +------------+ +MATCH_PHRASE +----- + +Description +>>>>>>>>>>> + +``match_phrase(field_expression, query_expression[, option=]*)`` + +The match_phrase function maps to the match_phrase query used in search engine, to return the documents that match a provided text with a given field. Available parameters include: + +- analyzer +- slop +- zero_terms_query + +For backward compatibility, matchphrase is also supported and mapped to match_phrase query as well. + +Example with only ``field`` and ``query`` expressions, and all other parameters are set default values:: + + os> source=books | where match_phrase(author, 'Alexander Milne') | fields author, title + fetched rows / total rows = 2/2 + +----------------------+--------------------------+ + | author | title | + |----------------------+--------------------------| + | Alan Alexander Milne | The House at Pooh Corner | + | Alan Alexander Milne | Winnie-the-Pooh | + +----------------------+--------------------------+ + + + +Another example to show how to set custom values for the optional parameters:: + + os> source=books | where match_phrase(author, 'Alan Milne', slop = 2) | fields author, title + fetched rows / total rows = 2/2 + +----------------------+--------------------------+ + | author | title | + |----------------------+--------------------------| + | Alan Alexander Milne | The House at Pooh Corner | + | Alan Alexander Milne | Winnie-the-Pooh | + +----------------------+--------------------------+ + Limitations >>>>>>>>>>> diff --git a/doctest/test_data/books.json b/doctest/test_data/books.json new file mode 100644 index 0000000000..f0bb81e9e3 --- /dev/null +++ b/doctest/test_data/books.json @@ -0,0 +1,2 @@ +{"id": 1, "author": "Alan Alexander Milne", "title": "The House at Pooh Corner"} +{"id": 2, "author": "Alan Alexander Milne", "title": "Winnie-the-Pooh"} diff --git a/doctest/test_docs.py b/doctest/test_docs.py index c6c07d1d0d..76af1fa9e7 100644 --- a/doctest/test_docs.py +++ b/doctest/test_docs.py @@ -24,6 +24,7 @@ PEOPLE = "people" ACCOUNT2 = "account2" NYC_TAXI = "nyc_taxi" +BOOKS = "books" class DocTestConnection(OpenSearchConnection): @@ -88,6 +89,7 @@ def set_up_test_indices(test): load_file("people.json", index_name=PEOPLE) load_file("account2.json", index_name=ACCOUNT2) load_file("nyc_taxi.json", index_name=NYC_TAXI) + load_file("books.json", index_name=BOOKS) def load_file(filename, index_name): @@ -116,7 +118,7 @@ def set_up(test): def tear_down(test): # drop leftover tables after each test - test_data_client.indices.delete(index=[ACCOUNTS, EMPLOYEES, PEOPLE, ACCOUNT2, NYC_TAXI], ignore_unavailable=True) + test_data_client.indices.delete(index=[ACCOUNTS, EMPLOYEES, PEOPLE, ACCOUNT2, NYC_TAXI, BOOKS], ignore_unavailable=True) docsuite = partial(doctest.DocFileSuite, diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/WhereCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/WhereCommandIT.java index 5415b6e286..e5f8f35448 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/WhereCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/WhereCommandIT.java @@ -9,6 +9,7 @@ import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_ACCOUNT; import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK; import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK_WITH_NULL_VALUES; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_PHRASE; import static org.opensearch.sql.util.MatcherUtils.rows; import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; @@ -23,6 +24,8 @@ public void init() throws IOException { loadIndex(Index.ACCOUNT); loadIndex(Index.BANK_WITH_NULL_VALUES); loadIndex(Index.BANK); + loadIndex(Index.GAME_OF_THRONES); + loadIndex(Index.PHRASE); } @Test @@ -110,4 +113,22 @@ public void testRelevanceFunction() throws IOException { TEST_INDEX_BANK)); verifyDataRows(result, rows("Hattie")); } + + @Test + public void testMatchPhraseFunction() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | where match_phrase(phrase, 'quick fox') | fields phrase", TEST_INDEX_PHRASE)); + verifyDataRows(result, rows("quick fox"), rows("quick fox here")); + } + + @Test + public void testMathPhraseWithSlop() throws IOException { + JSONObject result = + executeQuery( + String.format( + "source=%s | where match_phrase(phrase, 'brown fox', slop = 2) | fields phrase", TEST_INDEX_PHRASE)); + verifyDataRows(result, rows("brown fox"), rows("fox brown")); + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilder.java index 94179c3369..ec54e8854e 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilder.java @@ -29,6 +29,7 @@ import org.opensearch.sql.opensearch.storage.script.filter.lucene.RangeQuery.Comparison; import org.opensearch.sql.opensearch.storage.script.filter.lucene.TermQuery; import org.opensearch.sql.opensearch.storage.script.filter.lucene.WildcardQuery; +import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.MatchPhraseQuery; import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.MatchQuery; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @@ -52,6 +53,8 @@ public class FilterQueryBuilder extends ExpressionNodeVisitor { + /** + * Default constructor for MatchPhraseQuery configures how RelevanceQuery.build() handles + * named arguments. + */ + public MatchPhraseQuery() { + super(ImmutableMap.>builder() + .put("analyzer", (b, v) -> b.analyzer(v.stringValue())) + .put("slop", (b, v) -> b.slop(Integer.parseInt(v.stringValue()))) + .put("zero_terms_query", (b, v) -> b.zeroTermsQuery( + org.opensearch.index.search.MatchQuery.ZeroTermsQuery.valueOf(v.stringValue()))) + .build()); + } + + @Override + protected MatchPhraseQueryBuilder createQueryBuilder(String field, String query) { + return QueryBuilders.matchPhraseQuery(field, query); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchQuery.java index a1ed67cd9b..c69b43cbcb 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchQuery.java @@ -6,82 +6,41 @@ package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; import com.google.common.collect.ImmutableMap; -import java.util.Iterator; -import java.util.function.BiFunction; +import java.util.Map; import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.Operator; -import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; -import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.exception.SemanticCheckException; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.NamedArgumentExpression; -import org.opensearch.sql.opensearch.storage.script.filter.lucene.LuceneQuery; -public class MatchQuery extends LuceneQuery { - private final BiFunction analyzer = - (b, v) -> b.analyzer(v.stringValue()); - private final BiFunction synonymsPhrase = - (b, v) -> b.autoGenerateSynonymsPhraseQuery(Boolean.parseBoolean(v.stringValue())); - private final BiFunction fuzziness = - (b, v) -> b.fuzziness(v.stringValue()); - private final BiFunction maxExpansions = - (b, v) -> b.maxExpansions(Integer.parseInt(v.stringValue())); - private final BiFunction prefixLength = - (b, v) -> b.prefixLength(Integer.parseInt(v.stringValue())); - private final BiFunction fuzzyTranspositions = - (b, v) -> b.fuzzyTranspositions(Boolean.parseBoolean(v.stringValue())); - private final BiFunction fuzzyRewrite = - (b, v) -> b.fuzzyRewrite(v.stringValue()); - private final BiFunction lenient = - (b, v) -> b.lenient(Boolean.parseBoolean(v.stringValue())); - private final BiFunction operator = - (b, v) -> b.operator(Operator.fromString(v.stringValue())); - private final BiFunction minimumShouldMatch = - (b, v) -> b.minimumShouldMatch(v.stringValue()); - private final BiFunction zeroTermsQuery = - (b, v) -> b.zeroTermsQuery( - org.opensearch.index.search.MatchQuery.ZeroTermsQuery.valueOf(v.stringValue())); - private final BiFunction boost = - (b, v) -> b.boost(Float.parseFloat(v.stringValue())); - - ImmutableMap argAction = ImmutableMap.builder() - .put("analyzer", analyzer) - .put("auto_generate_synonyms_phrase_query", synonymsPhrase) - .put("fuzziness", fuzziness) - .put("max_expansions", maxExpansions) - .put("prefix_length", prefixLength) - .put("fuzzy_transpositions", fuzzyTranspositions) - .put("fuzzy_rewrite", fuzzyRewrite) - .put("lenient", lenient) - .put("operator", operator) - .put("minimum_should_match", minimumShouldMatch) - .put("zero_terms_query", zeroTermsQuery) - .put("boost", boost) - .build(); +/** + * Initializes MatchQueryBuilder from a FunctionExpression. + */ +public class MatchQuery extends RelevanceQuery { + /** + * Default constructor for MatchQuery configures how RelevanceQuery.build() handles + * named arguments. + */ + public MatchQuery() { + super(ImmutableMap.>builder() + .put("analyzer", (b, v) -> b.analyzer(v.stringValue())) + .put("auto_generate_synonyms_phrase_query", + (b, v) -> b.autoGenerateSynonymsPhraseQuery(Boolean.parseBoolean(v.stringValue()))) + .put("fuzziness", (b, v) -> b.fuzziness(v.stringValue())) + .put("max_expansions", (b, v) -> b.maxExpansions(Integer.parseInt(v.stringValue()))) + .put("prefix_length", (b, v) -> b.prefixLength(Integer.parseInt(v.stringValue()))) + .put("fuzzy_transpositions", + (b, v) -> b.fuzzyTranspositions(Boolean.parseBoolean(v.stringValue()))) + .put("fuzzy_rewrite", (b, v) -> b.fuzzyRewrite(v.stringValue())) + .put("lenient", (b, v) -> b.lenient(Boolean.parseBoolean(v.stringValue()))) + .put("operator", (b, v) -> b.operator(Operator.fromString(v.stringValue()))) + .put("minimum_should_match", (b, v) -> b.minimumShouldMatch(v.stringValue())) + .put("zero_terms_query", (b, v) -> b.zeroTermsQuery( + org.opensearch.index.search.MatchQuery.ZeroTermsQuery.valueOf(v.stringValue()))) + .put("boost", (b, v) -> b.boost(Float.parseFloat(v.stringValue()))) + .build()); + } @Override - public QueryBuilder build(FunctionExpression func) { - if (func.getArguments().size() < 2) { - throw new SemanticCheckException("match must have at least two arguments"); - } - Iterator iterator = func.getArguments().iterator(); - NamedArgumentExpression field = (NamedArgumentExpression) iterator.next(); - NamedArgumentExpression query = (NamedArgumentExpression) iterator.next(); - MatchQueryBuilder queryBuilder = QueryBuilders.matchQuery( - field.getValue().valueOf(null).stringValue(), - query.getValue().valueOf(null).stringValue()); - while (iterator.hasNext()) { - NamedArgumentExpression arg = (NamedArgumentExpression) iterator.next(); - if (!argAction.containsKey(arg.getArgName())) { - throw new SemanticCheckException(String - .format("Parameter %s is invalid for match function.", arg.getArgName())); - } - ((BiFunction) argAction - .get(arg.getArgName())) - .apply(queryBuilder, arg.getValue().valueOf(null)); - } - return queryBuilder; + protected MatchQueryBuilder createQueryBuilder(String field, String query) { + return QueryBuilders.matchQuery(field, query); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQuery.java new file mode 100644 index 0000000000..fb0852c18b --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQuery.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.BiFunction; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.opensearch.storage.script.filter.lucene.LuceneQuery; + +/** + * Base class for query abstraction that builds a relevance query from function expression. + */ +public abstract class RelevanceQuery extends LuceneQuery { + protected Map> queryBuildActions; + + protected RelevanceQuery(Map> actionMap) { + queryBuildActions = actionMap; + } + + @Override + public QueryBuilder build(FunctionExpression func) { + List arguments = func.getArguments(); + if (arguments.size() < 2) { + String queryName = createQueryBuilder("dummy_field", "").getWriteableName(); + throw new SyntaxCheckException( + String.format("%s requires at least two parameters", queryName)); + } + NamedArgumentExpression field = (NamedArgumentExpression) arguments.get(0); + NamedArgumentExpression query = (NamedArgumentExpression) arguments.get(1); + T queryBuilder = createQueryBuilder( + field.getValue().valueOf(null).stringValue(), + query.getValue().valueOf(null).stringValue()); + + Iterator iterator = arguments.listIterator(2); + while (iterator.hasNext()) { + NamedArgumentExpression arg = (NamedArgumentExpression) iterator.next(); + if (!queryBuildActions.containsKey(arg.getArgName())) { + throw new SemanticCheckException( + String.format("Parameter %s is invalid for %s function.", + arg.getArgName(), queryBuilder.getWriteableName())); + } + (Objects.requireNonNull( + queryBuildActions + .get(arg.getArgName()))) + .apply(queryBuilder, arg.getValue().valueOf(null)); + } + return queryBuilder; + } + + protected abstract T createQueryBuilder(String field, String query); + + /** + * Convenience interface for a function that updates a QueryBuilder + * based on ExprValue. + * @param Concrete query builder + */ + public interface QueryBuilderStep extends + BiFunction { + + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java index 15b4ad2bad..9bc6ed5076 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.opensearch.storage.script.filter; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; @@ -45,6 +46,7 @@ import org.opensearch.sql.data.model.ExprTimeValue; import org.opensearch.sql.data.model.ExprTimestampValue; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; @@ -361,8 +363,117 @@ void match_invalid_parameter() { dsl.namedArgument("field", literal("message")), dsl.namedArgument("query", literal("search query")), dsl.namedArgument("invalid_parameter", literal("invalid_value"))); - assertThrows(SemanticCheckException.class, () -> buildQuery(expr), - "Parameter invalid_parameter is invalid for match function."); + var msg = assertThrows(SemanticCheckException.class, () -> buildQuery(expr)).getMessage(); + assertEquals("Parameter invalid_parameter is invalid for match function.", msg); + } + + @Test + void should_build_match_phrase_query_with_default_parameters() { + assertJsonEquals( + "{\n" + + " \"match_phrase\" : {\n" + + " \"message\" : {\n" + + " \"query\" : \"search query\",\n" + + " \"slop\" : 0,\n" + + " \"zero_terms_query\" : \"NONE\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + dsl.match_phrase( + dsl.namedArgument("field", literal("message")), + dsl.namedArgument("query", literal("search query"))))); + } + + @Test + void should_build_match_phrase_query_with_custom_parameters() { + assertJsonEquals( + "{\n" + + " \"match_phrase\" : {\n" + + " \"message\" : {\n" + + " \"query\" : \"search query\",\n" + + " \"analyzer\" : \"keyword\"," + + " \"slop\" : 2,\n" + + " \"zero_terms_query\" : \"ALL\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery( + dsl.match_phrase( + dsl.namedArgument("field", literal("message")), + dsl.namedArgument("query", literal("search query")), + dsl.namedArgument("analyzer", literal("keyword")), + dsl.namedArgument("slop", literal("2")), + dsl.namedArgument("zero_terms_query", literal("ALL"))))); + } + + @Test + void match_phrase_invalid_parameter() { + FunctionExpression expr = dsl.match_phrase( + dsl.namedArgument("field", literal("message")), + dsl.namedArgument("query", literal("search query")), + dsl.namedArgument("invalid_parameter", literal("invalid_value"))); + var msg = assertThrows(SemanticCheckException.class, () -> buildQuery(expr)).getMessage(); + assertEquals("Parameter invalid_parameter is invalid for match_phrase function.", msg); + } + + @Test + void match_phrase_invalid_value_slop() { + FunctionExpression expr = dsl.match_phrase( + dsl.namedArgument("field", literal("message")), + dsl.namedArgument("query", literal("search query")), + dsl.namedArgument("slop", literal("1.5"))); + var msg = assertThrows(NumberFormatException.class, () -> buildQuery(expr)).getMessage(); + assertEquals("For input string: \"1.5\"", msg); + } + + @Test + void match_phrase_invalid_value_ztq() { + FunctionExpression expr = dsl.match_phrase( + dsl.namedArgument("field", literal("message")), + dsl.namedArgument("query", literal("search query")), + dsl.namedArgument("zero_terms_query", literal("meow"))); + var msg = assertThrows(IllegalArgumentException.class, () -> buildQuery(expr)).getMessage(); + assertEquals("No enum constant org.opensearch.index.search.MatchQuery.ZeroTermsQuery.meow", + msg); + } + + @Test + void match_phrase_missing_field() { + var msg = assertThrows(ExpressionEvaluationException.class, () -> + dsl.match_phrase( + dsl.namedArgument("query", literal("search query")))).getMessage(); + assertEquals("match_phrase function expected {[STRING,STRING],[STRING,STRING,STRING]," + + "[STRING,STRING,STRING,STRING],[STRING,STRING,STRING,STRING,STRING]}, but get [STRING]", + msg); + } + + @Test + void match_phrase_missing_query() { + var msg = assertThrows(ExpressionEvaluationException.class, () -> + dsl.match_phrase( + dsl.namedArgument("field", literal("message")))).getMessage(); + assertEquals("match_phrase function expected {[STRING,STRING],[STRING,STRING,STRING]," + + "[STRING,STRING,STRING,STRING],[STRING,STRING,STRING,STRING,STRING]}, but get [STRING]", + msg); + } + + @Test + void match_phrase_too_many_args() { + var msg = assertThrows(ExpressionEvaluationException.class, () -> + dsl.match_phrase( + dsl.namedArgument("one", literal("1")), + dsl.namedArgument("two", literal("2")), + dsl.namedArgument("three", literal("3")), + dsl.namedArgument("four", literal("4")), + dsl.namedArgument("fix", literal("5")), + dsl.namedArgument("six", literal("6")) + )).getMessage(); + assertEquals("match_phrase function expected {[STRING,STRING],[STRING,STRING,STRING]," + + "[STRING,STRING,STRING,STRING],[STRING,STRING,STRING,STRING,STRING]}, but get " + + "[STRING,STRING,STRING,STRING,STRING,STRING]", msg); } @Test diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java new file mode 100644 index 0000000000..fef3d64f95 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java @@ -0,0 +1,117 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.filter.lucene; + + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.List; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.expression.config.ExpressionConfig; +import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.MatchPhraseQuery; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class MatchPhraseQueryTest { + + private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); + private final MatchPhraseQuery matchPhraseQuery = new MatchPhraseQuery(); + private final FunctionName matchPhrase = FunctionName.of("match_phrase"); + + private NamedArgumentExpression namedArgument(String name, String value) { + return dsl.namedArgument(name, DSL.literal(value)); + } + + @Test + public void test_SyntaxCheckException_when_no_arguments() { + List arguments = List.of(); + assertThrows(SyntaxCheckException.class, + () -> matchPhraseQuery.build(new MatchPhraseExpression(arguments))); + } + + @Test + public void test_SyntaxCheckException_when_one_argument() { + List arguments = List.of(namedArgument("field", "test")); + assertThrows(SyntaxCheckException.class, + () -> matchPhraseQuery.build(new MatchPhraseExpression(arguments))); + } + + @Test + public void test_SyntaxCheckException_when_invalid_parameter() { + List arguments = List.of( + namedArgument("field", "test"), + namedArgument("query", "test2"), + namedArgument("unsupported", "3")); + Assertions.assertThrows(SemanticCheckException.class, + () -> matchPhraseQuery.build(new MatchPhraseExpression(arguments))); + } + + @Test + public void test_analyzer_parameter() { + List arguments = List.of( + namedArgument("field", "t1"), + namedArgument("query", "t2"), + namedArgument("analyzer", "standard") + ); + Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); + } + + @Test + public void build_succeeds_with_two_arguments() { + List arguments = List.of( + namedArgument("field", "test"), + namedArgument("query", "test2")); + Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); + } + + @Test + public void test_slop_parameter() { + List arguments = List.of( + namedArgument("field", "t1"), + namedArgument("query", "t2"), + namedArgument("slop", "2") + ); + Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); + } + + @Test + public void test_zero_terms_query_parameter() { + List arguments = List.of( + namedArgument("field", "t1"), + namedArgument("query", "t2"), + namedArgument("zero_terms_query", "ALL") + ); + Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); + } + + private class MatchPhraseExpression extends FunctionExpression { + public MatchPhraseExpression(List arguments) { + super(MatchPhraseQueryTest.this.matchPhrase, arguments); + } + + @Override + public ExprValue valueOf(Environment valueEnv) { + return null; + } + + @Override + public ExprType type() { + return null; + } + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchQueryTest.java index 7e3e5e0862..99cf132a3e 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchQueryTest.java @@ -15,6 +15,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.exception.SemanticCheckException; @@ -110,16 +111,16 @@ public void test_valid_parameters(List validArgs) { } @Test - public void test_SemanticCheckException_when_no_arguments() { + public void test_SyntaxCheckException_when_no_arguments() { List arguments = List.of(); - assertThrows(SemanticCheckException.class, + assertThrows(SyntaxCheckException.class, () -> matchQuery.build(new MatchExpression(arguments))); } @Test - public void test_SemanticCheckException_when_one_argument() { + public void test_SyntaxCheckException_when_one_argument() { List arguments = List.of(namedArgument("field", "field_value")); - assertThrows(SemanticCheckException.class, + assertThrows(SyntaxCheckException.class, () -> matchQuery.build(new MatchExpression(arguments))); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQueryBuildTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQueryBuildTest.java new file mode 100644 index 0000000000..1186031f5f --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQueryBuildTest.java @@ -0,0 +1,122 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; + +import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + +import com.google.common.collect.ImmutableMap; +import java.util.List; +import java.util.stream.Stream; +import org.apache.commons.lang3.NotImplementedException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.LiteralExpression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.FunctionName; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +class RelevanceQueryBuildTest { + + public static final NamedArgumentExpression FIELD_ARG = namedArgument("field", "field_A"); + public static final NamedArgumentExpression QUERY_ARG = namedArgument("query", "find me"); + private RelevanceQuery query; + private QueryBuilder queryBuilder; + + @BeforeEach + public void setUp() { + query = mock(RelevanceQuery.class, withSettings().useConstructor( + ImmutableMap.>builder() + .put("boost", (k, v) -> k.boost(Float.parseFloat(v.stringValue()))).build()) + .defaultAnswer(Mockito.CALLS_REAL_METHODS)); + queryBuilder = mock(QueryBuilder.class); + when(query.createQueryBuilder(any(), any())).thenReturn(queryBuilder); + when(queryBuilder.queryName()).thenReturn("mocked_query"); + when(queryBuilder.getWriteableName()).thenReturn("mock_query"); + } + + @Test + void first_arg_field_second_arg_query_test() { + query.build(createCall(List.of(FIELD_ARG, QUERY_ARG))); + verify(query, times(1)).createQueryBuilder("field_A", "find me"); + } + + @Test + void throws_SemanticCheckException_when_wrong_argument_name() { + FunctionExpression expr = + createCall(List.of(FIELD_ARG, QUERY_ARG, namedArgument("wrongArg", "value"))); + + SemanticCheckException exception = + assertThrows(SemanticCheckException.class, () -> query.build(expr)); + assertEquals("Parameter wrongArg is invalid for mock_query function.", exception.getMessage()); + } + + @Test + void calls_action_when_correct_argument_name() { + FunctionExpression expr = + createCall(List.of(FIELD_ARG, QUERY_ARG, namedArgument("boost", "2.3"))); + query.build(expr); + + verify(queryBuilder, times(1)).boost(2.3f); + } + + @ParameterizedTest + @MethodSource("insufficientArguments") + public void throws_SyntaxCheckException_when_no_required_arguments(List arguments) { + SyntaxCheckException exception = assertThrows(SyntaxCheckException.class, + () -> query.build(createCall(arguments))); + assertEquals("mock_query requires at least two parameters", exception.getMessage()); + } + + public static Stream> insufficientArguments() { + return Stream.of(List.of(), + List.of(namedArgument("field", "field_A"))); + } + + private static NamedArgumentExpression namedArgument(String field, String fieldValue) { + return new NamedArgumentExpression(field, createLiteral(fieldValue)); + } + + @Test + private static Expression createLiteral(String value) { + return new LiteralExpression(new ExprStringValue(value)); + } + + private static FunctionExpression createCall(List arguments) { + return new FunctionExpression(new FunctionName("mock_function"), arguments) { + @Override + public ExprValue valueOf(Environment valueEnv) { + throw new NotImplementedException("FunctionExpression.valueOf"); + } + + @Override + public ExprType type() { + throw new NotImplementedException("FunctionExpression.type"); + } + }; + } +} diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index aee51d0a10..be4ad272d1 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -264,6 +264,7 @@ IF: 'IF'; // RELEVANCE FUNCTIONS AND PARAMETERS MATCH: 'MATCH'; +MATCH_PHRASE: 'MATCH_PHRASE'; ANALYZER: 'ANALYZER'; FUZZINESS: 'FUZZINESS'; AUTO_GENERATE_SYNONYMS_PHRASE_QUERY:'AUTO_GENERATE_SYNONYMS_PHRASE_QUERY'; @@ -276,7 +277,7 @@ OPERATOR: 'OPERATOR'; MINIMUM_SHOULD_MATCH: 'MINIMUM_SHOULD_MATCH'; ZERO_TERMS_QUERY: 'ZERO_TERMS_QUERY'; BOOST: 'BOOST'; - +SLOP: 'SLOP'; // SPAN KEYWORDS SPAN: 'SPAN'; MS: 'MS'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 100547ae7a..11a42d8d0b 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -171,7 +171,7 @@ statsFunctionName ; percentileAggFunction - : PERCENTILE '<' value=integerLiteral '>' LT_PRTHS aggField=fieldExpression RT_PRTHS + : PERCENTILE LESS value=integerLiteral GREATER LT_PRTHS aggField=fieldExpression RT_PRTHS ; /** expressions */ @@ -305,7 +305,7 @@ relevanceArg relevanceArgName : ANALYZER | FUZZINESS | AUTO_GENERATE_SYNONYMS_PHRASE_QUERY | MAX_EXPANSIONS | PREFIX_LENGTH | FUZZY_TRANSPOSITIONS | FUZZY_REWRITE | LENIENT | OPERATOR | MINIMUM_SHOULD_MATCH | ZERO_TERMS_QUERY - | BOOST + | BOOST | SLOP ; relevanceArgValue @@ -351,6 +351,7 @@ binaryOperator relevanceFunctionName : MATCH + | MATCH_PHRASE ; /** literals and values*/ diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserMatchPhraseSamplesTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserMatchPhraseSamplesTest.java new file mode 100644 index 0000000000..a4fbee44e3 --- /dev/null +++ b/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserMatchPhraseSamplesTest.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.antlr; + +import static org.junit.Assert.assertNotEquals; + +import java.util.List; +import org.antlr.v4.runtime.tree.ParseTree; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + + +@RunWith(Parameterized.class) +public class PPLSyntaxParserMatchPhraseSamplesTest { + + + /** Returns sample queries that the PPLSyntaxParser is expected to parse successfully. + * @return an Iterable of sample queries. + */ + @Parameterized.Parameters(name = "{0}") + public static Iterable sampleQueries() { + return List.of( + "source=t a= 1 | where match_phrase(a, 'hello world')", + "source=t a = 1 | where match_phrase(a, 'hello world', slop = 3)", + "source=t a = 1 | where match_phrase(a, 'hello world', analyzer = 'standard'," + + "zero_terms_query = 'none', slop = 3)", + "source=t a = 1 | where match_phrase(a, 'hello world', zero_terms_query = all)"); + } + + private final String query; + + public PPLSyntaxParserMatchPhraseSamplesTest(String query) { + this.query = query; + } + + @Test + public void test() { + ParseTree tree = new PPLSyntaxParser().analyzeSyntax(query); + assertNotEquals(null, tree); + } +} diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java index 3d5f8d453c..89f608ebe5 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/antlr/PPLSyntaxParserTest.java @@ -80,3 +80,4 @@ public void testTopCommandWithoutNAndGroupByShouldPass() { assertNotEquals(null, tree); } } + diff --git a/sql/src/main/antlr/OpenSearchSQLLexer.g4 b/sql/src/main/antlr/OpenSearchSQLLexer.g4 index ed2d7e8160..902a8ed190 100644 --- a/sql/src/main/antlr/OpenSearchSQLLexer.g4 +++ b/sql/src/main/antlr/OpenSearchSQLLexer.g4 @@ -310,18 +310,33 @@ STRCMP: 'STRCMP'; ADDDATE: 'ADDDATE'; // RELEVANCE FUNCTIONS AND PARAMETERS +ALLOW_LEADING_WILDCARD: 'ALLOW_LEADING_WILDCARD'; +ANALYZE_WILDCARD: 'ANALYZE_WILDCARD'; ANALYZER: 'ANALYZER'; -FUZZINESS: 'FUZZINESS'; AUTO_GENERATE_SYNONYMS_PHRASE_QUERY:'AUTO_GENERATE_SYNONYMS_PHRASE_QUERY'; -MAX_EXPANSIONS: 'MAX_EXPANSIONS'; -PREFIX_LENGTH: 'PREFIX_LENGTH'; +BOOST: 'BOOST'; +CUTOFF_FREQUENCY: 'CUTOFF_FREQUENCY'; +ENABLE_POSITION_INCREMENTS: 'ENABLE_POSITION_INCREMENTS'; +FIELDS: 'FIELDS'; +FLAGS: 'FLAGS'; +FUZZINESS: 'FUZZINESS'; FUZZY_TRANSPOSITIONS: 'FUZZY_TRANSPOSITIONS'; FUZZY_REWRITE: 'FUZZY_REWRITE'; LENIENT: 'LENIENT'; -OPERATOR: 'OPERATOR'; +LOW_FREQ_OPERATOR: 'LOW_FREQ_OPERATOR'; +MAX_DETERMINIZED_STATES: 'MAX_DETERMINIZED_STATES'; +MAX_EXPANSIONS: 'MAX_EXPANSIONS'; MINIMUM_SHOULD_MATCH: 'MINIMUM_SHOULD_MATCH'; +OPERATOR: 'OPERATOR'; +PHRASE_SLOP: 'PHRASE_SLOP'; +PREFIX_LENGTH: 'PREFIX_LENGTH'; +QUOTE_FIELD_SUFFIX: 'QUOTE_FIELD_SUFFIX'; +REWRITE: 'REWRITE'; +SLOP: 'SLOP'; +TIE_BREAKER: 'TIE_BREAKER'; +TIME_ZONE: 'TIME_ZONE'; +TYPE: 'TYPE'; ZERO_TERMS_QUERY: 'ZERO_TERMS_QUERY'; -BOOST: 'BOOST'; // Operators diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 0b8f3c5250..85ebe2703b 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -383,7 +383,7 @@ flowControlFunctionName ; relevanceFunctionName - : MATCH + : MATCH | MATCH_PHRASE | MATCHPHRASE ; legacyRelevanceFunctionName @@ -403,9 +403,11 @@ relevanceArg ; relevanceArgName - : ANALYZER | FUZZINESS | AUTO_GENERATE_SYNONYMS_PHRASE_QUERY | MAX_EXPANSIONS | PREFIX_LENGTH - | FUZZY_TRANSPOSITIONS | FUZZY_REWRITE | LENIENT | OPERATOR | MINIMUM_SHOULD_MATCH | ZERO_TERMS_QUERY - | BOOST + : ALLOW_LEADING_WILDCARD | ANALYZE_WILDCARD | ANALYZER | AUTO_GENERATE_SYNONYMS_PHRASE_QUERY | BOOST + | CUTOFF_FREQUENCY | ENABLE_POSITION_INCREMENTS | FIELDS | FLAGS | FUZZINESS | FUZZY_TRANSPOSITIONS + | FUZZY_REWRITE | LENIENT | LOW_FREQ_OPERATOR | MAX_DETERMINIZED_STATES | MAX_EXPANSIONS | MINIMUM_SHOULD_MATCH + | OPERATOR | PHRASE_SLOP | PREFIX_LENGTH | QUOTE_FIELD_SUFFIX | REWRITE | SLOP | TIE_BREAKER | TIME_ZONE + | TYPE | ZERO_TERMS_QUERY ; relevanceArgValue diff --git a/sql/src/test/java/org/opensearch/sql/sql/antlr/SQLSyntaxParserTest.java b/sql/src/test/java/org/opensearch/sql/sql/antlr/SQLSyntaxParserTest.java index e7cb22e8a2..48107b1744 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/antlr/SQLSyntaxParserTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/antlr/SQLSyntaxParserTest.java @@ -9,7 +9,17 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import com.google.common.collect.Streams; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Random; +import java.util.stream.Stream; +import org.apache.commons.lang3.RandomStringUtils; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.opensearch.sql.common.antlr.SyntaxCheckException; class SQLSyntaxParserTest { @@ -144,4 +154,130 @@ public void canNotParseShowStatementWithoutFilterClause() { assertThrows(SyntaxCheckException.class, () -> parser.parse("SHOW TABLES")); } + @Test + public void canParseRelevanceFunctions() { + assertNotNull(parser.parse("SELECT * FROM test WHERE match(column, \"this is a test\")")); + assertNotNull(parser.parse("SELECT * FROM test WHERE match(column, 'this is a test')")); + assertNotNull(parser.parse("SELECT * FROM test WHERE match(`column`, \"this is a test\")")); + assertNotNull(parser.parse("SELECT * FROM test WHERE match(`column`, 'this is a test')")); + assertNotNull(parser.parse("SELECT * FROM test WHERE match(column, 100500)")); + + assertNotNull( + parser.parse("SELECT * FROM test WHERE match_phrase(column, \"this is a test\")")); + assertNotNull(parser.parse("SELECT * FROM test WHERE match_phrase(column, 'this is a test')")); + assertNotNull( + parser.parse("SELECT * FROM test WHERE match_phrase(`column`, \"this is a test\")")); + assertNotNull( + parser.parse("SELECT * FROM test WHERE match_phrase(`column`, 'this is a test')")); + assertNotNull(parser.parse("SELECT * FROM test WHERE match_phrase(column, 100500)")); + } + + @ParameterizedTest + @MethodSource({"matchPhraseComplexQueries", + "matchPhraseGeneratedQueries", "generateMatchPhraseQueries"}) + public void canParseComplexMatchPhraseArgsTest(String query) { + assertNotNull(parser.parse(query)); + } + + private static Stream matchPhraseComplexQueries() { + return Stream.of( + "SELECT * FROM t WHERE match_phrase(c, 3)", + "SELECT * FROM t WHERE match_phrase(c, 3, fuzziness=AUTO)", + "SELECT * FROM t WHERE match_phrase(c, 3, zero_terms_query=\"all\")", + "SELECT * FROM t WHERE match_phrase(c, 3, lenient=true)", + "SELECT * FROM t WHERE match_phrase(c, 3, lenient='true')", + "SELECT * FROM t WHERE match_phrase(c, 3, operator=xor)", + "SELECT * FROM t WHERE match_phrase(c, 3, cutoff_frequency=0.04)", + "SELECT * FROM t WHERE match_phrase(c, 3, cutoff_frequency=0.04, analyzer = english, " + + "prefix_length=34, fuzziness='auto', minimum_should_match='2<-25% 9<-3')", + "SELECT * FROM t WHERE match_phrase(c, 3, minimum_should_match='2<-25% 9<-3')", + "SELECT * FROM t WHERE match_phrase(c, 3, operator='AUTO')" + ); + } + + private static Stream matchPhraseGeneratedQueries() { + var matchArgs = new HashMap(); + matchArgs.put("fuzziness", new String[]{ "AUTO", "AUTO:1,5", "1" }); + matchArgs.put("fuzzy_transpositions", new Boolean[]{ true, false }); + matchArgs.put("operator", new String[]{ "and", "or" }); + matchArgs.put("minimum_should_match", + new String[]{ "3", "-2", "75%", "-25%", "3<90%", "2<-25% 9<-3" }); + matchArgs.put("analyzer", new String[]{ "standard", "stop", "english" }); + matchArgs.put("zero_terms_query", new String[]{ "none", "all" }); + matchArgs.put("lenient", new Boolean[]{ true, false }); + // deprecated + matchArgs.put("cutoff_frequency", new Double[]{ .0, 0.001, 1., 42. }); + matchArgs.put("prefix_length", new Integer[]{ 0, 2, 5 }); + matchArgs.put("max_expansions", new Integer[]{ 0, 5, 20 }); + matchArgs.put("boost", new Double[]{ .5, 1., 2.3 }); + + return generateQueries("match", matchArgs); + } + + private static Stream generateMatchPhraseQueries() { + var matchPhraseArgs = new HashMap(); + matchPhraseArgs.put("analyzer", new String[]{ "standard", "stop", "english" }); + matchPhraseArgs.put("max_expansions", new Integer[]{ 0, 5, 20 }); + matchPhraseArgs.put("slop", new Integer[]{ 0, 1, 2 }); + + return generateQueries("match_phrase", matchPhraseArgs); + } + + private static Stream generateQueries(String function, + HashMap functionArgs) { + var rand = new Random(0); + + class QueryGenerator implements Iterator { + + private int currentQuery = 0; + + private String randomIdentifier() { + return RandomStringUtils.random(10, 0, 0,true, false, null, rand); + } + + @Override + public boolean hasNext() { + int numQueries = 100; + return currentQuery < numQueries; + } + + @Override + public String next() { + currentQuery += 1; + + StringBuilder query = new StringBuilder(); + query.append(String.format("SELECT * FROM test WHERE %s(%s, %s", function, + randomIdentifier(), + randomIdentifier())); + var args = new ArrayList(); + for (var pair : functionArgs.entrySet()) { + if (rand.nextBoolean()) { + var arg = new StringBuilder(); + arg.append(rand.nextBoolean() ? "," : ", "); + arg.append(rand.nextBoolean() ? pair.getKey().toLowerCase() + : pair.getKey().toUpperCase()); + arg.append(rand.nextBoolean() ? "=" : " = "); + if (pair.getValue() instanceof String[] || rand.nextBoolean()) { + var quoteSymbol = rand.nextBoolean() ? '\'' : '"'; + arg.append(quoteSymbol); + arg.append(pair.getValue()[rand.nextInt(pair.getValue().length)]); + arg.append(quoteSymbol); + } else { + arg.append(pair.getValue()[rand.nextInt(pair.getValue().length)]); + } + args.add(arg.toString()); + } + } + Collections.shuffle(args, rand); + for (var arg : args) { + query.append(arg); + } + query.append(rand.nextBoolean() ? ")" : ");"); + return query.toString(); + } + } + + var it = new QueryGenerator(); + return Streams.stream(it); + } }