Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Nested Function Use In WHERE Clause Predicate Expresion #1657

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.data.model.ExprMissingValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.DSL;
Expand Down Expand Up @@ -220,7 +219,6 @@ public LogicalPlan visitLimit(Limit node, AnalysisContext context) {
public LogicalPlan visitFilter(Filter node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
Expression condition = expressionAnalyzer.analyze(node.getCondition(), context);
verifySupportsCondition(condition);

ExpressionReferenceOptimizer optimizer =
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child);
Expand All @@ -229,7 +227,7 @@ public LogicalPlan visitFilter(Filter node, AnalysisContext context) {
}

/**
* Ensure NESTED function is not used in WHERE, GROUP BY, and HAVING clauses.
* Ensure NESTED function is not used in GROUP BY, and HAVING clauses.
* Fallback to legacy engine. Can remove when support is added for NESTED function in WHERE,
* GROUP BY, ORDER BY, and HAVING clauses.
* @param condition : Filter condition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1041,29 +1041,6 @@ public void nested_group_by_clause_throws_syntax_exception() {
exception.getMessage());
}

/**
* Ensure Nested function falls back to legacy engine when used in WHERE clause.
* TODO Remove this test when support is added.
*/
@Test
public void nested_where_clause_throws_syntax_exception() {
forestmvey marked this conversation as resolved.
Show resolved Hide resolved
SyntaxCheckException exception = assertThrows(SyntaxCheckException.class,
() -> analyze(
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.equalTo(
AstDSL.function("nested", qualifiedName("message", "info")),
AstDSL.stringLiteral("str")
)
)
)
);
assertEquals("Falling back to legacy engine. Nested function is not supported in WHERE,"
+ " GROUP BY, and HAVING clauses.",
exception.getMessage());
}


/**
* SELECT name, AVG(age) FROM test GROUP BY name.
*/
Expand Down
11 changes: 11 additions & 0 deletions docs/user/dql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4456,6 +4456,17 @@ Example with ``field`` and ``path`` parameters::
+---------------------------------+


Example with ``field`` and ``path`` parameters in the SELECT and WHERE clause::

os> SELECT nested(message.info, message) FROM nested WHERE nested(message.info, message) = 'b';
fetched rows / total rows = 1/1
+---------------------------------+
| nested(message.info, message) |
|---------------------------------|
| b |
+---------------------------------+


System Functions
================

Expand Down
77 changes: 65 additions & 12 deletions integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,6 @@ public void nested_function_mixed_with_non_nested_type_test() {
rows("zz", "a"));
}

@Test
public void nested_function_with_where_clause() {
String query =
"SELECT nested(message.info) FROM " + TEST_INDEX_NESTED_TYPE + " WHERE nested(message.info) = 'a'";
JSONObject result = executeJdbcRequest(query);

assertEquals(2, result.getInt("total"));
verifyDataRows(result,
rows("a"),
rows("a"));
}

@Test
public void nested_function_with_order_by_clause() {
String query =
Expand Down Expand Up @@ -313,4 +301,69 @@ public void nested_missing_path_argument() {
"}"
));
}

@Test
public void test_nested_where_with_and_conditional() {
String query = "SELECT nested(message.info), nested(message.author) FROM " + TEST_INDEX_NESTED_TYPE
+ " WHERE nested(message, message.info = 'a' AND message.author = 'e')";
JSONObject result = executeJdbcRequest(query);
assertEquals(1, result.getInt("total"));
verifyDataRows(result, rows("a", "e"));
}

@Test
public void test_nested_in_select_and_where_as_predicate_expression() {
String query = "SELECT nested(message.info) FROM " + TEST_INDEX_NESTED_TYPE
+ " WHERE nested(message.info) = 'a'";
JSONObject result = executeJdbcRequest(query);
assertEquals(3, result.getInt("total"));
verifyDataRows(
result,
rows("a"),
rows("c"),
rows("a")
);
}

@Test
public void test_nested_in_where_as_predicate_expression() {
String query = "SELECT message.info FROM " + TEST_INDEX_NESTED_TYPE
+ " WHERE nested(message.info) = 'a'";
JSONObject result = executeJdbcRequest(query);
assertEquals(2, result.getInt("total"));
// Only first index of array is returned. Second index has 'a'
verifyDataRows(result, rows("a"), rows("c"));
}

@Test
public void test_nested_in_where_as_predicate_expression_with_like() {
String query = "SELECT message.info FROM " + TEST_INDEX_NESTED_TYPE
+ " WHERE nested(message.info) LIKE 'a'";
JSONObject result = executeJdbcRequest(query);
assertEquals(2, result.getInt("total"));
// Only first index of array is returned. Second index has 'a'
verifyDataRows(result, rows("a"), rows("c"));
}

@Test
public void test_nested_in_where_as_predicate_expression_with_multiple_conditions() {
String query = "SELECT message.info, comment.data, message.dayOfWeek FROM " + TEST_INDEX_NESTED_TYPE
+ " WHERE nested(message.info) = 'zz' OR nested(comment.data) = 'ab' AND nested(message.dayOfWeek) >= 4";
JSONObject result = executeJdbcRequest(query);
assertEquals(2, result.getInt("total"));
verifyDataRows(
result,
rows("c", "ab", 4),
rows("zz", "aa", 6)
);
}

@Test
public void test_nested_in_where_as_predicate_expression_with_relevance_query() {
String query = "SELECT comment.likes, someField FROM " + TEST_INDEX_NESTED_TYPE
+ " WHERE nested(comment.likes) = 10 AND match(someField, 'a')";
JSONObject result = executeJdbcRequest(query);
assertEquals(1, result.getInt("total"));
verifyDataRows(result, rows(10, "a"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@

import static java.util.stream.Collectors.mapping;
import static java.util.stream.Collectors.toList;
import static org.opensearch.index.query.QueryBuilders.boolQuery;
import static org.opensearch.index.query.QueryBuilders.matchAllQuery;
import static org.opensearch.index.query.QueryBuilders.nestedQuery;
import static org.opensearch.search.sort.FieldSortBuilder.DOC_FIELD_NAME;
import static org.opensearch.search.sort.SortOrder.ASC;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
Expand All @@ -44,7 +47,6 @@
import org.opensearch.sql.opensearch.data.type.OpenSearchDataType;
import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory;
import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser;
import org.opensearch.sql.planner.logical.LogicalNested;

/**
* OpenSearch search request builder.
Expand Down Expand Up @@ -257,13 +259,33 @@ private boolean isSortByDocOnly() {
*/
public void pushDownNested(List<Map<String, ReferenceExpression>> nestedArgs) {
initBoolQueryFilter();
List<NestedQueryBuilder> nestedQueries = extractNestedQueries(query());
groupFieldNamesByPath(nestedArgs).forEach(
(path, fieldNames) -> buildInnerHit(
fieldNames, createEmptyNestedQuery(path)
)
(path, fieldNames) ->
buildInnerHit(fieldNames, findNestedQueryWithSamePath(nestedQueries, path))
);
}

/**
* InnerHit must be added to the NestedQueryBuilder. We need to extract
* the nested queries currently in the query if there is already a filter
* push down with nested query.
* @param query : current query.
* @return : grouped nested queries currently in query.
*/
private List<NestedQueryBuilder> extractNestedQueries(QueryBuilder query) {
List<NestedQueryBuilder> result = new ArrayList<>();
if (query instanceof NestedQueryBuilder) {
result.add((NestedQueryBuilder) query);
} else if (query instanceof BoolQueryBuilder) {
BoolQueryBuilder boolQ = (BoolQueryBuilder) query;
Stream.of(boolQ.filter(), boolQ.must(), boolQ.should())
.flatMap(Collection::stream)
.forEach(q -> result.addAll(extractNestedQueries(q)));
}
return result;
}

/**
* Initialize bool query for push down.
*/
Expand Down Expand Up @@ -307,13 +329,41 @@ private void buildInnerHit(List<String> paths, NestedQueryBuilder query) {
));
}

/**
* We need to group nested queries with same path for adding new fields with same path of
* inner hits. If we try to add additional inner hits with same path we get an OS error.
* @param nestedQueries Current list of nested queries in query.
* @param path path comparing with current nested queries.
* @return Query with same path or new empty nested query.
*/
private NestedQueryBuilder findNestedQueryWithSamePath(
List<NestedQueryBuilder> nestedQueries, String path
) {
return nestedQueries.stream()
.filter(query -> isSamePath(path, query))
.findAny()
.orElseGet(createEmptyNestedQuery(path));
}

/**
* Check if is nested query is of the same path value.
* @param path Value of path to compare with nested query.
* @param query nested query builder to compare with path.
* @return true if nested query has same path.
*/
private boolean isSamePath(String path, NestedQueryBuilder query) {
return nestedQuery(path, query.query(), query.scoreMode()).equals(query);
}

/**
* Create a nested query with match all filter to place inner hits.
*/
private NestedQueryBuilder createEmptyNestedQuery(String path) {
NestedQueryBuilder nestedQuery = nestedQuery(path, matchAllQuery(), ScoreMode.None);
((BoolQueryBuilder) query().filter().get(0)).must(nestedQuery);
return nestedQuery;
private Supplier<NestedQueryBuilder> createEmptyNestedQuery(String path) {
return () -> {
NestedQueryBuilder nestedQuery = nestedQuery(path, matchAllQuery(), ScoreMode.None);
((BoolQueryBuilder) query().filter().get(0)).must(nestedQuery);
return nestedQuery;
};
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,23 @@
import java.util.Map;
import java.util.function.BiFunction;
import lombok.RequiredArgsConstructor;
import org.apache.lucene.search.join.ScoreMode;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.ScriptQueryBuilder;
import org.opensearch.script.Script;
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.ExpressionNodeVisitor;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.opensearch.storage.script.filter.lucene.LikeQuery;
import org.opensearch.sql.opensearch.storage.script.filter.lucene.LuceneQuery;
import org.opensearch.sql.opensearch.storage.script.filter.lucene.NestedQuery;
import org.opensearch.sql.opensearch.storage.script.filter.lucene.RangeQuery;
import org.opensearch.sql.opensearch.storage.script.filter.lucene.RangeQuery.Comparison;
import org.opensearch.sql.opensearch.storage.script.filter.lucene.TermQuery;
Expand Down Expand Up @@ -75,6 +80,7 @@ public class FilterQueryBuilder extends ExpressionNodeVisitor<QueryBuilder, Obje
.put(BuiltinFunctionName.MATCH_PHRASE_PREFIX.getName(), new MatchPhrasePrefixQuery())
.put(BuiltinFunctionName.WILDCARD_QUERY.getName(), new WildcardQuery())
.put(BuiltinFunctionName.WILDCARDQUERY.getName(), new WildcardQuery())
.put(BuiltinFunctionName.NESTED.getName(), new NestedQuery())
.build();

/**
Expand All @@ -96,10 +102,20 @@ public QueryBuilder visitFunction(FunctionExpression func, Object context) {
return buildBoolQuery(func, context, BoolQueryBuilder::should);
case "not":
return buildBoolQuery(func, context, BoolQueryBuilder::mustNot);
case "nested":
// TODO Fill in case when adding support for syntax - nested(path, condition)
throw new SyntaxCheckException(
"Invalid syntax used for nested function in WHERE clause: "
+ "nested(field | field, path) OPERATOR LITERAL"
);
default: {
LuceneQuery query = luceneQueries.get(name);
if (query != null && query.canSupport(func)) {
return query.build(func);
} else if (query != null && query.isNestedPredicate(func)) {
NestedQuery nestedQuery = (NestedQuery) luceneQueries.get(
((FunctionExpression)func.getArguments().get(0)).getFunctionName());
return nestedQuery.buildNested(func, query);
}
return buildScriptQuery(func);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,15 @@
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.WildcardQueryBuilder;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.opensearch.data.type.OpenSearchTextType;
import org.opensearch.sql.opensearch.storage.script.StringUtils;

public class LikeQuery extends LuceneQuery {
@Override
public QueryBuilder build(FunctionExpression func) {
ReferenceExpression ref = (ReferenceExpression) func.getArguments().get(0);
String field = OpenSearchTextType.convertTextToKeyword(ref.getAttr(), ref.type());
Expression expr = func.getArguments().get(1);
ExprValue literalValue = expr.valueOf();
return createBuilder(field, literalValue.stringValue());
public QueryBuilder doBuild(String fieldName, ExprType fieldType, ExprValue literal) {
String field = OpenSearchTextType.convertTextToKeyword(fieldName, fieldType);
return createBuilder(field, literal.stringValue());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.opensearch.data.type.OpenSearchTextType;

/**
* Lucene query abstraction that builds Lucene query from function expression.
Expand All @@ -56,6 +55,19 @@ public boolean canSupport(FunctionExpression func) {
|| isMultiParameterQuery(func);
}

/**
* Check if predicate expression has nested function on left side of predicate expression.
* Validation for right side being a `LiteralExpression` is done in NestedQuery.
* @param func function.
* @return return true if function has supported nested function expression.
*/
public boolean isNestedPredicate(FunctionExpression func) {
return ((func.getArguments().get(0) instanceof FunctionExpression
&& ((FunctionExpression)func.getArguments().get(0))
.getFunctionName().getFunctionName().equalsIgnoreCase(BuiltinFunctionName.NESTED.name()))
);
}

/**
* Check if the function expression has multiple named argument expressions as the parameters.
*
Expand Down
Loading