Skip to content

Commit

Permalink
Removing code to cut search results of hybrid search in the priority …
Browse files Browse the repository at this point in the history
…queue (#867)

* Removing code to cut results in the priority queue

Signed-off-by: Varun Jain <varunudr@amazon.com>

* throw exception when from is not equal to 0

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Adding TODO check

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Move pagination validation of from condition to query phase searcher

Signed-off-by: Varun Jain <varunudr@amazon.com>

* Adding integ test

Signed-off-by: Varun Jain <varunudr@amazon.com>

---------

Signed-off-by: Varun Jain <varunudr@amazon.com>
  • Loading branch information
vibrantvarun authored Sep 4, 2024
1 parent 3a6bdc7 commit b8e2b35
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,6 @@ private TopDocs topDocsPerQuery(int start, int howMany, PriorityQueue<ScoreDoc>

int size = howMany - start;
ScoreDoc[] results = new ScoreDoc[size];
// pq's pop() returns the 'least' element in the queue, therefore need
// to discard the first ones, until we reach the requested range.
for (int i = pq.size() - start - size; i > 0; i--) {
pq.pop();
}

// Get the requested results from pq.
populateResults(results, size, pq);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ public boolean searchWith(
validateQuery(searchContext, query);
return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
} else {
// TODO remove this check after following issue https://github.com/opensearch-project/neural-search/issues/280 gets resolved.
if (searchContext.from() != 0) {
throw new IllegalArgumentException("In the current OpenSearch version pagination is not supported with hybrid query");
}
Query hybridQuery = extractHybridQuery(searchContext, query);
QueryPhaseSearcher queryPhaseSearcher = getQueryPhaseSearcher(searchContext);
return queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ private void testPostFilterWithSimpleHybridQuery(boolean isSingleShard, boolean
rangeFilterQuery,
null,
false,
null
null,
0
);

assertHitResultsFromQuery(1, searchResponseAsMap);
Expand All @@ -230,7 +231,8 @@ private void testPostFilterWithSimpleHybridQuery(boolean isSingleShard, boolean
null,
null,
false,
null
null,
0
);
assertHitResultsFromQuery(2, searchResponseAsMap);
} else if (!isSingleShard && hasPostFilterQuery) {
Expand All @@ -244,7 +246,8 @@ private void testPostFilterWithSimpleHybridQuery(boolean isSingleShard, boolean
rangeFilterQuery,
null,
false,
null
null,
0
);
assertHitResultsFromQuery(2, searchResponseAsMap);
} else {
Expand All @@ -258,7 +261,8 @@ private void testPostFilterWithSimpleHybridQuery(boolean isSingleShard, boolean
null,
null,
false,
null
null,
0
);
assertHitResultsFromQuery(3, searchResponseAsMap);
}
Expand Down Expand Up @@ -319,7 +323,8 @@ private void testPostFilterWithComplexHybridQuery(boolean isSingleShard, boolean
rangeFilterQuery,
null,
false,
null
null,
0
);

assertHitResultsFromQuery(1, searchResponseAsMap);
Expand All @@ -334,7 +339,8 @@ private void testPostFilterWithComplexHybridQuery(boolean isSingleShard, boolean
null,
null,
false,
null
null,
0
);
assertHitResultsFromQuery(2, searchResponseAsMap);
} else if (!isSingleShard && hasPostFilterQuery) {
Expand All @@ -348,7 +354,8 @@ private void testPostFilterWithComplexHybridQuery(boolean isSingleShard, boolean
rangeFilterQuery,
null,
false,
null
null,
0
);
assertHitResultsFromQuery(4, searchResponseAsMap);
} else {
Expand All @@ -362,7 +369,8 @@ private void testPostFilterWithComplexHybridQuery(boolean isSingleShard, boolean
null,
null,
false,
null
null,
0
);
assertHitResultsFromQuery(3, searchResponseAsMap);
}
Expand Down
40 changes: 40 additions & 0 deletions src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,46 @@ public void testConcurrentSearchWithMultipleSlices_whenMultipleShardsIndex_thenS
}
}

// TODO remove this test after following issue https://github.com/opensearch-project/neural-search/issues/280 gets resolved.
@SneakyThrows
public void testHybridQuery_whenFromIsSetInSearchRequest_thenFail() {
try {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD);
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);
MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
HybridQueryBuilder hybridQueryBuilderOnlyTerm = new HybridQueryBuilder();
hybridQueryBuilderOnlyTerm.add(matchQueryBuilder);

ResponseException exceptionNoNestedTypes = expectThrows(
ResponseException.class,
() -> search(
TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD,
hybridQueryBuilderOnlyTerm,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
null,
null,
null,
false,
null,
10
)

);

org.hamcrest.MatcherAssert.assertThat(
exceptionNoNestedTypes.getMessage(),
allOf(
containsString("In the current OpenSearch version pagination is not supported with hybrid query"),
containsString("illegal_argument_exception")
)
);
} finally {
wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD, null, null, SEARCH_PIPELINE);
}
}

@SneakyThrows
private void initializeIndexIfNotExist(String indexName) throws IOException {
if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ private void testPostFilterRangeQuery(String indexName) {
postFilterQuery,
null,
false,
null
null,
0
);
assertHybridQueryResults(searchResponseAsMap, 1, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY);
}
Expand Down Expand Up @@ -262,7 +263,8 @@ private void testPostFilterBoolQuery(String indexName) {
postFilterQuery,
null,
false,
null
null,
0
);
assertHybridQueryResults(searchResponseAsMap, 2, 1, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY);
// Case 2 A Query with a combination of hybrid query (Match Query, Term Query, Range Query), aggregation (Average stock price
Expand All @@ -278,7 +280,8 @@ private void testPostFilterBoolQuery(String indexName) {
postFilterQuery,
null,
false,
null
null,
0
);
assertHybridQueryResults(searchResponseAsMap, 2, 1, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY);
Map<String, Object> aggregations = getAggregations(searchResponseAsMap);
Expand All @@ -303,7 +306,8 @@ private void testPostFilterBoolQuery(String indexName) {
postFilterQuery,
null,
false,
null
null,
0
);
assertHybridQueryResults(searchResponseAsMap, 0, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY);
// Case 4 A Query with a combination of hybrid query (Match Query, Range Query) and a post filter query (Bool Query with a should
Expand All @@ -324,7 +328,8 @@ private void testPostFilterBoolQuery(String indexName) {
postFilterQuery,
null,
false,
null
null,
0
);
assertHybridQueryResults(searchResponseAsMap, 0, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY);
}
Expand Down Expand Up @@ -382,7 +387,8 @@ private void testPostFilterMatchAllAndMatchNoneQueries(String indexName) {
postFilterQuery,
null,
false,
null
null,
0
);
assertHybridQueryResults(searchResponseAsMap, 4, 3, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY);

Expand All @@ -399,7 +405,8 @@ private void testPostFilterMatchAllAndMatchNoneQueries(String indexName) {
postFilterQuery,
null,
false,
null
null,
0
);
assertHybridQueryResults(searchResponseAsMap, 0, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ private void testSingleFieldSort_whenMultipleSubQueries_thenSuccessful(String in
null,
createSortBuilders(fieldSortOrderMap, false),
false,
null
null,
0
);
List<Map<String, Object>> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 6, 6);
assertStockValueWithSortOrderInHybridQueryResults(nestedHits, SortOrder.DESC, LARGEST_STOCK_VALUE_IN_QUERY_RESULT, true, true);
Expand Down Expand Up @@ -168,7 +169,8 @@ private void testMultipleFieldSort_whenMultipleSubQueries_thenSuccessful(String
null,
createSortBuilders(fieldSortOrderMap, false),
false,
null
null,
0
);
List<Map<String, Object>> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 6, 6);
assertStockValueWithSortOrderInHybridQueryResults(nestedHits, SortOrder.DESC, LARGEST_STOCK_VALUE_IN_QUERY_RESULT, true, false);
Expand Down Expand Up @@ -200,7 +202,8 @@ public void testSingleFieldSort_whenTrackScoresIsEnabled_thenFail() {
null,
createSortBuilders(fieldSortOrderMap, false),
true,
null
null,
0
)
);
} finally {
Expand Down Expand Up @@ -234,7 +237,8 @@ public void testSingleFieldSort_whenSortCriteriaIsByScoreAndField_thenFail() {
null,
createSortBuilders(fieldSortOrderMap, false),
true,
null
null,
0
)
);
} finally {
Expand Down Expand Up @@ -312,7 +316,8 @@ private void testSearchAfter_whenSingleFieldSort_thenSuccessful(String indexName
null,
createSortBuilders(fieldSortOrderMap, false),
false,
searchAfter
searchAfter,
0
);
List<Map<String, Object>> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 3, 6);
assertStockValueWithSortOrderInHybridQueryResults(
Expand Down Expand Up @@ -348,7 +353,8 @@ private void testSearchAfter_whenMultipleFieldSort_thenSuccessful(String indexNa
null,
createSortBuilders(fieldSortOrderMap, false),
false,
searchAfter
searchAfter,
0
);
List<Map<String, Object>> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 5, 6);
assertStockValueWithSortOrderInHybridQueryResults(
Expand Down Expand Up @@ -381,7 +387,8 @@ private void testScoreSort_whenSingleFieldSort_thenSuccessful(String indexName)
null,
createSortBuilders(fieldSortOrderMap, false),
false,
null
null,
0
);
List<Map<String, Object>> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 6, 6);
assertScoreWithSortOrderInHybridQueryResults(nestedHits, SortOrder.DESC, 1.0);
Expand Down Expand Up @@ -415,7 +422,8 @@ public void testSort_whenSortFieldsSizeNotEqualToSearchAfterSize_thenFail() {
null,
createSortBuilders(fieldSortOrderMap, false),
true,
searchAfter
searchAfter,
0
)
);
} finally {
Expand Down Expand Up @@ -450,7 +458,8 @@ public void testSearchAfter_whenAfterFieldIsNotPassed_thenFail() {
null,
createSortBuilders(fieldSortOrderMap, false),
true,
searchAfter
searchAfter,
0
)
);
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,8 @@ private void testSumAggsAndRangePostFilter() throws IOException {
rangeFilterQuery,
null,
false,
null
null,
0
);

Map<String, Object> aggregations = getAggregations(searchResponseAsMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ protected Map<String, Object> search(
Map<String, String> requestParams,
List<Object> aggs
) {
return search(index, queryBuilder, rescorer, resultSize, requestParams, aggs, null, null, false, null);
return search(index, queryBuilder, rescorer, resultSize, requestParams, aggs, null, null, false, null, 0);
}

@SneakyThrows
Expand All @@ -542,10 +542,11 @@ protected Map<String, Object> search(
QueryBuilder postFilterBuilder,
List<SortBuilder<?>> sortBuilders,
boolean trackScores,
List<Object> searchAfter
List<Object> searchAfter,
int from
) {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();

builder.field("from", from);
if (queryBuilder != null) {
builder.field("query");
queryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS);
Expand Down

0 comments on commit b8e2b35

Please sign in to comment.