From dbb03dbc40a7824f513abdcbcddebbf61f159884 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Wed, 20 Aug 2025 16:00:07 -0700 Subject: [PATCH 1/7] Rewrite count(eval) expression to support filtered counting Signed-off-by: Chen Dai --- ppl/src/main/antlr/OpenSearchPPLParser.g4 | 5 +++++ .../sql/ppl/parser/AstExpressionBuilder.java | 18 ++++++++++++++++++ .../ppl/calcite/CalcitePPLAggregationTest.java | 17 +++++++++++++++++ 3 files changed, 40 insertions(+) diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index dad56443436..c0c6b6833c3 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -426,6 +426,7 @@ statsAggTerm statsFunction : statsFunctionName LT_PRTHS valueExpression RT_PRTHS # statsFunctionCall | COUNT LT_PRTHS RT_PRTHS # countAllFunctionCall + | COUNT LT_PRTHS evalExpression RT_PRTHS # countEvalFunctionCall | (DISTINCT_COUNT | DC | DISTINCT_COUNT_APPROX) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall | takeAggFunction # takeAggFunctionCall | percentileApproxFunction # percentileApproxFunctionCall @@ -494,6 +495,10 @@ valueExpression | LT_PRTHS logicalExpression RT_PRTHS # nestedValueExpr ; +evalExpression + : EVAL LT_PRTHS logicalExpression RT_PRTHS + ; + functionCall : evalFunctionCall | dataTypeFunctionCall diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 212616ac813..b3482b899f5 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -12,6 +12,7 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CompareExprContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ConvertedDataTypeContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountAllFunctionCallContext; +import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountEvalFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DataTypeFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DecimalLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DistinctCountFunctionCallContext; @@ -62,6 +63,7 @@ import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalExpressionContext; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParserBaseVisitor; import org.opensearch.sql.ppl.utils.ArgumentFactory; @@ -229,12 +231,28 @@ public UnresolvedExpression visitCountAllFunctionCall(CountAllFunctionCallContex return new AggregateFunction("count", AllFields.of()); } + @Override + public UnresolvedExpression visitCountEvalFunctionCall(CountEvalFunctionCallContext ctx) { + return new AggregateFunction("count", visit(ctx.evalExpression())); + } + @Override public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunctionCallContext ctx) { String funcName = ctx.DISTINCT_COUNT_APPROX() != null ? "distinct_count_approx" : "count"; return new AggregateFunction(funcName, visit(ctx.valueExpression()), true); } + @Override + public UnresolvedExpression visitEvalExpression(EvalExpressionContext ctx) { + /* + * Rewrite "eval(p)" as "CASE WHEN p THEN 1 ELSE NULL END" so that COUNT or DISTINCT_COUNT + * can correctly perform filtered counting. + * Note: at present only eval() inside counting functions is supported. + */ + UnresolvedExpression predicate = visit(ctx.logicalExpression()); + return AstDSL.caseWhen(null, AstDSL.when(predicate, AstDSL.intLiteral(1))); + } + @Override public UnresolvedExpression visitPercentileApproxFunctionCall( OpenSearchPPLParser.PercentileApproxFunctionCallContext ctx) { diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java index dc3abdc3007..88b14d35e29 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java @@ -32,6 +32,23 @@ public void testSimpleCount() { verifyPPLToSparkSQL(root, expectedSparkSql); } + @Test + public void testEvalCount() { + String ppl = "source=EMP | stats count(eval(SAL > 2000)) as c"; + RelNode root = getRelNode(ppl); + String expectedLogical = + "LogicalAggregate(group=[{}], c=[COUNT($0)])\n" + + " LogicalProject($f1=[CASE(>($5, 2000), 1, null:NULL)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + String expectedResult = "c=6\n"; + verifyResult(root, expectedResult); + + String expectedSparkSql = + "SELECT COUNT(CASE WHEN `SAL` > 2000 THEN 1 ELSE NULL END) `c`\nFROM `scott`.`EMP`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + @Test public void testTakeAgg() { String ppl = "source=EMP | stats take(JOB, 2) as c"; From 1aa23e7d3dff5967542663661c925ba68e9dbe7c Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Thu, 21 Aug 2025 10:41:48 -0700 Subject: [PATCH 2/7] Refactor count eval UTs Signed-off-by: Chen Dai --- .../ppl/calcite/CalcitePPLAbstractTest.java | 31 +++++++ .../calcite/CalcitePPLAggregationTest.java | 17 ---- .../ppl/calcite/CalcitePPLCountEvalTest.java | 85 +++++++++++++++++++ 3 files changed, 116 insertions(+), 17 deletions(-) create mode 100644 ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCountEvalTest.java diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAbstractTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAbstractTest.java index 2e70a210d6e..f40e90ff0e6 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAbstractTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAbstractTest.java @@ -100,6 +100,37 @@ private Node plan(PPLSyntaxParser parser, String query) { return builder.visit(parser.parse(query)); } + /** + * Fluent API for building count(eval) test cases. Provides a clean and readable way to define PPL + * queries and their expected outcomes. + */ + protected PPLQueryTestBuilder withPPLQuery(String ppl) { + return new PPLQueryTestBuilder(ppl); + } + + protected class PPLQueryTestBuilder { + private final RelNode relNode; + + public PPLQueryTestBuilder(String ppl) { + this.relNode = getRelNode(ppl); + } + + public PPLQueryTestBuilder expectLogical(String expectedLogical) { + verifyLogical(relNode, expectedLogical); + return this; + } + + public PPLQueryTestBuilder expectResult(String expectedResult) { + verifyResult(relNode, expectedResult); + return this; + } + + public PPLQueryTestBuilder expectSparkSQL(String expectedSparkSql) { + verifyPPLToSparkSQL(relNode, expectedSparkSql); + return this; + } + } + /** Verify the logical plan of the given RelNode */ public void verifyLogical(RelNode rel, String expectedLogical) { assertThat(rel, hasTree(expectedLogical)); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java index 88b14d35e29..dc3abdc3007 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLAggregationTest.java @@ -32,23 +32,6 @@ public void testSimpleCount() { verifyPPLToSparkSQL(root, expectedSparkSql); } - @Test - public void testEvalCount() { - String ppl = "source=EMP | stats count(eval(SAL > 2000)) as c"; - RelNode root = getRelNode(ppl); - String expectedLogical = - "LogicalAggregate(group=[{}], c=[COUNT($0)])\n" - + " LogicalProject($f1=[CASE(>($5, 2000), 1, null:NULL)])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; - verifyLogical(root, expectedLogical); - String expectedResult = "c=6\n"; - verifyResult(root, expectedResult); - - String expectedSparkSql = - "SELECT COUNT(CASE WHEN `SAL` > 2000 THEN 1 ELSE NULL END) `c`\nFROM `scott`.`EMP`"; - verifyPPLToSparkSQL(root, expectedSparkSql); - } - @Test public void testTakeAgg() { String ppl = "source=EMP | stats take(JOB, 2) as c"; diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCountEvalTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCountEvalTest.java new file mode 100644 index 00000000000..2a07c78a91a --- /dev/null +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCountEvalTest.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.calcite; + +import org.apache.calcite.test.CalciteAssert; +import org.junit.Test; + +/** + * Unit tests for count(eval) functionality in CalcitePPL engine. Tests various scenarios of + * filtered count aggregations. + */ +public class CalcitePPLCountEvalTest extends CalcitePPLAbstractTest { + + public CalcitePPLCountEvalTest() { + super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL); + } + + @Test + public void testCountEvalSimpleCondition() { + withPPLQuery("source=EMP | stats count(eval(SAL > 2000)) as c") + .expectLogical( + "LogicalAggregate(group=[{}], c=[COUNT($0)])\n" + + " LogicalProject($f1=[CASE(>($5, 2000), 1, null:NULL)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n") + .expectResult("c=6\n") + .expectSparkSQL( + "SELECT COUNT(CASE WHEN `SAL` > 2000 THEN 1 ELSE NULL END) `c`\nFROM `scott`.`EMP`"); + } + + @Test + public void testCountEvalComplexCondition() { + withPPLQuery("source=EMP | stats count(eval(SAL > 2000 and DEPTNO < 30)) as c") + .expectLogical( + "LogicalAggregate(group=[{}], c=[COUNT($0)])\n" + + " LogicalProject($f2=[CASE(AND(>($5, 2000), <($7, 30)), 1, null:NULL)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n") + .expectResult("c=5\n") + .expectSparkSQL( + "SELECT COUNT(CASE WHEN `SAL` > 2000 AND `DEPTNO` < 30 THEN 1 ELSE NULL END) `c`\n" + + "FROM `scott`.`EMP`"); + } + + @Test + public void testCountEvalStringComparison() { + withPPLQuery("source=EMP | stats count(eval(JOB = 'MANAGER')) as manager_count") + .expectLogical( + "LogicalAggregate(group=[{}], manager_count=[COUNT($0)])\n" + + " LogicalProject($f1=[CASE(=($2, 'MANAGER':VARCHAR), 1, null:NULL)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n") + .expectResult("manager_count=3\n") + .expectSparkSQL( + "SELECT COUNT(CASE WHEN `JOB` = 'MANAGER' THEN 1 ELSE NULL END) `manager_count`\n" + + "FROM `scott`.`EMP`"); + } + + @Test + public void testCountEvalArithmeticExpression() { + withPPLQuery("source=EMP | stats count(eval(SAL / COMM > 10)) as high_ratio") + .expectLogical( + "LogicalAggregate(group=[{}], high_ratio=[COUNT($0)])\n" + + " LogicalProject($f2=[CASE(>(DIVIDE($5, $6), 10), 1, null:NULL)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n") + .expectResult("high_ratio=0\n") + .expectSparkSQL( + "SELECT COUNT(CASE WHEN `DIVIDE`(`SAL`, `COMM`) > 10 THEN 1 ELSE NULL END)" + + " `high_ratio`\n" + + "FROM `scott`.`EMP`"); + } + + @Test + public void testCountEvalWithNullHandling() { + withPPLQuery("source=EMP | stats count(eval(isnotnull(MGR))) as non_null_mgr") + .expectLogical( + "LogicalAggregate(group=[{}], non_null_mgr=[COUNT($0)])\n" + + " LogicalProject($f1=[CASE(IS NOT NULL($3), 1, null:NULL)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n") + .expectResult("non_null_mgr=13\n") + .expectSparkSQL( + "SELECT COUNT(CASE WHEN `MGR` IS NOT NULL THEN 1 ELSE NULL END) `non_null_mgr`\n" + + "FROM `scott`.`EMP`"); + } +} From 06afcde6f9786ba3c6cfe6dbd62de2bc14ed0560 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Thu, 21 Aug 2025 11:13:14 -0700 Subject: [PATCH 3/7] Add count eval ITs Signed-off-by: Chen Dai --- .../remote/CalcitePPLAggregationIT.java | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAggregationIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAggregationIT.java index eb725c4eea4..1ac2f07539f 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAggregationIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAggregationIT.java @@ -811,4 +811,52 @@ public void testAggByByteNumberWithScript() throws IOException { TEST_INDEX_DATATYPE_NUMERIC)); verifyDataRows(response, rows(1, 4)); } + + @Test + public void testCountEvalSimpleCondition() throws IOException { + JSONObject actual = + executeQuery( + String.format("source=%s | stats count(eval(age > 30)) as c", TEST_INDEX_BANK)); + verifySchema(actual, schema("c", "bigint")); + verifyDataRows(actual, rows(6)); + } + + @Test + public void testCountEvalComplexCondition() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | stats count(eval(balance > 20000 and age < 35)) as c", + TEST_INDEX_BANK)); + verifySchema(actual, schema("c", "bigint")); + verifyDataRows(actual, rows(3)); + } + + @Test + public void testCountEvalGroupBy() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | stats count(eval(balance > 25000)) as high_balance by gender", + TEST_INDEX_BANK)); + verifySchema(actual, schema("gender", "string"), schema("high_balance", "bigint")); + verifyDataRows(actual, rows(3, "F"), rows(1, "M")); + } + + @Test + public void testCountEvalWithMultipleAggregations() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | stats count(eval(age > 30)) as mature_count, " + + "count(eval(balance > 25000)) as high_balance_count, " + + "count() as total_count", + TEST_INDEX_BANK)); + verifySchema( + actual, + schema("mature_count", "bigint"), + schema("high_balance_count", "bigint"), + schema("total_count", "bigint")); + verifyDataRows(actual, rows(6, 4, 7)); + } } From 06f3080e13cd7f6fa2d86740423741c5fe0c22e0 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Thu, 21 Aug 2025 13:41:33 -0700 Subject: [PATCH 4/7] Add count eval doctest Signed-off-by: Chen Dai --- docs/user/ppl/cmd/stats.rst | 22 ++++++++++++++++++- .../ppl/calcite/CalcitePPLCountEvalTest.java | 18 +++++++-------- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/docs/user/ppl/cmd/stats.rst b/docs/user/ppl/cmd/stats.rst index a2a6885c744..b92c5c5d2cd 100644 --- a/docs/user/ppl/cmd/stats.rst +++ b/docs/user/ppl/cmd/stats.rst @@ -80,7 +80,7 @@ COUNT Description >>>>>>>>>>> -Usage: Returns a count of the number of expr in the rows retrieved by a SELECT statement. +Usage: Returns a count of the number of expr in the rows retrieved. To perform a filtered counting, wrap the condition to satisfy in an `eval` expression. Example:: @@ -92,6 +92,26 @@ Example:: | 4 | +---------+ +Example of filtered counting:: + + os> source=accounts | stats count(eval(age > 30)) as mature_users; + fetched rows / total rows = 1/1 + +--------------+ + | mature_users | + |--------------| + | 3 | + +--------------+ + +Example of filtered counting with complex conditions:: + + os> source=accounts | stats count(eval(age > 30 and balance > 25000)) as high_value_users; + fetched rows / total rows = 1/1 + +------------------+ + | high_value_users | + |------------------| + | 2 | + +------------------+ + SUM --- diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCountEvalTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCountEvalTest.java index 2a07c78a91a..1579408e171 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCountEvalTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCountEvalTest.java @@ -59,15 +59,15 @@ public void testCountEvalStringComparison() { @Test public void testCountEvalArithmeticExpression() { withPPLQuery("source=EMP | stats count(eval(SAL / COMM > 10)) as high_ratio") - .expectLogical( - "LogicalAggregate(group=[{}], high_ratio=[COUNT($0)])\n" - + " LogicalProject($f2=[CASE(>(DIVIDE($5, $6), 10), 1, null:NULL)])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n") - .expectResult("high_ratio=0\n") - .expectSparkSQL( - "SELECT COUNT(CASE WHEN `DIVIDE`(`SAL`, `COMM`) > 10 THEN 1 ELSE NULL END)" - + " `high_ratio`\n" - + "FROM `scott`.`EMP`"); + .expectLogical( + "LogicalAggregate(group=[{}], high_ratio=[COUNT($0)])\n" + + " LogicalProject($f2=[CASE(>(DIVIDE($5, $6), 10), 1, null:NULL)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n") + .expectResult("high_ratio=0\n") + .expectSparkSQL( + "SELECT COUNT(CASE WHEN `DIVIDE`(`SAL`, `COMM`) > 10 THEN 1 ELSE NULL END)" + + " `high_ratio`\n" + + "FROM `scott`.`EMP`"); } @Test From ed488aaebf882e94b1502cf09bda309e6c40a660 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Thu, 21 Aug 2025 15:52:08 -0700 Subject: [PATCH 5/7] Fix doctest failure Signed-off-by: Chen Dai --- docs/user/ppl/cmd/stats.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user/ppl/cmd/stats.rst b/docs/user/ppl/cmd/stats.rst index b92c5c5d2cd..1d7c0ba157c 100644 --- a/docs/user/ppl/cmd/stats.rst +++ b/docs/user/ppl/cmd/stats.rst @@ -109,7 +109,7 @@ Example of filtered counting with complex conditions:: +------------------+ | high_value_users | |------------------| - | 2 | + | 1 | +------------------+ SUM From 93ad059869859dad574d3cf274e9ed9871261ed9 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Thu, 21 Aug 2025 17:07:53 -0700 Subject: [PATCH 6/7] Add more UT for AST builder Signed-off-by: Chen Dai --- .../ppl/parser/AstExpressionBuilderTest.java | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java index 28a941b8238..3a6600f20a4 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java @@ -12,6 +12,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.and; import static org.opensearch.sql.ast.dsl.AstDSL.argument; import static org.opensearch.sql.ast.dsl.AstDSL.booleanLiteral; +import static org.opensearch.sql.ast.dsl.AstDSL.caseWhen; import static org.opensearch.sql.ast.dsl.AstDSL.cast; import static org.opensearch.sql.ast.dsl.AstDSL.compare; import static org.opensearch.sql.ast.dsl.AstDSL.decimalLiteral; @@ -41,6 +42,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.sort; import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.unresolvedArg; +import static org.opensearch.sql.ast.dsl.AstDSL.when; import static org.opensearch.sql.ast.dsl.AstDSL.xor; import com.google.common.collect.ImmutableMap; @@ -546,6 +548,24 @@ public void testCountFuncCallExpr() { defaultStatsArgs())); } + @Test + public void testCountEvalFuncCallExpr() { + assertEqual( + "source=t | stats count(eval(a > 0)) by b", + agg( + relation("t"), + exprList( + alias( + "count(eval(a > 0))", + aggregate( + "count", + caseWhen( + null, when(compare(">", field("a"), intLiteral(0)), intLiteral(1)))))), + emptyList(), + exprList(alias("b", field("b"))), + defaultStatsArgs())); + } + @Test public void testDistinctCount() { assertEqual( From dcb8f8567fdf82d7ea7c574a2683287de40d811c Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Thu, 28 Aug 2025 11:48:15 -0700 Subject: [PATCH 7/7] Resolve conflicts and more changes for shortcut c Signed-off-by: Chen Dai --- docs/user/ppl/cmd/stats.rst | 10 ++++---- .../remote/CalcitePPLAggregationIT.java | 19 ++++++++++++++ ppl/src/main/antlr/OpenSearchPPLParser.g4 | 2 +- .../ppl/calcite/CalcitePPLCountEvalTest.java | 25 +++++++++++++++++++ 4 files changed, 50 insertions(+), 6 deletions(-) diff --git a/docs/user/ppl/cmd/stats.rst b/docs/user/ppl/cmd/stats.rst index a24a3c9109e..c041281c6f8 100644 --- a/docs/user/ppl/cmd/stats.rst +++ b/docs/user/ppl/cmd/stats.rst @@ -86,11 +86,11 @@ Example:: os> source=accounts | stats count(), c(); fetched rows / total rows = 1/1 - +---------+ - | count() | - |---------| - | 4 | - +---------+ + +---------+-----+ + | count() | c() | + |---------+-----| + | 4 | 4 | + +---------+-----+ Example of filtered counting:: diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAggregationIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAggregationIT.java index 5ed535b081c..681206c2590 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAggregationIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalcitePPLAggregationIT.java @@ -864,6 +864,25 @@ public void testCountEvalWithMultipleAggregations() throws IOException { verifyDataRows(actual, rows(6, 4, 7)); } + @Test + public void testShortcutCEvalSimpleCondition() throws IOException { + JSONObject actual = + executeQuery(String.format("source=%s | stats c(eval(age > 30)) as c", TEST_INDEX_BANK)); + verifySchema(actual, schema("c", "bigint")); + verifyDataRows(actual, rows(6)); + } + + @Test + public void testShortcutCEvalComplexCondition() throws IOException { + JSONObject actual = + executeQuery( + String.format( + "source=%s | stats c(eval(balance > 20000 and age < 35)) as c", TEST_INDEX_BANK)); + verifySchema(actual, schema("c", "bigint")); + verifyDataRows(actual, rows(3)); + } + + @Test public void testPercentileShortcuts() throws IOException { JSONObject actual = executeQuery( diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index e6fc38e4d7d..16de4c7eb9c 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -459,7 +459,7 @@ statsAggTerm // aggregation functions statsFunction : statsFunctionName LT_PRTHS valueExpression RT_PRTHS # statsFunctionCall - | COUNT LT_PRTHS evalExpression RT_PRTHS # countEvalFunctionCall + | (COUNT | C) LT_PRTHS evalExpression RT_PRTHS # countEvalFunctionCall | (COUNT | C) LT_PRTHS RT_PRTHS # countAllFunctionCall | PERCENTILE_SHORTCUT LT_PRTHS valueExpression RT_PRTHS # percentileShortcutFunctionCall | (DISTINCT_COUNT | DC | DISTINCT_COUNT_APPROX) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCountEvalTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCountEvalTest.java index 1579408e171..91d9a8de09d 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCountEvalTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLCountEvalTest.java @@ -82,4 +82,29 @@ public void testCountEvalWithNullHandling() { "SELECT COUNT(CASE WHEN `MGR` IS NOT NULL THEN 1 ELSE NULL END) `non_null_mgr`\n" + "FROM `scott`.`EMP`"); } + + @Test + public void testShortcutCEvalSimpleCondition() { + withPPLQuery("source=EMP | stats c(eval(SAL > 2000)) as c") + .expectLogical( + "LogicalAggregate(group=[{}], c=[COUNT($0)])\n" + + " LogicalProject($f1=[CASE(>($5, 2000), 1, null:NULL)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n") + .expectResult("c=6\n") + .expectSparkSQL( + "SELECT COUNT(CASE WHEN `SAL` > 2000 THEN 1 ELSE NULL END) `c`\nFROM `scott`.`EMP`"); + } + + @Test + public void testShortcutCEvalComplexCondition() { + withPPLQuery("source=EMP | stats c(eval(JOB = 'MANAGER')) as manager_count") + .expectLogical( + "LogicalAggregate(group=[{}], manager_count=[COUNT($0)])\n" + + " LogicalProject($f1=[CASE(=($2, 'MANAGER':VARCHAR), 1, null:NULL)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n") + .expectResult("manager_count=3\n") + .expectSparkSQL( + "SELECT COUNT(CASE WHEN `JOB` = 'MANAGER' THEN 1 ELSE NULL END) `manager_count`\n" + + "FROM `scott`.`EMP`"); + } }