diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index 29ab9ec9f8..a0c4eb6eab 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -8,6 +8,7 @@ import static org.opensearch.sql.util.MatcherUtils.assertJsonEqualsIgnoreId; import java.io.IOException; +import org.junit.Assume; import org.junit.Ignore; import org.junit.Test; import org.opensearch.sql.ppl.ExplainIT; @@ -67,4 +68,30 @@ public void supportPushDownSortMergeJoin() throws IOException { String expected = loadExpectedPlan("explain_merge_join_sort_push.json"); assertJsonEqualsIgnoreId(expected, result); } + + // Only for Calcite + @Test + public void supportPartialPushDown() throws IOException { + Assume.assumeTrue("This test is only for push down enabled", isPushdownEnabled()); + // field `address` is text type without keyword subfield, so we cannot push it down. + String query = + "source=opensearch-sql_test_index_account | where (state = 'Seattle' or age < 10) and (age" + + " >= 1 and address = '880 Holmes Lane') | fields age, address"; + var result = explainQueryToString(query); + String expected = loadFromFile("expectedOutput/calcite/explain_partial_filter_push.json"); + assertJsonEqualsIgnoreId(expected, result); + } + + // Only for Calcite + @Test + public void supportPartialPushDown_NoPushIfAllFailed() throws IOException { + Assume.assumeTrue("This test is only for push down enabled", isPushdownEnabled()); + // field `address` is text type without keyword subfield, so we cannot push it down. + String query = + "source=opensearch-sql_test_index_account | where (address = '671 Bristol Street' or age <" + + " 10) and (age >= 10 or address = '880 Holmes Lane') | fields age, address"; + var result = explainQueryToString(query); + String expected = loadFromFile("expectedOutput/calcite/explain_partial_filter_push2.json"); + assertJsonEqualsIgnoreId(expected, result); + } } diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteRelevanceFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteRelevanceFunctionIT.java index 85ac2554e0..d95fa51405 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteRelevanceFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteRelevanceFunctionIT.java @@ -5,6 +5,10 @@ package org.opensearch.sql.calcite.remote; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BEER; + +import java.io.IOException; +import org.junit.Assume; import org.opensearch.sql.ppl.RelevanceFunctionIT; public class CalciteRelevanceFunctionIT extends RelevanceFunctionIT { @@ -14,4 +18,17 @@ public void init() throws Exception { enableCalcite(); disallowCalciteFallback(); } + + // For Calcite, this PPL won't throw exception since it supports partial pushdown and has + // optimization rule `FilterProjectTransposeRule` to push down the filter through the project. + @Override + public void not_pushdown_throws_exception() throws IOException { + Assume.assumeTrue("This test is only for push down enabled", isPushdownEnabled()); + String query1 = + "SOURCE=" + + TEST_INDEX_BEER + + " | EVAL answerId = AcceptedAnswerId + 1" + + " | WHERE simple_query_string(['Tags'], 'taste') and answerId > 200"; + assertEquals(5, executeQuery(query1).getInt("total")); + } } diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_partial_filter_push.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_partial_filter_push.json new file mode 100644 index 0000000000..642fe5cd51 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_partial_filter_push.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalProject(age=[$8], address=[$2])\n LogicalFilter(condition=[AND(OR(=($7, 'Seattle'), <($8, 10)), >=($8, 1), =($2, '880 Holmes Lane'))])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", + "physical": "EnumerableCalc(expr#0..1=[{inputs}], expr#2=['880 Holmes Lane':VARCHAR], expr#3=[=($t0, $t2)], age=[$t1], address=[$t0], $condition=[$t3])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[address, state, age], FILTER->AND(OR(=($1, 'Seattle'), <($2, 10)), >=($2, 1)), PROJECT->[address, age]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"query\":{\"bool\":{\"must\":[{\"bool\":{\"should\":[{\"term\":{\"state.keyword\":{\"value\":\"Seattle\",\"boost\":1.0}}},{\"range\":{\"age\":{\"from\":null,\"to\":10,\"include_lower\":true,\"include_upper\":false,\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}},{\"range\":{\"age\":{\"from\":1,\"to\":null,\"include_lower\":true,\"include_upper\":true,\"boost\":1.0}}}],\"adjust_pure_negative\":true,\"boost\":1.0}},\"_source\":{\"includes\":[\"address\",\"age\"],\"excludes\":[]},\"sort\":[{\"_doc\":{\"order\":\"asc\"}}]}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n" + } +} diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_partial_filter_push2.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_partial_filter_push2.json new file mode 100644 index 0000000000..087d2dfbea --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_partial_filter_push2.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalProject(age=[$8], address=[$2])\n LogicalFilter(condition=[AND(OR(=($2, '671 Bristol Street'), <($8, 10)), OR(>=($8, 10), =($2, '880 Holmes Lane')))])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]])\n", + "physical": "EnumerableCalc(expr#0..1=[{inputs}], expr#2=['671 Bristol Street':VARCHAR], expr#3=[=($t0, $t2)], expr#4=[10], expr#5=[<($t1, $t4)], expr#6=[OR($t3, $t5)], expr#7=[>=($t1, $t4)], expr#8=['880 Holmes Lane':VARCHAR], expr#9=[=($t0, $t8)], expr#10=[OR($t7, $t9)], expr#11=[AND($t6, $t10)], age=[$t1], address=[$t0], $condition=[$t11])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[address, age]], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"timeout\":\"1m\",\"_source\":{\"includes\":[\"address\",\"age\"],\"excludes\":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n" + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchFilterIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchFilterIndexScanRule.java index 234dbaf4e4..4ee5148b82 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchFilterIndexScanRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchFilterIndexScanRule.java @@ -7,6 +7,7 @@ import java.util.function.Predicate; import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.AbstractRelNode; import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.logical.LogicalFilter; import org.immutables.value.Value; @@ -37,9 +38,9 @@ public void onMatch(RelOptRuleCall call) { } protected void apply(RelOptRuleCall call, Filter filter, CalciteLogicalIndexScan scan) { - CalciteLogicalIndexScan newScan = scan.pushDownFilter(filter); - if (newScan != null) { - call.transformTo(newScan); + AbstractRelNode newRel = scan.pushDownFilter(filter); + if (newRel != null) { + call.transformTo(newRel); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PredicateAnalyzer.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PredicateAnalyzer.java index 45117c39a5..0d660df45b 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PredicateAnalyzer.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PredicateAnalyzer.java @@ -52,6 +52,7 @@ import java.util.Locale; import java.util.Map; import java.util.Set; +import lombok.Getter; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexCall; @@ -144,17 +145,15 @@ private PredicateAnalyzer() {} public static QueryBuilder analyze( RexNode expression, List schema, Map filedTypes) throws ExpressionNotAnalyzableException { + return analyze_(expression, schema, filedTypes).builder(); + } + + public static QueryExpression analyze_( + RexNode expression, List schema, Map filedTypes) + throws ExpressionNotAnalyzableException { requireNonNull(expression, "expression"); try { - // visits expression tree - QueryExpression queryExpression = - (QueryExpression) expression.accept(new Visitor(schema, filedTypes)); - - if (queryExpression != null && queryExpression.isPartial()) { - throw new UnsupportedOperationException( - "Can't handle partial QueryExpression: " + queryExpression); - } - return queryExpression != null ? queryExpression.builder() : null; + return (QueryExpression) expression.accept(new Visitor(schema, filedTypes)); } catch (Throwable e) { Throwables.throwIfInstanceOf(e, UnsupportedOperationException.class); throw new ExpressionNotAnalyzableException("Can't convert " + expression, e); @@ -567,6 +566,7 @@ private QueryExpression andOr(RexCall call) { QueryExpression[] expressions = new QueryExpression[call.getOperands().size()]; PredicateAnalyzerException firstError = null; boolean partial = false; + int failedCount = 0; for (int i = 0; i < call.getOperands().size(); i++) { try { Expression expr = call.getOperands().get(i).accept(this); @@ -574,6 +574,9 @@ private QueryExpression andOr(RexCall call) { // nop currently } else { expressions[i] = (QueryExpression) call.getOperands().get(i).accept(this); + // Update or simplify the analyzed node list if it is not partial. + if (!expressions[i].isPartial()) + expressions[i].updateAnalyzedNodes(call.getOperands().get(i)); } partial |= expressions[i].isPartial(); } catch (PredicateAnalyzerException e) { @@ -581,6 +584,10 @@ private QueryExpression andOr(RexCall call) { firstError = e; } partial = true; + ++failedCount; + // If we cannot analyze the operand, wrap the RexNode with UnAnalyzableQueryExpression and + // record them in the array. We will reuse them later. + expressions[i] = new UnAnalyzableQueryExpression(call.getOperands().get(i)); } } @@ -596,6 +603,11 @@ private QueryExpression andOr(RexCall call) { } return CompoundQueryExpression.or(expressions); case AND: + if (failedCount == call.getOperands().size()) { + // If all operands failed, we cannot analyze the AND expression. + throw new PredicateAnalyzerException( + "All expressions in AND failed to analyze: " + call); + } return CompoundQueryExpression.and(partial, expressions); default: String message = format(Locale.ROOT, "Unable to handle call: [%s]", call); @@ -712,74 +724,138 @@ private static boolean isColumn( interface Expression {} /** Main expression operators (like {@code equals}, {@code gt}, {@code exists} etc.) */ - abstract static class QueryExpression implements Expression { + public abstract static class QueryExpression implements Expression { public abstract QueryBuilder builder(); + public abstract List getAnalyzedNodes(); + + public abstract void updateAnalyzedNodes(RexNode rexNode); + + public abstract List getUnAnalyzableNodes(); + public boolean isPartial() { return false; } - public abstract QueryExpression contains(LiteralExpression literal); - /** Negate {@code this} QueryExpression (not the next one). */ - public abstract QueryExpression not(); - - public abstract QueryExpression exists(); + QueryExpression not() { + throw new PredicateAnalyzerException("not cannot be applied to " + this.getClass()); + } - public abstract QueryExpression notExists(); + QueryExpression exists() { + throw new PredicateAnalyzerException( + "SqlOperatorImpl ['exists'] " + "cannot be applied to " + this.getClass()); + } - public abstract QueryExpression like(LiteralExpression literal); + QueryExpression notExists() { + throw new PredicateAnalyzerException( + "SqlOperatorImpl ['notExists'] " + "cannot be applied to " + this.getClass()); + } - public abstract QueryExpression notLike(LiteralExpression literal); + QueryExpression contains(LiteralExpression literal) { + throw new PredicateAnalyzerException( + "SqlOperatorImpl ['contains'] " + "cannot be applied to " + this.getClass()); + } - public abstract QueryExpression equals(LiteralExpression literal); + QueryExpression between(Range literal, boolean isTimeStamp) { + throw new PredicateAnalyzerException("between cannot be applied to " + this.getClass()); + } - public abstract QueryExpression in(LiteralExpression literal); + QueryExpression like(LiteralExpression literal) { + throw new PredicateAnalyzerException( + "SqlOperatorImpl ['like'] " + "cannot be applied to " + this.getClass()); + } - public abstract QueryExpression notIn(LiteralExpression literal); + QueryExpression notLike(LiteralExpression literal) { + throw new PredicateAnalyzerException( + "SqlOperatorImpl ['notLike'] " + "cannot be applied to " + this.getClass()); + } - public QueryExpression between(Range literal, boolean isTimeStamp) { - throw new PredicateAnalyzerException("between cannot be applied to " + this.getClass()); + QueryExpression equals(LiteralExpression literal) { + throw new PredicateAnalyzerException( + "SqlOperatorImpl ['='] " + "cannot be applied to " + this.getClass()); } - public QueryExpression equals(Object point, boolean isTimeStamp) { + QueryExpression equals(Object point, boolean isTimeStamp) { throw new PredicateAnalyzerException("equals cannot be applied to " + this.getClass()); } - public abstract QueryExpression notEquals(LiteralExpression literal); + QueryExpression notEquals(LiteralExpression literal) { + throw new PredicateAnalyzerException( + "SqlOperatorImpl ['not'] " + "cannot be applied to " + this.getClass()); + } - public abstract QueryExpression gt(LiteralExpression literal); + QueryExpression gt(LiteralExpression literal) { + throw new PredicateAnalyzerException( + "SqlOperatorImpl ['>'] " + "cannot be applied to " + this.getClass()); + } - public abstract QueryExpression gte(LiteralExpression literal); + QueryExpression gte(LiteralExpression literal) { + throw new PredicateAnalyzerException( + "SqlOperatorImpl ['>='] " + "cannot be applied to " + this.getClass()); + } - public abstract QueryExpression lt(LiteralExpression literal); + QueryExpression lt(LiteralExpression literal) { + throw new PredicateAnalyzerException( + "SqlOperatorImpl ['<'] " + "cannot be applied to " + this.getClass()); + } - public abstract QueryExpression lte(LiteralExpression literal); + QueryExpression lte(LiteralExpression literal) { + throw new PredicateAnalyzerException( + "SqlOperatorImpl ['<='] " + "cannot be applied to " + this.getClass()); + } - public abstract QueryExpression match(String query, Map optionalArguments); + QueryExpression match(String query, Map optionalArguments) { + throw new PredicateAnalyzerException("Match " + "cannot be applied to " + this.getClass()); + } - public abstract QueryExpression matchPhrase( - String query, Map optionalArguments); + QueryExpression matchPhrase(String query, Map optionalArguments) { + throw new PredicateAnalyzerException( + "MatchPhrase " + "cannot be applied to " + this.getClass()); + } - public abstract QueryExpression matchBoolPrefix( - String query, Map optionalArguments); + QueryExpression matchBoolPrefix(String query, Map optionalArguments) { + throw new PredicateAnalyzerException( + "MatchBoolPrefix " + "cannot be applied to " + this.getClass()); + } - public abstract QueryExpression matchPhrasePrefix( - String query, Map optionalArguments); + QueryExpression matchPhrasePrefix(String query, Map optionalArguments) { + throw new PredicateAnalyzerException( + "MatchPhrasePrefix " + "cannot be applied to " + this.getClass()); + } - public abstract QueryExpression simpleQueryString( - RexCall fieldsRexCall, String query, Map optionalArguments); + QueryExpression simpleQueryString( + RexCall fieldsRexCall, String query, Map optionalArguments) { + throw new PredicateAnalyzerException( + "SimpleQueryString " + "cannot be applied to " + this.getClass()); + } - public abstract QueryExpression queryString( - RexCall fieldsRexCall, String query, Map optionalArguments); + QueryExpression queryString( + RexCall fieldsRexCall, String query, Map optionalArguments) { + throw new PredicateAnalyzerException( + "QueryString " + "cannot be applied to " + this.getClass()); + } + + QueryExpression multiMatch( + RexCall fieldsRexCall, String query, Map optionalArguments) { + throw new PredicateAnalyzerException( + "MultiMatch " + "cannot be applied to " + this.getClass()); + } + + QueryExpression isTrue() { + throw new PredicateAnalyzerException("isTrue cannot be applied to " + this.getClass()); + } - public abstract QueryExpression multiMatch( - RexCall fieldsRexCall, String query, Map optionalArguments); + QueryExpression in(LiteralExpression literal) { + throw new PredicateAnalyzerException("in cannot be applied to " + this.getClass()); + } - public abstract QueryExpression isTrue(); + QueryExpression notIn(LiteralExpression literal) { + throw new PredicateAnalyzerException("notIn cannot be applied to " + this.getClass()); + } - public static QueryExpression create(TerminalExpression expression) { + static QueryExpression create(TerminalExpression expression) { if (expression instanceof CastExpression) { expression = CastExpression.unpack(expression); } @@ -793,11 +869,43 @@ public static QueryExpression create(TerminalExpression expression) { } } + @Getter + static class UnAnalyzableQueryExpression extends QueryExpression { + final RexNode unAnalyzableRexNode; + + public UnAnalyzableQueryExpression(RexNode rexNode) { + this.unAnalyzableRexNode = requireNonNull(rexNode, "rexNode"); + } + + @Override + public QueryBuilder builder() { + return null; + } + + @Override + public List getUnAnalyzableNodes() { + return List.of(unAnalyzableRexNode); + } + + @Override + public List getAnalyzedNodes() { + return List.of(); + } + + @Override + public void updateAnalyzedNodes(RexNode rexNode) { + throw new IllegalStateException( + "UnAnalyzableQueryExpression does not support unAnalyzableNodes"); + } + } + /** Builds conjunctions / disjunctions based on existing expressions. */ - static class CompoundQueryExpression extends QueryExpression { + public static class CompoundQueryExpression extends QueryExpression { private final boolean partial; private final BoolQueryBuilder builder; + @Getter private List analyzedNodes = new ArrayList<>(); + @Getter private final List unAnalyzableNodes = new ArrayList<>(); public static CompoundQueryExpression or(QueryExpression... expressions) { CompoundQueryExpression bqe = new CompoundQueryExpression(false); @@ -817,7 +925,9 @@ public static CompoundQueryExpression or(QueryExpression... expressions) { public static CompoundQueryExpression and(boolean partial, QueryExpression... expressions) { CompoundQueryExpression bqe = new CompoundQueryExpression(partial); for (QueryExpression expression : expressions) { - if (expression != null) { // partial expressions have nulls for missing nodes + bqe.analyzedNodes.addAll(expression.getAnalyzedNodes()); + bqe.unAnalyzableNodes.addAll(expression.getUnAnalyzableNodes()); + if (!(expression instanceof UnAnalyzableQueryExpression)) { bqe.builder.must(expression.builder()); } } @@ -844,139 +954,20 @@ public QueryBuilder builder() { } @Override - public QueryExpression not() { - return new CompoundQueryExpression(partial, boolQuery().mustNot(builder())); - } - - @Override - public QueryExpression exists() { - throw new PredicateAnalyzerException( - "SqlOperatorImpl ['exists'] " + "cannot be applied to a compound expression"); - } - - @Override - public QueryExpression contains(LiteralExpression literal) { - throw new PredicateAnalyzerException( - "SqlOperatorImpl ['contains'] " + "cannot be applied to a compound expression"); - } - - @Override - public QueryExpression notExists() { - throw new PredicateAnalyzerException( - "SqlOperatorImpl ['notExists'] " + "cannot be applied to a compound expression"); + public void updateAnalyzedNodes(RexNode rexNode) { + this.analyzedNodes = List.of(rexNode); } @Override - public QueryExpression like(LiteralExpression literal) { - throw new PredicateAnalyzerException( - "SqlOperatorImpl ['like'] " + "cannot be applied to a compound expression"); - } - - @Override - public QueryExpression notLike(LiteralExpression literal) { - throw new PredicateAnalyzerException( - "SqlOperatorImpl ['notLike'] " + "cannot be applied to a compound expression"); - } - - @Override - public QueryExpression equals(LiteralExpression literal) { - throw new PredicateAnalyzerException( - "SqlOperatorImpl ['='] " + "cannot be applied to a compound expression"); - } - - @Override - public QueryExpression notEquals(LiteralExpression literal) { - throw new PredicateAnalyzerException( - "SqlOperatorImpl ['not'] " + "cannot be applied to a compound expression"); - } - - @Override - public QueryExpression gt(LiteralExpression literal) { - throw new PredicateAnalyzerException( - "SqlOperatorImpl ['>'] " + "cannot be applied to a compound expression"); - } - - @Override - public QueryExpression gte(LiteralExpression literal) { - throw new PredicateAnalyzerException( - "SqlOperatorImpl ['>='] " + "cannot be applied to a compound expression"); - } - - @Override - public QueryExpression lt(LiteralExpression literal) { - throw new PredicateAnalyzerException( - "SqlOperatorImpl ['<'] " + "cannot be applied to a compound expression"); - } - - @Override - public QueryExpression lte(LiteralExpression literal) { - throw new PredicateAnalyzerException( - "SqlOperatorImpl ['<='] " + "cannot be applied to a compound expression"); - } - - @Override - public QueryExpression match(String query, Map optionalArguments) { - throw new PredicateAnalyzerException("Match " + "cannot be applied to a compound expression"); - } - - @Override - public QueryExpression matchPhrase(String query, Map optionalArguments) { - throw new PredicateAnalyzerException( - "MatchPhrase " + "cannot be applied to a compound expression"); - } - - @Override - public QueryExpression matchBoolPrefix(String query, Map optionalArguments) { - throw new PredicateAnalyzerException( - "MatchBoolPrefix " + "cannot be applied to a compound expression"); - } - - @Override - public QueryExpression matchPhrasePrefix(String query, Map optionalArguments) { - throw new PredicateAnalyzerException( - "MatchPhrasePrefix " + "cannot be applied to a compound expression"); - } - - @Override - public QueryExpression simpleQueryString( - RexCall fieldsRexCall, String query, Map optionalArguments) { - throw new PredicateAnalyzerException( - "SimpleQueryString " + "cannot be applied to a compound expression"); - } - - @Override - public QueryExpression queryString( - RexCall fieldsRexCall, String query, Map optionalArguments) { - throw new PredicateAnalyzerException( - "QueryString " + "cannot be applied to a compound expression"); - } - - @Override - public QueryExpression multiMatch( - RexCall fieldsRexCall, String query, Map optionalArguments) { - throw new PredicateAnalyzerException( - "MultiMatch " + "cannot be applied to a compound expression"); - } - - @Override - public QueryExpression isTrue() { - throw new PredicateAnalyzerException("isTrue cannot be applied to a compound expression"); - } - - @Override - public QueryExpression in(LiteralExpression literal) { - throw new PredicateAnalyzerException("in cannot be applied to a compound expression"); - } - - @Override - public QueryExpression notIn(LiteralExpression literal) { - throw new PredicateAnalyzerException("notIn cannot be applied to a compound expression"); + public QueryExpression not() { + return new CompoundQueryExpression(partial, boolQuery().mustNot(builder())); } } /** Usually basic expression of type {@code a = 'val'} or {@code b > 42}. */ static class SimpleQueryExpression extends QueryExpression { + private RexNode analyzedRexNode; private final NamedFieldExpression rel; private QueryBuilder builder; @@ -985,7 +976,13 @@ private String getFieldReference() { } private String getFieldReferenceForTermQuery() { - return rel.getReferenceForTermQuery(); + String reference = rel.getReferenceForTermQuery(); + // Throw exception in advance of method builder() to trigger partial push down. + if (reference == null) { + throw new PredicateAnalyzerException( + "Field reference for term query cannot be null for " + rel.getRootName()); + } + return reference; } private SimpleQueryExpression(NamedFieldExpression rel) { @@ -1005,6 +1002,21 @@ public QueryBuilder builder() { return builder; } + @Override + public List getUnAnalyzableNodes() { + return List.of(); + } + + @Override + public List getAnalyzedNodes() { + return List.of(analyzedRexNode); + } + + @Override + public void updateAnalyzedNodes(RexNode rexNode) { + this.analyzedRexNode = rexNode; + } + @Override public QueryExpression not() { builder = boolQuery().mustNot(builder()); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java index 7753b619f8..b2281bcb00 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/AbstractCalciteIndexScan.java @@ -288,6 +288,13 @@ protected enum PushDownType { // NESTED } + /** + * Represents a push down action that can be applied to an OpenSearchRequestBuilder. + * + * @param type PushDownType enum + * @param digest the digest of the pushed down operator + * @param action the lambda action to apply on the OpenSearchRequestBuilder + */ public record PushDownAction(PushDownType type, Object digest, AbstractAction action) { static PushDownAction of(PushDownType type, Object digest, AbstractAction action) { return new PushDownAction(type, digest, action); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java index 72025bfa52..66804c1b6c 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/CalciteLogicalIndexScan.java @@ -17,6 +17,7 @@ import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.RelTraitSet; +import org.apache.calcite.rel.AbstractRelNode; import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.RelFieldCollation; @@ -27,6 +28,9 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -40,6 +44,7 @@ import org.opensearch.sql.opensearch.planner.physical.OpenSearchIndexRules; import org.opensearch.sql.opensearch.request.AggregateAnalyzer; import org.opensearch.sql.opensearch.request.PredicateAnalyzer; +import org.opensearch.sql.opensearch.request.PredicateAnalyzer.QueryExpression; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; import org.opensearch.sql.opensearch.storage.OpenSearchIndex; @@ -102,20 +107,32 @@ public void register(RelOptPlanner planner) { } } - public CalciteLogicalIndexScan pushDownFilter(Filter filter) { + public AbstractRelNode pushDownFilter(Filter filter) { try { - CalciteLogicalIndexScan newScan = this.copyWithNewSchema(filter.getRowType()); List schema = this.getRowType().getFieldNames(); Map filedTypes = this.osIndex.getFieldTypes(); - QueryBuilder filterBuilder = - PredicateAnalyzer.analyze(filter.getCondition(), schema, filedTypes); + QueryExpression queryExpression = + PredicateAnalyzer.analyze_(filter.getCondition(), schema, filedTypes); + QueryBuilder queryBuilder = queryExpression.builder(); + CalciteLogicalIndexScan newScan = this.copyWithNewSchema(filter.getRowType()); + // TODO: handle the case where condition contains a score function newScan.pushDownContext.add( PushDownAction.of( PushDownType.FILTER, - filter.getCondition(), - requestBuilder -> requestBuilder.pushDownFilter(filterBuilder))); + queryExpression.isPartial() + ? constructCondition( + queryExpression.getAnalyzedNodes(), getCluster().getRexBuilder()) + : filter.getCondition(), + requestBuilder -> requestBuilder.pushDownFilter(queryBuilder))); - // TODO: handle the case where condition contains a score function + // If the query expression is partial, we need to replace the input of the filter with the + // partial pushed scan and the filter condition with non-pushed-down conditions. + if (queryExpression.isPartial()) { + // Only CompoundQueryExpression could be partial. + List conditions = queryExpression.getUnAnalyzableNodes(); + RexNode newCondition = constructCondition(conditions, getCluster().getRexBuilder()); + return filter.copy(filter.getTraitSet(), newScan, newCondition); + } return newScan; } catch (Exception e) { if (LOG.isDebugEnabled()) { @@ -127,6 +144,12 @@ public CalciteLogicalIndexScan pushDownFilter(Filter filter) { return null; } + private static RexNode constructCondition(List conditions, RexBuilder rexBuilder) { + return conditions.size() > 1 + ? rexBuilder.makeCall(SqlStdOperatorTable.AND, conditions) + : conditions.get(0); + } + /** * When pushing down a project, we need to create a new CalciteLogicalIndexScan with the updated * schema since we cannot override getRowType() which is defined to be final.