diff --git a/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java b/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java index 1c78df704c..2d98be6bbb 100644 --- a/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java +++ b/core/src/main/java/org/opensearch/sql/executor/ExecutionEngine.java @@ -54,7 +54,7 @@ void execute(PhysicalPlan plan, ExecutionContext context, class QueryResponse { private final Schema schema; private final List results; - + private final long total; private final Cursor cursor; } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlan.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlan.java index 5890b6f15f..c52e658e41 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlan.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlan.java @@ -50,6 +50,10 @@ public ExecutionEngine.Schema schema() { + "ProjectOperator, instead of %s", this.getClass().getSimpleName())); } + public long getTotalHits() { + return getChild().stream().mapToLong(PhysicalPlan::getTotalHits).max().orElse(0); + } + public String toCursor() { throw new IllegalStateException(String.format("%s is not compatible with cursor feature", this.getClass().getSimpleName())); diff --git a/core/src/test/java/org/opensearch/sql/executor/QueryServiceTest.java b/core/src/test/java/org/opensearch/sql/executor/QueryServiceTest.java index a0be4f8f2e..69c819398c 100644 --- a/core/src/test/java/org/opensearch/sql/executor/QueryServiceTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/QueryServiceTest.java @@ -133,7 +133,7 @@ Helper executeSuccess(Split split) { invocation -> { ResponseListener listener = invocation.getArgument(2); listener.onResponse( - new ExecutionEngine.QueryResponse(schema, Collections.emptyList(), + new ExecutionEngine.QueryResponse(schema, Collections.emptyList(), 0, Cursor.None)); return null; }) diff --git a/core/src/test/java/org/opensearch/sql/executor/streaming/MicroBatchStreamingExecutionTest.java b/core/src/test/java/org/opensearch/sql/executor/streaming/MicroBatchStreamingExecutionTest.java index 75a6238530..f97f2b5f91 100644 --- a/core/src/test/java/org/opensearch/sql/executor/streaming/MicroBatchStreamingExecutionTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/streaming/MicroBatchStreamingExecutionTest.java @@ -170,7 +170,7 @@ Helper executeSuccess(Long... offsets) { ResponseListener listener = invocation.getArgument(2); listener.onResponse( - new ExecutionEngine.QueryResponse(null, Collections.emptyList(), + new ExecutionEngine.QueryResponse(null, Collections.emptyList(), 0, Cursor.None)); PlanContext planContext = invocation.getArgument(1); diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTest.java index 0a93c96bbb..c6759b7f7c 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTest.java @@ -5,9 +5,16 @@ package org.opensearch.sql.planner.physical; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.CALLS_REAL_METHODS; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.util.List; + +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -16,6 +23,7 @@ import org.opensearch.sql.storage.split.Split; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class PhysicalPlanTest { @Mock Split split; @@ -46,8 +54,25 @@ public List getChild() { }; @Test - void addSplitToChildByDefault() { + void add_split_to_child_by_default() { testPlan.add(split); verify(child).add(split); } + + @Test + void get_total_hits_from_child() { + var plan = mock(PhysicalPlan.class); + when(child.getTotalHits()).thenReturn(42L); + when(plan.getChild()).thenReturn(List.of(child)); + when(plan.getTotalHits()).then(CALLS_REAL_METHODS); + assertEquals(42, plan.getTotalHits()); + verify(child).getTotalHits(); + } + + @Test + void get_total_hits_uses_default_value() { + var plan = mock(PhysicalPlan.class); + when(plan.getTotalHits()).then(CALLS_REAL_METHODS); + assertEquals(0, plan.getTotalHits()); + } } diff --git a/core/src/testFixtures/java/org/opensearch/sql/executor/DefaultExecutionEngine.java b/core/src/testFixtures/java/org/opensearch/sql/executor/DefaultExecutionEngine.java index 1805830271..00e02eb433 100644 --- a/core/src/testFixtures/java/org/opensearch/sql/executor/DefaultExecutionEngine.java +++ b/core/src/testFixtures/java/org/opensearch/sql/executor/DefaultExecutionEngine.java @@ -34,7 +34,7 @@ public void execute( result.add(plan.next()); } QueryResponse response = new QueryResponse(new Schema(new ArrayList<>()), new ArrayList<>(), - Cursor.None); + 0, Cursor.None); listener.onResponse(response); } catch (Exception e) { listener.onFailure(e); diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java index bd9387a68e..4f9fdd9a53 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSQLQueryAction.java @@ -172,7 +172,8 @@ private ResponseListener createQueryResponseListener( @Override public void onResponse(QueryResponse response) { sendResponse(channel, OK, - formatter.format(new QueryResult(response.getSchema(), response.getResults(), response.getCursor()))); + formatter.format(new QueryResult(response.getSchema(), response.getResults(), + response.getCursor(), response.getTotal()))); } @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java index cec2864c11..b1b32821f2 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/OpenSearchExecutionEngine.java @@ -53,7 +53,7 @@ public void execute(PhysicalPlan physicalPlan, ExecutionContext context, Cursor qc = paginatedPlanCache.convertToCursor(plan); - QueryResponse response = new QueryResponse(physicalPlan.schema(), result, qc); + QueryResponse response = new QueryResponse(physicalPlan.schema(), result, plan.getTotalHits(), qc); listener.onResponse(response); } catch (Exception e) { listener.onFailure(e); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/ResourceMonitorPlan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/ResourceMonitorPlan.java index 307b40dce7..3d880d82b9 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/ResourceMonitorPlan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/ResourceMonitorPlan.java @@ -83,6 +83,11 @@ public ExprValue next() { return delegate.next(); } + @Override + public long getTotalHits() { + return delegate.getTotalHits(); + } + @Override public String toCursor() { return delegate.toCursor(); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java index c2042d7c0e..74aa07fccb 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/OpenSearchResponse.java @@ -71,7 +71,10 @@ public OpenSearchResponse(SearchHits hits, OpenSearchExprValueFactory exprValueF */ public boolean isEmpty() { return (hits.getHits() == null) || (hits.getHits().length == 0) && aggregations == null; - // TODO TBD ^ ^ + } + + public long getTotalHits() { + return hits.getTotalHits().value; } public boolean isAggregationResponse() { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java index 3ae2e62cfd..f9a420332d 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScan.java @@ -84,6 +84,12 @@ public ExprValue next() { return iterator.next(); } + @Override + public long getTotalHits() { + // TODO maybe store totalHits from `response` + return queryCount; + } + private void fetchNextBatch() { OpenSearchResponse response = client.search(request); if (!response.isEmpty()) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScan.java index 1dc455cfd2..6626af6e9e 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScan.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScan.java @@ -27,6 +27,7 @@ public class OpenSearchPagedIndexScan extends TableScanOperator { private OpenSearchRequest request; private Iterator iterator; private boolean needClean = false; + private long totalHits = 0; public OpenSearchPagedIndexScan(OpenSearchClient client, PagedRequestBuilder requestBuilder) { @@ -56,6 +57,7 @@ public void open() { OpenSearchResponse response = client.search(request); if (!response.isEmpty()) { iterator = response.iterator(); + totalHits = response.getTotalHits(); } else { needClean = true; iterator = Collections.emptyIterator(); @@ -71,6 +73,11 @@ public void close() { } } + @Override + public long getTotalHits() { + return totalHits; + } + @Override public String toCursor() { // TODO this assumes exactly one index is scanned. diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/ResourceMonitorPlanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/ResourceMonitorPlanTest.java index d4d987a7df..b111047b6f 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/ResourceMonitorPlanTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/ResourceMonitorPlanTest.java @@ -107,4 +107,10 @@ void acceptSuccess() { monitorPlan.accept(visitor, context); verify(plan, times(1)).accept(visitor, context); } + + @Test + void getTotalHitsSuccess() { + monitorPlan.getTotalHits(); + verify(plan, times(1)).getTotalHits(); + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java index 0a60503415..2d1d6145f3 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchResponseTest.java @@ -74,20 +74,29 @@ void isEmpty() { new TotalHits(2L, TotalHits.Relation.EQUAL_TO), 1.0F)); - assertFalse(new OpenSearchResponse(searchResponse, factory).isEmpty()); + var response = new OpenSearchResponse(searchResponse, factory); + assertFalse(response.isEmpty()); + assertEquals(2L, response.getTotalHits()); when(searchResponse.getHits()).thenReturn(SearchHits.empty()); when(searchResponse.getAggregations()).thenReturn(null); - assertTrue(new OpenSearchResponse(searchResponse, factory).isEmpty()); + + response = new OpenSearchResponse(searchResponse, factory); + assertTrue(response.isEmpty()); + assertEquals(0L, response.getTotalHits()); when(searchResponse.getHits()) .thenReturn(new SearchHits(null, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0)); - OpenSearchResponse response3 = new OpenSearchResponse(searchResponse, factory); - assertTrue(response3.isEmpty()); + response = new OpenSearchResponse(searchResponse, factory); + assertTrue(response.isEmpty()); + assertEquals(0L, response.getTotalHits()); when(searchResponse.getHits()).thenReturn(SearchHits.empty()); when(searchResponse.getAggregations()).thenReturn(new Aggregations(emptyList())); - assertFalse(new OpenSearchResponse(searchResponse, factory).isEmpty()); + + response = new OpenSearchResponse(searchResponse, factory); + assertFalse(response.isEmpty()); + assertEquals(0L, response.getTotalHits()); } @Test @@ -104,7 +113,8 @@ void iterator() { when(factory.construct(any())).thenReturn(exprTupleValue1).thenReturn(exprTupleValue2); int i = 0; - for (ExprValue hit : new OpenSearchResponse(searchResponse, factory)) { + var response = new OpenSearchResponse(searchResponse, factory); + for (ExprValue hit : response) { if (i == 0) { assertEquals(exprTupleValue1, hit); } else if (i == 1) { @@ -114,6 +124,7 @@ void iterator() { } i++; } + assertEquals(2L, response.getTotalHits()); } @Test diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java index 90ad624135..d93f4729b8 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanTest.java @@ -6,10 +6,12 @@ package org.opensearch.sql.opensearch.storage.scan; +import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -22,6 +24,8 @@ import java.util.HashMap; import java.util.Map; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; @@ -47,6 +51,7 @@ import org.opensearch.sql.opensearch.response.OpenSearchResponse; @ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class OpenSearchIndexScanTest { @Mock @@ -64,19 +69,22 @@ void setup() { } @Test - void queryEmptyResult() { + void query_empty_result() { mockResponse(client); try (OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, new OpenSearchRequestBuilder("test", 3, settings, exprValueFactory))) { indexScan.open(); - assertFalse(indexScan.hasNext()); + assertAll( + () -> assertFalse(indexScan.hasNext()), + () -> assertEquals(0, indexScan.getTotalHits()) + ); } verify(client).cleanup(any()); } @Test - void queryAllResultsWithQuery() { + void query_all_results_with_query() { mockResponse(client, new ExprValue[]{ employee(1, "John", "IT"), employee(2, "Smith", "HR"), @@ -89,22 +97,25 @@ void queryAllResultsWithQuery() { new OpenSearchIndexScan(client, builder)) { indexScan.open(); - assertTrue(indexScan.hasNext()); - assertEquals(employee(1, "John", "IT"), indexScan.next()); + assertAll( + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), - assertTrue(indexScan.hasNext()); - assertEquals(employee(2, "Smith", "HR"), indexScan.next()); + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), - assertTrue(indexScan.hasNext()); - assertEquals(employee(3, "Allen", "IT"), indexScan.next()); + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), - assertFalse(indexScan.hasNext()); + () -> assertFalse(indexScan.hasNext()), + () -> assertEquals(3, indexScan.getTotalHits()) + ); } verify(client).cleanup(any()); } @Test - void queryAllResultsWithScroll() { + void query_all_results_with_scroll() { mockResponse(client, new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, new ExprValue[]{employee(3, "Allen", "IT")}); @@ -114,22 +125,25 @@ void queryAllResultsWithScroll() { exprValueFactory))) { indexScan.open(); - assertTrue(indexScan.hasNext()); - assertEquals(employee(1, "John", "IT"), indexScan.next()); + assertAll( + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), - assertTrue(indexScan.hasNext()); - assertEquals(employee(2, "Smith", "HR"), indexScan.next()); + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), - assertTrue(indexScan.hasNext()); - assertEquals(employee(3, "Allen", "IT"), indexScan.next()); + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), - assertFalse(indexScan.hasNext()); + () -> assertFalse(indexScan.hasNext()), + () -> assertEquals(3, indexScan.getTotalHits()) + ); } verify(client).cleanup(any()); } @Test - void querySomeResultsWithQuery() { + void query_some_results_with_query() { mockResponse(client, new ExprValue[]{ employee(1, "John", "IT"), employee(2, "Smith", "HR"), @@ -142,22 +156,25 @@ void querySomeResultsWithQuery() { indexScan.getRequestBuilder().pushDownLimit(3, 0); indexScan.open(); - assertTrue(indexScan.hasNext()); - assertEquals(employee(1, "John", "IT"), indexScan.next()); + assertAll( + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), - assertTrue(indexScan.hasNext()); - assertEquals(employee(2, "Smith", "HR"), indexScan.next()); + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), - assertTrue(indexScan.hasNext()); - assertEquals(employee(3, "Allen", "IT"), indexScan.next()); + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), - assertFalse(indexScan.hasNext()); + () -> assertFalse(indexScan.hasNext()), + () -> assertEquals(3, indexScan.getTotalHits()) + ); } verify(client).cleanup(any()); } @Test - void querySomeResultsWithScroll() { + void query_some_results_with_scroll() { mockResponse(client, new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, new ExprValue[]{employee(3, "Allen", "IT"), employee(4, "Bob", "HR")}); @@ -168,22 +185,25 @@ void querySomeResultsWithScroll() { indexScan.getRequestBuilder().pushDownLimit(3, 0); indexScan.open(); - assertTrue(indexScan.hasNext()); - assertEquals(employee(1, "John", "IT"), indexScan.next()); + assertAll( + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(1, "John", "IT"), indexScan.next()), - assertTrue(indexScan.hasNext()); - assertEquals(employee(2, "Smith", "HR"), indexScan.next()); + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(2, "Smith", "HR"), indexScan.next()), - assertTrue(indexScan.hasNext()); - assertEquals(employee(3, "Allen", "IT"), indexScan.next()); + () -> assertTrue(indexScan.hasNext()), + () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), - assertFalse(indexScan.hasNext()); + () -> assertFalse(indexScan.hasNext()), + () -> assertEquals(3, indexScan.getTotalHits()) + ); } verify(client).cleanup(any()); } @Test - void pushDownFilters() { + void push_down_filters() { assertThat() .pushDown(QueryBuilders.termQuery("name", "John")) .shouldQuery(QueryBuilders.termQuery("name", "John")) @@ -201,7 +221,7 @@ void pushDownFilters() { } @Test - void pushDownHighlight() { + void push_down_highlight() { Map args = new HashMap<>(); assertThat() .pushDown(QueryBuilders.termQuery("name", "John")) @@ -212,7 +232,7 @@ void pushDownHighlight() { } @Test - void pushDownHighlightWithArguments() { + void push_down_highlight_with_arguments() { Map args = new HashMap<>(); args.put("pre_tags", new Literal("", DataType.STRING)); args.put("post_tags", new Literal("", DataType.STRING)); @@ -227,7 +247,7 @@ void pushDownHighlightWithArguments() { } @Test - void pushDownHighlightWithRepeatingFields() { + void push_down_highlight_with_repeating_fields() { mockResponse(client, new ExprValue[]{employee(1, "John", "IT"), employee(2, "Smith", "HR")}, new ExprValue[]{employee(3, "Allen", "IT"), employee(4, "Bob", "HR")}); @@ -314,6 +334,9 @@ public OpenSearchResponse answer(InvocationOnMock invocation) { when(response.isEmpty()).thenReturn(false); ExprValue[] searchHit = searchHitBatches[batchNum]; when(response.iterator()).thenReturn(Arrays.asList(searchHit).iterator()); + // used in OpenSearchPagedIndexScanTest + lenient().when(response.getTotalHits()) + .thenReturn((long) searchHitBatches[batchNum].length); } else { when(response.isEmpty()).thenReturn(true); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScanTest.java index c13e63f01a..fc77ad4ff7 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScanTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchPagedIndexScanTest.java @@ -82,7 +82,8 @@ void query_all_results_initial_scroll_request() { () -> assertTrue(indexScan.hasNext()), () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), - () -> assertFalse(indexScan.hasNext()) + () -> assertFalse(indexScan.hasNext()), + () -> assertEquals(3, indexScan.getTotalHits()) ); } // cleanup should be called on empty response only @@ -120,7 +121,8 @@ void query_all_results_continuation_scroll_request() { () -> assertTrue(indexScan.hasNext()), () -> assertEquals(employee(3, "Allen", "IT"), indexScan.next()), - () -> assertFalse(indexScan.hasNext()) + () -> assertFalse(indexScan.hasNext()), + () -> assertEquals(3, indexScan.getTotalHits()) ); } // cleanup should be called on empty response only diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java index d9dad9d535..a67e077ecc 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryAction.java @@ -140,7 +140,7 @@ private ResponseListener createListener( public void onResponse(ExecutionEngine.QueryResponse response) { String responseContent = formatter.format(new QueryResult(response.getSchema(), response.getResults(), - response.getCursor())); + response.getCursor(), response.getTotal())); listener.onResponse(new TransportPPLQueryResponse(responseContent)); } diff --git a/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java b/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java index 3ea5846b87..90f422c81d 100644 --- a/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java +++ b/protocol/src/main/java/org/opensearch/sql/protocol/response/QueryResult.java @@ -36,9 +36,12 @@ public class QueryResult implements Iterable { @Getter private final Cursor cursor; + @Getter + private final long total; + public QueryResult(ExecutionEngine.Schema schema, Collection exprValues) { - this(schema, exprValues, Cursor.None); + this(schema, exprValues, Cursor.None, exprValues.size()); } /** diff --git a/protocol/src/main/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatter.java b/protocol/src/main/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatter.java index f52ee22246..eb4de495c1 100644 --- a/protocol/src/main/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatter.java +++ b/protocol/src/main/java/org/opensearch/sql/protocol/response/format/JdbcResponseFormatter.java @@ -40,7 +40,7 @@ protected Object buildJsonObject(QueryResult response) { json.datarows(fetchDataRows(response)); // Populate other fields - json.total(response.size()) + json.total(response.getTotal()) .size(response.size()) .status(200); if (!response.getCursor().equals(Cursor.None)) { diff --git a/protocol/src/test/java/org/opensearch/sql/protocol/response/QueryResultTest.java b/protocol/src/test/java/org/opensearch/sql/protocol/response/QueryResultTest.java index 3db405339b..2bfbe55278 100644 --- a/protocol/src/test/java/org/opensearch/sql/protocol/response/QueryResultTest.java +++ b/protocol/src/test/java/org/opensearch/sql/protocol/response/QueryResultTest.java @@ -36,7 +36,7 @@ void size() { tupleValue(ImmutableMap.of("name", "John", "age", 20)), tupleValue(ImmutableMap.of("name", "Allen", "age", 30)), tupleValue(ImmutableMap.of("name", "Smith", "age", 40)) - ), Cursor.None); + ), Cursor.None, 0); assertEquals(3, response.size()); } @@ -46,7 +46,7 @@ void columnNameTypes() { schema, Collections.singletonList( tupleValue(ImmutableMap.of("name", "John", "age", 20)) - ), Cursor.None); + ), Cursor.None, 0); assertEquals( ImmutableMap.of("name", "string", "age", "integer"), @@ -61,7 +61,7 @@ void columnNameTypesWithAlias() { QueryResult response = new QueryResult( schema, Collections.singletonList(tupleValue(ImmutableMap.of("n", "John"))), - Cursor.None); + Cursor.None, 0); assertEquals( ImmutableMap.of("n", "string"), @@ -73,7 +73,7 @@ void columnNameTypesWithAlias() { void columnNameTypesFromEmptyExprValues() { QueryResult response = new QueryResult( schema, - Collections.emptyList(), Cursor.None); + Collections.emptyList(), Cursor.None, 0); assertEquals( ImmutableMap.of("name", "string", "age", "integer"), response.columnNameTypes() @@ -102,7 +102,7 @@ void iterate() { Arrays.asList( tupleValue(ImmutableMap.of("name", "John", "age", 20)), tupleValue(ImmutableMap.of("name", "Allen", "age", 30)) - ), Cursor.None); + ), Cursor.None, 0); int i = 0; for (Object[] objects : response) {