diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index b0e0f3d89fd..cb7264113ec 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -305,6 +305,25 @@ public void testExplainBinWithAligntime() throws IOException { + " head 5")); } + @Test + public void testExplainCountEval() throws IOException { + String query = + "source=opensearch-sql_test_index_bank | stats count(eval(age > 30)) as mature_count"; + var result = explainQueryToString(query); + String expected = loadExpectedPlan("explain_count_eval_push.json"); + assertJsonEqualsIgnoreId(expected, result); + } + + @Test + public void testExplainCountEvalComplex() throws IOException { + String query = + "source=opensearch-sql_test_index_bank | stats count(eval(age > 30 and age < 50)) as" + + " mature_count"; + var result = explainQueryToString(query); + String expected = loadExpectedPlan("explain_count_eval_complex_push.json"); + assertJsonEqualsIgnoreId(expected, result); + } + public void testEventstatsDistinctCountExplain() throws IOException { Assume.assumeTrue("This test is only for push down enabled", isPushdownEnabled()); String query = diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_count_eval_complex_push.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_count_eval_complex_push.json new file mode 100644 index 00000000000..8e429a7f610 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_count_eval_complex_push.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalAggregate(group=[{}], mature_count=[COUNT($0)])\n LogicalProject($f1=[CASE(SEARCH($10, Sarg[(30..50)]), 1, null:NULL)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n", + "physical": "EnumerableLimit(fetch=[10000])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={},mature_count=COUNT() FILTER $0)], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"aggregations\":{\"mature_count\":{\"filter\":{\"range\":{\"age\":{\"from\":30.0,\"to\":50.0,\"include_lower\":false,\"include_upper\":false,\"boost\":1.0}}},\"aggregations\":{\"mature_count\":{\"value_count\":{\"field\":\"_index\"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n" + } +} \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_count_eval_push.json b/integ-test/src/test/resources/expectedOutput/calcite/explain_count_eval_push.json new file mode 100644 index 00000000000..f0b75595a56 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_count_eval_push.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalAggregate(group=[{}], mature_count=[COUNT($0)])\n LogicalProject($f1=[CASE(>($10, 30), 1, null:NULL)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n", + "physical": "EnumerableLimit(fetch=[10000])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]], PushDownContext=[[AGGREGATION->rel#:LogicalAggregate.NONE.[](input=RelSubset#,group={},mature_count=COUNT() FILTER $0)], OpenSearchRequestBuilder(sourceBuilder={\"from\":0,\"size\":0,\"timeout\":\"1m\",\"aggregations\":{\"mature_count\":{\"filter\":{\"range\":{\"age\":{\"from\":30,\"to\":null,\"include_lower\":false,\"include_upper\":true,\"boost\":1.0}}},\"aggregations\":{\"mature_count\":{\"value_count\":{\"field\":\"_index\"}}}}}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)])\n" + } +} \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_count_eval_complex_push.json b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_count_eval_complex_push.json new file mode 100644 index 00000000000..1c4cb0bf63e --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_count_eval_complex_push.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalAggregate(group=[{}], mature_count=[COUNT($0)])\n LogicalProject($f1=[CASE(SEARCH($10, Sarg[(30..50)]), 1, null:NULL)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n", + "physical": "EnumerableLimit(fetch=[10000])\n EnumerableAggregate(group=[{}], mature_count=[COUNT() FILTER $0])\n EnumerableCalc(expr#0..18=[{inputs}], expr#19=[Sarg[(30..50)]], expr#20=[SEARCH($t10, $t19)], expr#21=[IS TRUE($t20)], $f1=[$t21])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n" + } +} \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_count_eval_push.json b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_count_eval_push.json new file mode 100644 index 00000000000..879da821403 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_count_eval_push.json @@ -0,0 +1,6 @@ +{ + "calcite": { + "logical": "LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT])\n LogicalAggregate(group=[{}], mature_count=[COUNT($0)])\n LogicalProject($f1=[CASE(>($10, 30), 1, null:NULL)])\n CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n", + "physical": "EnumerableLimit(fetch=[10000])\n EnumerableAggregate(group=[{}], mature_count=[COUNT() FILTER $0])\n EnumerableCalc(expr#0..18=[{inputs}], expr#19=[30], expr#20=[>($t10, $t19)], expr#21=[IS TRUE($t20)], $f1=[$t21])\n CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_bank]])\n" + } +} \ No newline at end of file diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java index a93fecd2df1..eeb6c671283 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/physical/OpenSearchAggregateIndexScanRule.java @@ -7,7 +7,6 @@ import java.util.function.Predicate; import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelRule; -import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.logical.LogicalProject; import org.apache.calcite.sql.SqlKind; @@ -65,10 +64,6 @@ public interface Config extends RelRule.Config { .withOperandSupplier( b0 -> b0.operand(LogicalAggregate.class) - .predicate( - agg -> - // Cannot push down aggregation with inner filter - agg.getAggCallList().stream().noneMatch(AggregateCall::hasFilter)) .oneInput( b1 -> b1.operand(LogicalProject.class) @@ -100,8 +95,7 @@ public interface Config extends RelRule.Config { .allMatch( call -> call.getAggregation().kind == SqlKind.COUNT - && call.getArgList().isEmpty() - && !call.hasFilter())) + && call.getArgList().isEmpty())) .oneInput( b1 -> b1.operand(CalciteLogicalIndexScan.class) diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java index 2cd37e53690..fff96077332 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateAnalyzer.java @@ -117,10 +117,10 @@ public static class ExpressionNotAnalyzableException extends Exception { private AggregateAnalyzer() {} @RequiredArgsConstructor - private static class AggregateBuilderHelper { - private final RelDataType rowType; - private final Map fieldTypes; - private final RelOptCluster cluster; + static class AggregateBuilderHelper { + final RelDataType rowType; + final Map fieldTypes; + final RelOptCluster cluster; > T build(RexNode node, T aggBuilder) { return build(node, aggBuilder::field, aggBuilder::script); @@ -205,9 +205,11 @@ private static Pair> processAggregateCalls( List aggFieldNames, List aggCalls, Project project, - AggregateBuilderHelper helper) { + AggregateBuilderHelper helper) + throws PredicateAnalyzer.ExpressionNotAnalyzableException { Builder metricBuilder = new AggregatorFactories.Builder(); List metricParserList = new ArrayList<>(); + AggregateFilterAnalyzer aggFilterAnalyzer = new AggregateFilterAnalyzer(helper, project); for (int i = 0; i < aggCalls.size(); i++) { AggregateCall aggCall = aggCalls.get(i); @@ -216,6 +218,7 @@ private static Pair> processAggregateCalls( Pair builderAndParser = createAggregationBuilderAndParser(aggCall, args, aggFieldName, helper); + builderAndParser = aggFilterAnalyzer.analyze(builderAndParser, aggCall, aggFieldName); metricBuilder.addAggregator(builderAndParser.getLeft()); metricParserList.add(builderAndParser.getRight()); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateFilterAnalyzer.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateFilterAnalyzer.java new file mode 100644 index 00000000000..11fa9117693 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/AggregateFilterAnalyzer.java @@ -0,0 +1,72 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.request; + +import lombok.RequiredArgsConstructor; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rex.RexNode; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.sql.opensearch.request.PredicateAnalyzer.QueryExpression; +import org.opensearch.sql.opensearch.response.agg.FilterParser; +import org.opensearch.sql.opensearch.response.agg.MetricParser; + +/** Analyzer for converting aggregate filter conditions into OpenSearch filter aggregations. */ +@RequiredArgsConstructor +public class AggregateFilterAnalyzer { + + /** Helper containing row type, field types, and cluster context for analysis. */ + private final AggregateAnalyzer.AggregateBuilderHelper helper; + + /** Project containing filter expressions referenced by aggregate calls. */ + private final Project project; + + /** + * Analyzes and applies filter to aggregation if the AggregateCall has a filter condition. + * + * @param aggResult the base aggregation and parser to potentially wrap with filter + * @param aggCall the aggregate call which may contain filter information + * @param aggFieldName name for the filtered aggregation + * @return wrapped aggregation with filter if present, otherwise the original result + * @throws PredicateAnalyzer.ExpressionNotAnalyzableException if filter condition cannot be + * analyzed + */ + public Pair analyze( + Pair aggResult, AggregateCall aggCall, String aggFieldName) + throws PredicateAnalyzer.ExpressionNotAnalyzableException { + if (project == null || !aggCall.hasFilter()) { + return aggResult; + } + + QueryExpression queryExpression = analyzeAggregateFilter(aggCall); + return Pair.of( + buildFilterAggregation(aggResult.getLeft(), aggFieldName, queryExpression), + buildFilterParser(aggResult.getRight(), aggFieldName)); + } + + private QueryExpression analyzeAggregateFilter(AggregateCall aggCall) + throws PredicateAnalyzer.ExpressionNotAnalyzableException { + RexNode filterCondition = project.getProjects().get(aggCall.filterArg); + return PredicateAnalyzer.analyzeExpression( + filterCondition, + helper.rowType.getFieldNames(), + helper.fieldTypes, + helper.rowType, + helper.cluster); + } + + private AggregationBuilder buildFilterAggregation( + AggregationBuilder aggBuilder, String aggFieldName, QueryExpression queryExpression) { + return AggregationBuilders.filter(aggFieldName, queryExpression.builder()) + .subAggregation(aggBuilder); + } + + private MetricParser buildFilterParser(MetricParser aggParser, String aggFieldName) { + return FilterParser.builder().name(aggFieldName).metricsParser(aggParser).build(); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PredicateAnalyzer.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PredicateAnalyzer.java index f3c284e4bbe..0c5ccdf580c 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PredicateAnalyzer.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/PredicateAnalyzer.java @@ -294,6 +294,7 @@ private static boolean supportedRexCall(RexCall call) { return true; case POSTFIX: switch (call.getKind()) { + case IS_TRUE: case IS_NOT_NULL: case IS_NULL: return true; @@ -559,12 +560,20 @@ private QueryExpression prefix(RexCall call) { } private QueryExpression postfix(RexCall call) { - checkArgument(call.getKind() == SqlKind.IS_NULL || call.getKind() == SqlKind.IS_NOT_NULL); + checkArgument( + call.getKind() == SqlKind.IS_TRUE + || call.getKind() == SqlKind.IS_NULL + || call.getKind() == SqlKind.IS_NOT_NULL); if (call.getOperands().size() != 1) { String message = format(Locale.ROOT, "Unsupported operator: [%s]", call); throw new PredicateAnalyzerException(message); } + if (call.getKind() == SqlKind.IS_TRUE) { + Expression qe = call.getOperands().get(0).accept(this); + return ((QueryExpression) qe).isTrue(); + } + // OpenSearch DSL does not handle IS_NULL / IS_NOT_NULL on nested fields correctly checkForNestedFieldOperands(call); @@ -1381,7 +1390,8 @@ public QueryExpression multiMatch( @Override public QueryExpression isTrue() { - builder = termQuery(getFieldReferenceForTermQuery(), true); + // Ignore istrue if ISTRUE(predicate) and will support ISTRUE(field) later. + // builder = termQuery(getFieldReferenceForTermQuery(), true); return this; } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java index 470784d4142..d5d512f17be 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/AggregateAnalyzerTest.java @@ -16,7 +16,10 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Function; import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.Project; @@ -25,9 +28,15 @@ import org.apache.calcite.rel.type.RelDataTypeSystem; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.runtime.Hook; +import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.schema.impl.AbstractTable; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.SqlTypeFactoryImpl; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.tools.Frameworks; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.util.Holder; import org.apache.calcite.util.ImmutableBitSet; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Test; @@ -37,6 +46,7 @@ import org.opensearch.sql.opensearch.data.type.OpenSearchDataType.MappingType; import org.opensearch.sql.opensearch.request.AggregateAnalyzer.ExpressionNotAnalyzableException; import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; +import org.opensearch.sql.opensearch.response.agg.FilterParser; import org.opensearch.sql.opensearch.response.agg.MetricParserHelper; import org.opensearch.sql.opensearch.response.agg.NoBucketAggregationParser; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; @@ -46,13 +56,14 @@ class AggregateAnalyzerTest { private final RelDataTypeFactory typeFactory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT); - private final List schema = List.of("a", "b", "c"); + private final List schema = List.of("a", "b", "c", "d"); private final RelDataType rowType = typeFactory.createStructType( ImmutableList.of( typeFactory.createSqlType(SqlTypeName.INTEGER), typeFactory.createSqlType(SqlTypeName.VARCHAR), - typeFactory.createSqlType(SqlTypeName.VARCHAR)), + typeFactory.createSqlType(SqlTypeName.VARCHAR), + typeFactory.createSqlType(SqlTypeName.BOOLEAN)), schema); final Map fieldTypes = Map.of( @@ -62,7 +73,9 @@ class AggregateAnalyzerTest { OpenSearchDataType.of( MappingType.Text, Map.of("fields", Map.of("keyword", Map.of("type", "keyword")))), "c", - OpenSearchDataType.of(MappingType.Text)); // Text without keyword cannot be push down + OpenSearchDataType.of(MappingType.Text), // Text without keyword cannot be push down + "d", + OpenSearchDataType.of(MappingType.Boolean)); // Boolean field for script filter test @Test void analyze_aggCall_simple() throws ExpressionNotAnalyzableException { @@ -332,6 +345,214 @@ void analyze_groupBy_TextWithoutKeyword() { assertEquals("[field] must not be null", exception.getCause().getMessage()); } + @Test + void analyze_aggCall_simpleFilter() throws ExpressionNotAnalyzableException { + buildAggregation("filter_cnt") + .withAggCall( + b -> + b.aggregateCall( + SqlStdOperatorTable.COUNT, + false, + b.call( + SqlStdOperatorTable.IS_TRUE, + b.call(SqlStdOperatorTable.GREATER_THAN, b.field("a"), b.literal(0))), + "filter_cnt")) + .expectDslQuery( + "[{\"filter_cnt\":{\"filter\":{\"range\":{\"a\":{" + + "\"from\":0," + + "\"to\":null," + + "\"include_lower\":false," + + "\"include_upper\":true," + + "\"boost\":1.0}}}," + + "\"aggregations\":{\"filter_cnt\":{\"value_count\":{\"field\":\"_index\"}}}}}]") + .expectResponseParser( + new MetricParserHelper( + List.of( + FilterParser.builder() + .name("filter_cnt") + .metricsParser(new SingleValueParser("filter_cnt")) + .build()))) + .verify(); + } + + @Test + void analyze_aggCall_simpleFilter_distinct() throws ExpressionNotAnalyzableException { + buildAggregation("filter_distinct_cnt") + .withAggCall( + b -> + b.aggregateCall( + SqlStdOperatorTable.COUNT, + true, // distinct = true + b.call( + SqlStdOperatorTable.IS_TRUE, + b.call(SqlStdOperatorTable.EQUALS, b.field("a"), b.literal(10))), + "filter_distinct_cnt", + b.field("a"))) + .expectDslQuery( + "[{\"filter_distinct_cnt\":{\"filter\":{\"term\":{\"a\":{\"value\":10,\"boost\":1.0}}}," + + "\"aggregations\":{\"filter_distinct_cnt\":{\"cardinality\":{\"field\":\"a\"}}}}}]") + .expectResponseParser( + new MetricParserHelper( + List.of( + FilterParser.builder() + .name("filter_distinct_cnt") + .metricsParser(new SingleValueParser("filter_distinct_cnt")) + .build()))) + .verify(); + } + + @Test + void analyze_aggCall_complexFilter() throws ExpressionNotAnalyzableException { + buildAggregation("filter_count_range") + .withAggCall( + b -> + b.aggregateCall( + SqlStdOperatorTable.COUNT, + false, + b.call( + SqlStdOperatorTable.IS_TRUE, + b.call( + SqlStdOperatorTable.AND, + b.call( + SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, + b.field("a"), + b.literal(30)), + b.call( + SqlStdOperatorTable.LESS_THAN_OR_EQUAL, + b.field("a"), + b.literal(50)))), + "filter_count_range")) + .expectDslQuery( + "[{\"filter_count_range\":{\"filter\":{\"range\":{\"a\":{\"from\":30.0,\"to\":50.0," + + "\"include_lower\":true,\"include_upper\":true,\"boost\":1.0}}}," + + "\"aggregations\":{\"filter_count_range\":{\"value_count\":{\"field\":\"_index\"}}}}}]") + .expectResponseParser( + new MetricParserHelper( + List.of( + FilterParser.builder() + .name("filter_count_range") + .metricsParser(new SingleValueParser("filter_count_range")) + .build()))) + .verify(); + } + + @Test + void analyze_aggCall_complexScriptFilter() throws ExpressionNotAnalyzableException { + buildAggregation("filter_bool_count", "filter_complex_count") + .withAggCall( + b -> + b.aggregateCall( + SqlStdOperatorTable.COUNT, + false, + b.call(SqlStdOperatorTable.IS_TRUE, b.field("d")), // bool field + "filter_bool_count")) + .withAggCall( + b -> + b.aggregateCall( + SqlStdOperatorTable.COUNT, + false, + b.call( + SqlStdOperatorTable.IS_TRUE, + b.call( + SqlStdOperatorTable.OR, + b.call(SqlStdOperatorTable.MOD, b.field("a"), b.literal(3)), + b.call(SqlStdOperatorTable.LIKE, b.field("c"), b.literal("%test%")))), + "filter_complex_count")) + .expectDslTemplate( + "[{\"filter_bool_count\":{\"filter\":{\"script\":{\"script\":{\"source\":\"{\\\"langType\\\":\\\"calcite\\\",\\\"script\\\":\\\"*\\\"}\"," + + "\"lang\":\"opensearch_compounded_script\",\"params\":{\"utcTimestamp\":0}},\"boost\":1.0}}," + + "\"aggregations\":{\"filter_bool_count\":{\"value_count\":{\"field\":\"_index\"}}}}}," + + " {\"filter_complex_count\":{\"filter\":{\"script\":{\"script\":{\"source\":\"{\\\"langType\\\":\\\"calcite\\\",\\\"script\\\":\\\"*\\\"}\"," + + "\"lang\":\"opensearch_compounded_script\",\"params\":{\"utcTimestamp\":0}},\"boost\":1.0}}," + + "\"aggregations\":{\"filter_complex_count\":{\"value_count\":{\"field\":\"_index\"}}}}}]") + .expectResponseParser( + new MetricParserHelper( + List.of( + FilterParser.builder() + .name("filter_bool_count") + .metricsParser(new SingleValueParser("filter_bool_count")) + .build(), + FilterParser.builder() + .name("filter_complex_count") + .metricsParser(new SingleValueParser("filter_complex_count")) + .build()))) + .verify(); + } + + @Test + void analyze_aggCall_multipleWithFilter() throws ExpressionNotAnalyzableException { + buildAggregation("filter_avg", "filter_sum", "filter_min", "filter_max") + .withAggCall( + b -> + b.aggregateCall( + SqlStdOperatorTable.AVG, + false, + b.call( + SqlStdOperatorTable.IS_TRUE, + b.call(SqlStdOperatorTable.EQUALS, b.field("a"), b.literal(10))), + "filter_avg", + b.field("a"))) + .withAggCall( + b -> + b.aggregateCall( + SqlStdOperatorTable.SUM, + false, + b.call( + SqlStdOperatorTable.IS_TRUE, + b.call(SqlStdOperatorTable.EQUALS, b.field("a"), b.literal(20))), + "filter_sum", + b.field("a"))) + .withAggCall( + b -> + b.aggregateCall( + SqlStdOperatorTable.MIN, + false, + b.call( + SqlStdOperatorTable.IS_TRUE, + b.call(SqlStdOperatorTable.EQUALS, b.field("b"), b.literal("test1"))), + "filter_min", + b.field("a"))) + .withAggCall( + b -> + b.aggregateCall( + SqlStdOperatorTable.MAX, + false, + b.call( + SqlStdOperatorTable.IS_TRUE, + b.call(SqlStdOperatorTable.EQUALS, b.field("b"), b.literal("test2"))), + "filter_max", + b.field("a"))) + .expectDslQuery( + "[{\"filter_avg\":{\"filter\":{\"term\":{\"a\":{\"value\":10,\"boost\":1.0}}}," + + "\"aggregations\":{\"filter_avg\":{\"avg\":{\"field\":\"a\"}}}}}," + + " {\"filter_sum\":{\"filter\":{\"term\":{\"a\":{\"value\":20,\"boost\":1.0}}}," + + "\"aggregations\":{\"filter_sum\":{\"sum\":{\"field\":\"a\"}}}}}," + + " {\"filter_min\":{\"filter\":{\"term\":{\"b.keyword\":{\"value\":\"test1\",\"boost\":1.0}}}," + + "\"aggregations\":{\"filter_min\":{\"min\":{\"field\":\"a\"}}}}}," + + " {\"filter_max\":{\"filter\":{\"term\":{\"b.keyword\":{\"value\":\"test2\",\"boost\":1.0}}}," + + "\"aggregations\":{\"filter_max\":{\"max\":{\"field\":\"a\"}}}}}]") + .expectResponseParser( + new MetricParserHelper( + List.of( + FilterParser.builder() + .name("filter_avg") + .metricsParser(new SingleValueParser("filter_avg")) + .build(), + FilterParser.builder() + .name("filter_sum") + .metricsParser(new SingleValueParser("filter_sum")) + .build(), + FilterParser.builder() + .name("filter_min") + .metricsParser(new SingleValueParser("filter_min")) + .build(), + FilterParser.builder() + .name("filter_max") + .metricsParser(new SingleValueParser("filter_max")) + .build()))) + .verify(); + } + private Aggregate createMockAggregate(List calls, ImmutableBitSet groups) { Aggregate agg = mock(Aggregate.class); when(agg.getGroupSet()).thenReturn(groups); @@ -352,4 +573,114 @@ private Project createMockProject(List refIndex) { when(project.getRowType()).thenReturn(rowType); return project; } + + private AggregationTestBuilder buildAggregation(String... outputFields) { + return new AggregationTestBuilder(List.of(outputFields)); + } + + /** Fluent API builder for creating aggregate filter tests */ + private class AggregationTestBuilder { + private final List outputFields; + private final List aggCalls = new ArrayList<>(); + private final RelBuilder relBuilder; + private String expectedDsl; + private String expectedDslTemplate; + private MetricParserHelper expectedParser; + + AggregationTestBuilder(List outputFields) { + this.outputFields = new ArrayList<>(outputFields); + this.relBuilder = createRelBuilder(); + } + + private RelBuilder createRelBuilder() { + String tableName = "test"; + SchemaPlus root = Frameworks.createRootSchema(true); + root.add( + tableName, + new AbstractTable() { + @Override + public RelDataType getRowType(RelDataTypeFactory tf) { + return rowType; + } + }); + return RelBuilder.create(Frameworks.newConfigBuilder().defaultSchema(root).build()) + .scan(tableName); + } + + AggregationTestBuilder withAggCall(Function aggCallBuilder) { + aggCalls.add(aggCallBuilder.apply(relBuilder)); + return this; + } + + AggregationTestBuilder expectDslQuery(String expectedDsl) { + this.expectedDsl = expectedDsl; + return this; + } + + AggregationTestBuilder expectDslTemplate(String expectedTemplate) { + this.expectedDslTemplate = expectedTemplate; + return this; + } + + AggregationTestBuilder expectResponseParser(MetricParserHelper expectedParser) { + this.expectedParser = expectedParser; + return this; + } + + private boolean matchesTemplate(String actual, String template) { + // Split template by * and escape each part separately + String[] parts = template.split("\\*", -1); + StringBuilder regexBuilder = new StringBuilder(); + + for (int i = 0; i < parts.length; i++) { + // Quote each literal part + regexBuilder.append(java.util.regex.Pattern.quote(parts[i])); + + // Add wildcard regex between parts (except after the last part) + if (i < parts.length - 1) { + regexBuilder.append(".*?"); + } + } + + String regexPattern = regexBuilder.toString(); + return actual.matches(regexPattern); + } + + void verify() throws ExpressionNotAnalyzableException { + // Set up time hook for script queries + Hook.CURRENT_TIME.addThread((Consumer>) h -> h.set(0L)); + + // Create test RelNode plan + RelNode rel = + relBuilder + .aggregate(relBuilder.groupKey(), aggCalls.toArray(new RelBuilder.AggCall[0])) + .build(); + + // Run analyzer + Aggregate agg = (Aggregate) rel; + Project project = (Project) agg.getInput(0); + Pair, OpenSearchAggregationResponseParser> result = + AggregateAnalyzer.analyze( + agg, project, rowType, fieldTypes, outputFields, agg.getCluster()); + + if (expectedDsl != null) { + assertEquals(expectedDsl, result.getLeft().toString()); + } + + if (expectedDslTemplate != null) { + assertTrue( + matchesTemplate(result.getLeft().toString(), expectedDslTemplate), + "DSL should match template.\nExpected: " + + expectedDslTemplate + + "\nActual: " + + result.getLeft().toString()); + } + + if (expectedParser != null) { + assertInstanceOf(NoBucketAggregationParser.class, result.getRight()); + assertEquals( + expectedParser, ((NoBucketAggregationParser) result.getRight()).getMetricsParser()); + } + } + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/PredicateAnalyzerTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/PredicateAnalyzerTest.java index b2d67b0ebea..0ed865705a7 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/PredicateAnalyzerTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/PredicateAnalyzerTest.java @@ -750,6 +750,28 @@ void equals_scriptPushDown_Struct() throws ExpressionNotAnalyzableException { assert (builder.toString().contains("\"lang\" : \"opensearch_compounded_script\"")); } + @Test + void isTrue_predicate() throws ExpressionNotAnalyzableException { + RexNode call = + builder.makeCall( + SqlStdOperatorTable.IS_TRUE, + builder.makeCall(SqlStdOperatorTable.EQUALS, field2, stringLiteral)); + QueryBuilder result = PredicateAnalyzer.analyze(call, schema, fieldTypes); + + assertInstanceOf(TermQueryBuilder.class, result); + assertEquals( + """ + { + "term" : { + "b.keyword" : { + "value" : "Hi", + "boost" : 1.0 + } + } + }""", + result.toString()); + } + @Test void isNullOr_ScriptPushDown() throws ExpressionNotAnalyzableException { final RelDataType rowType =