From 78fa7e4a5872b6d3e42aeecd95b1e88e92c338a4 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 13 Mar 2024 19:56:04 -0700 Subject: [PATCH] Support for post filter in hybrid query (#633) (#636) * Post Filter for hybrid query Signed-off-by: Varun Jain * Add changelog Signed-off-by: Varun Jain * Addressing martin comments Signed-off-by: Varun Jain * Addressing martin comments Signed-off-by: Varun Jain * Addressing navneet comments Signed-off-by: Varun Jain * Addressing navneet comments Signed-off-by: Varun Jain * Addressing navneet comments Signed-off-by: Varun Jain * Adding Coverage Signed-off-by: Varun Jain --------- Signed-off-by: Varun Jain (cherry picked from commit d2d4cc68422e37d65285579d5579517531aec813) Co-authored-by: Varun Jain --- CHANGELOG.md | 1 + .../search/query/HybridCollectorManager.java | 47 +++- .../query/HybridQueryAggregationsIT.java | 239 ++++++++++++++++++ .../query/HybridCollectorManagerTests.java | 111 ++++++++ 4 files changed, 390 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c9d7110f..f240141ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements - Adding aggregations in hybrid query ([#630](https://github.com/opensearch-project/neural-search/pull/630)) +- Support for post filter in hybrid query ([#633](https://github.com/opensearch-project/neural-search/pull/633)) ### Bug Fixes - Fix typo for sparse encoding processor factory([#600](https://github.com/opensearch-project/neural-search/pull/600)) - Add non-null check for queryBuilder in NeuralQueryEnricherProcessor ([#619](https://github.com/opensearch-project/neural-search/pull/619)) diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java index a5de898ab..e9d97c3b3 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -11,10 +11,16 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.Weight; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreMode; +import org.opensearch.common.Nullable; +import org.opensearch.common.lucene.search.FilteredCollector; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.neuralsearch.search.HitsThresholdChecker; import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; import org.opensearch.search.DocValueFormat; +import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.MultiCollectorWrapper; import org.opensearch.search.query.QuerySearchResult; @@ -46,6 +52,9 @@ public abstract class HybridCollectorManager implements CollectorManager collectors) { } } else if (collector instanceof HybridTopScoreDocCollector) { hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) collector); - } + } else if (collector instanceof FilteredCollector + && ((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector) { + hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) ((FilteredCollector) collector).getCollector()); + } } if (!hybridTopScoreDocCollectors.isEmpty()) { @@ -216,9 +245,10 @@ public HybridCollectorNonConcurrentManager( HitsThresholdChecker hitsThresholdChecker, boolean isSingleShard, int trackTotalHitsUpTo, - SortAndFormats sortAndFormats + SortAndFormats sortAndFormats, + Weight filteringWeight ) { - super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats); + super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight); scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null"); } @@ -245,9 +275,10 @@ public HybridCollectorConcurrentSearchManager( HitsThresholdChecker hitsThresholdChecker, boolean isSingleShard, int trackTotalHitsUpTo, - SortAndFormats sortAndFormats + SortAndFormats sortAndFormats, + Weight filteringWeight ) { - super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats); + super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight); } } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java index e51a4562d..4647ebf5f 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java @@ -7,6 +7,7 @@ import lombok.SneakyThrows; import org.junit.Before; +import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; @@ -46,6 +47,7 @@ public class HybridQueryAggregationsIT extends BaseNeuralSearchIT { "test-neural-aggs-pipeline-multi-doc-index-multiple-shards"; private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD = "test-neural-aggs-multi-doc-index-single-shard"; private static final String TEST_QUERY_TEXT3 = "hello"; + private static final String TEST_QUERY_TEXT4 = "everyone"; private static final String TEST_QUERY_TEXT5 = "welcome"; private static final String TEST_DOC_TEXT1 = "Hello world"; private static final String TEST_DOC_TEXT2 = "Hi to this place"; @@ -182,6 +184,204 @@ public void testAggregationNotSupportedConcurrentSearch_whenUseSamplerAgg_thenSu } } + @SneakyThrows + public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchNotEnabled_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", false); + testPostFilterWithSimpleHybridQuery(false, true); + testPostFilterWithComplexHybridQuery(false, true); + } + + @SneakyThrows + public void testPostFilterOnIndexWithMultipleShards_WhenConcurrentSearchEnabled_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", true); + testPostFilterWithSimpleHybridQuery(false, true); + testPostFilterWithComplexHybridQuery(false, true); + } + + @SneakyThrows + private void testPostFilterWithSimpleHybridQuery(boolean isSingleShard, boolean hasPostFilterQuery) { + try { + if (isSingleShard) { + prepareResourcesForSingleShardIndex(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, SEARCH_PIPELINE); + } else { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + } + + HybridQueryBuilder simpleHybridQueryBuilder = createHybridQueryBuilder(false); + + QueryBuilder rangeFilterQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(2000).lte(5000); + + Map searchResponseAsMap; + + if (isSingleShard && hasPostFilterQuery) { + searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, + simpleHybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + rangeFilterQuery + ); + + assertHitResultsFromQuery(1, searchResponseAsMap); + } else if (isSingleShard && !hasPostFilterQuery) { + searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, + simpleHybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null + ); + assertHitResultsFromQuery(2, searchResponseAsMap); + } else if (!isSingleShard && hasPostFilterQuery) { + searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + simpleHybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + rangeFilterQuery + ); + assertHitResultsFromQuery(2, searchResponseAsMap); + } else { + searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + simpleHybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null + ); + assertHitResultsFromQuery(3, searchResponseAsMap); + } + + // assert post-filter + List> hitsNestedList = getNestedHits(searchResponseAsMap); + + List docIndexes = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + assertNotNull(oneHit.get("_source")); + Map source = (Map) oneHit.get("_source"); + int docIndex = (int) source.get(INTEGER_FIELD_1); + docIndexes.add(docIndex); + } + if (isSingleShard && hasPostFilterQuery) { + assertEquals(0, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count()); + + } else if (isSingleShard && !hasPostFilterQuery) { + assertEquals(1, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count()); + + } else if (!isSingleShard && hasPostFilterQuery) { + assertEquals(0, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count()); + } else { + assertEquals(1, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count()); + } + } finally { + if (isSingleShard) { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE); + } else { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + } + + @SneakyThrows + private void testPostFilterWithComplexHybridQuery(boolean isSingleShard, boolean hasPostFilterQuery) { + try { + if (isSingleShard) { + prepareResourcesForSingleShardIndex(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, SEARCH_PIPELINE); + } else { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + } + + HybridQueryBuilder complexHybridQueryBuilder = createHybridQueryBuilder(true); + + QueryBuilder rangeFilterQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(2000).lte(5000); + + Map searchResponseAsMap; + + if (isSingleShard && hasPostFilterQuery) { + searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, + complexHybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + rangeFilterQuery + ); + + assertHitResultsFromQuery(1, searchResponseAsMap); + } else if (isSingleShard && !hasPostFilterQuery) { + searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, + complexHybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null + ); + assertHitResultsFromQuery(2, searchResponseAsMap); + } else if (!isSingleShard && hasPostFilterQuery) { + searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + complexHybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + rangeFilterQuery + ); + assertHitResultsFromQuery(4, searchResponseAsMap); + } else { + searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + complexHybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + null + ); + assertHitResultsFromQuery(3, searchResponseAsMap); + } + + // assert post-filter + List> hitsNestedList = getNestedHits(searchResponseAsMap); + + List docIndexes = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + assertNotNull(oneHit.get("_source")); + Map source = (Map) oneHit.get("_source"); + int docIndex = (int) source.get(INTEGER_FIELD_1); + docIndexes.add(docIndex); + } + if (isSingleShard && hasPostFilterQuery) { + assertEquals(0, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count()); + + } else if (isSingleShard && !hasPostFilterQuery) { + assertEquals(1, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count()); + + } else if (!isSingleShard && hasPostFilterQuery) { + assertEquals(0, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count()); + } else { + assertEquals(1, docIndexes.stream().filter(docIndex -> docIndex < 2000 || docIndex > 5000).count()); + } + } finally { + if (isSingleShard) { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE); + } else { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + } + @SneakyThrows private void testAvgSumMinMaxAggs() { try { @@ -227,6 +427,20 @@ private void testAvgSumMinMaxAggs() { } } + @SneakyThrows + public void testPostFilterOnIndexWithSingleShards_WhenConcurrentSearchNotEnabled_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", false); + testPostFilterWithSimpleHybridQuery(true, true); + testPostFilterWithComplexHybridQuery(true, true); + } + + @SneakyThrows + public void testPostFilterOnIndexWithSingleShards_WhenConcurrentSearchEnabled_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", true); + testPostFilterWithSimpleHybridQuery(true, true); + testPostFilterWithComplexHybridQuery(true, true); + } + private void testMaxAggsOnSingleShardCluster() throws Exception { try { prepareResourcesForSingleShardIndex(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, SEARCH_PIPELINE); @@ -594,4 +808,29 @@ private void assertHitResultsFromQuery(int expected, Map searchR assertNotNull(total.get("relation")); assertEquals(RELATION_EQUAL_TO, total.get("relation")); } + + private HybridQueryBuilder createHybridQueryBuilder(boolean isComplex) { + if (isComplex) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.should().add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + + QueryBuilder rangeFilterQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1).gte(2000).lte(5000); + + QueryBuilder matchQuery = QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(boolQueryBuilder).add(rangeFilterQuery).add(matchQuery); + return hybridQueryBuilder; + + } else { + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder3); + return hybridQueryBuilderNeuralThenTerm; + } + } + } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index 65d6f3d8a..1fd67a7ae 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -22,14 +22,18 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; +import org.apache.lucene.search.Query; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.opensearch.common.lucene.search.FilteredCollector; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.BoostingQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.ParsedQuery; import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.neuralsearch.query.HybridQueryWeight; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; @@ -120,6 +124,94 @@ public void testNewCollector_whenConcurrentSearch_thenSuccessful() { assertNotSame(collector, secondCollector); } + @SneakyThrows + public void testPostFilter_whenNotConcurrentSearch_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); + ParsedQuery parsedQuery = new ParsedQuery(postFilterQuery.toQuery(mockQueryShardContext)); + searchContext.parsedQuery(parsedQuery); + + Query pfQuery = postFilterQuery.toQuery(mockQueryShardContext); + when(searchContext.parsedPostFilter()).thenReturn(parsedQuery); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + when(indexSearcher.rewrite(pfQuery)).thenReturn(pfQuery); + Weight weight = mock(Weight.class); + when(indexSearcher.createWeight(pfQuery, ScoreMode.COMPLETE_NO_SCORES, 1f)).thenReturn(weight); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof FilteredCollector); + assertTrue(((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertSame(collector, secondCollector); + assertTrue(((FilteredCollector) secondCollector).getCollector() instanceof HybridTopScoreDocCollector); + } + + @SneakyThrows + public void testPostFilter_whenConcurrentSearch_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, "world"); + Query pfQuery = postFilterQuery.toQuery(mockQueryShardContext); + ParsedQuery parsedQuery = new ParsedQuery(pfQuery); + searchContext.parsedQuery(parsedQuery); + + when(searchContext.parsedPostFilter()).thenReturn(parsedQuery); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + when(indexSearcher.rewrite(pfQuery)).thenReturn(pfQuery); + Weight weight = mock(Weight.class); + when(indexSearcher.createWeight(pfQuery, ScoreMode.COMPLETE_NO_SCORES, 1f)).thenReturn(weight); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorConcurrentSearchManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof FilteredCollector); + assertTrue(((FilteredCollector) collector).getCollector() instanceof HybridTopScoreDocCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertNotSame(collector, secondCollector); + assertTrue(((FilteredCollector) secondCollector).getCollector() instanceof HybridTopScoreDocCollector); + } + @SneakyThrows public void testReduce_whenMatchedDocs_thenSuccessful() { SearchContext searchContext = mock(SearchContext.class); @@ -166,17 +258,36 @@ public void testReduce_whenMatchedDocs_thenSuccessful() { CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); HybridTopScoreDocCollector collector = (HybridTopScoreDocCollector) hybridCollectorManager.newCollector(); + QueryBuilder postFilterQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + + Query pfQuery = postFilterQuery.toQuery(mockQueryShardContext); + ParsedQuery parsedQuery = new ParsedQuery(pfQuery); + searchContext.parsedQuery(parsedQuery); + when(searchContext.parsedPostFilter()).thenReturn(parsedQuery); + when(indexSearcher.rewrite(pfQuery)).thenReturn(pfQuery); + Weight postFilterWeight = mock(Weight.class); + when(indexSearcher.createWeight(pfQuery, ScoreMode.COMPLETE_NO_SCORES, 1f)).thenReturn(postFilterWeight); + + CollectorManager hybridCollectorManager1 = HybridCollectorManager.createHybridCollectorManager(searchContext); + FilteredCollector collector1 = (FilteredCollector) hybridCollectorManager1.newCollector(); + Weight weight = new HybridQueryWeight(hybridQueryWithTerm, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); collector.setWeight(weight); + collector1.setWeight(weight); LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); LeafCollector leafCollector = collector.getLeafCollector(leafReaderContext); + LeafCollector leafCollector1 = collector1.getLeafCollector(leafReaderContext); BulkScorer scorer = weight.bulkScorer(leafReaderContext); scorer.score(leafCollector, leafReaderContext.reader().getLiveDocs()); leafCollector.finish(); + scorer.score(leafCollector1, leafReaderContext.reader().getLiveDocs()); + leafCollector1.finish(); Object results = hybridCollectorManager.reduce(List.of()); + Object results1 = hybridCollectorManager1.reduce(List.of()); assertNotNull(results); + assertNotNull(results1); ReduceableSearchResult reduceableSearchResult = ((ReduceableSearchResult) results); QuerySearchResult querySearchResult = new QuerySearchResult(); reduceableSearchResult.reduce(querySearchResult);