Skip to content

Commit

Permalink
Initial version for rescorer
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski committed Sep 30, 2024
1 parent d03e69b commit fcf7f51
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 41 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.17...2.x)
### Features
### Enhancements
- Added rescorer in hybrid query ([#917](https://github.com/opensearch-project/neural-search/pull/917))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import org.opensearch.index.query.MatchQueryBuilder;
import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER;
import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS;
Expand All @@ -17,6 +19,8 @@
import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD;
import static org.opensearch.neuralsearch.util.TestUtils.getModelId;

import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
Expand All @@ -31,6 +35,8 @@ public class HybridSearchIT extends AbstractRollingUpgradeTestCase {
private static final String TEXT_UPGRADED = "Hi earth";
private static final String QUERY = "Hi world";
private static final int NUM_DOCS_PER_ROUND = 1;
private static final String VECTOR_EMBEDDING_FIELD = "passage_embedding";
protected static final String RESCORE_QUERY = "hi";
private static String modelId = "";

// Test rolling-upgrade normalization processor when index with multiple shards
Expand Down Expand Up @@ -62,12 +68,13 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
if (isFirstMixedRound()) {
totalDocsCountMixed = NUM_DOCS_PER_ROUND;
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder);
QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, rescorer);
addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null);
} else {
totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND;
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder);
validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, hybridQueryBuilder, null);
}
break;
case UPGRADED:
Expand All @@ -77,9 +84,10 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
loadModel(modelId);
addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null);
HybridQueryBuilder hybridQueryBuilder = getQueryBuilder(modelId, null, null);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder);
QueryBuilder rescorer = QueryBuilders.matchQuery(TEST_FIELD, RESCORE_QUERY).boost(0.3f);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer);
hybridQueryBuilder = getQueryBuilder(modelId, Map.of("ef_search", 100), RescoreContext.getDefault());
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder);
validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, hybridQueryBuilder, rescorer);
} finally {
wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, SEARCH_PIPELINE_NAME);
}
Expand All @@ -89,15 +97,19 @@ public void testNormalizationProcessor_whenIndexWithMultipleShards_E2EFlow() thr
}
}

private void validateTestIndexOnUpgrade(final int numberOfDocs, final String modelId, HybridQueryBuilder hybridQueryBuilder)
throws Exception {
private void validateTestIndexOnUpgrade(
final int numberOfDocs,
final String modelId,
HybridQueryBuilder hybridQueryBuilder,
QueryBuilder rescorer
) throws Exception {
int docCount = getDocCount(getIndexNameForTest());
assertEquals(numberOfDocs, docCount);
loadModel(modelId);
Map<String, Object> searchResponseAsMap = search(
getIndexNameForTest(),
hybridQueryBuilder,
null,
rescorer,
1,
Map.of("search_pipeline", SEARCH_PIPELINE_NAME)
);
Expand All @@ -113,18 +125,18 @@ private void validateTestIndexOnUpgrade(final int numberOfDocs, final String mod
private HybridQueryBuilder getQueryBuilder(
final String modelId,
final Map<String, ?> methodParameters,
final RescoreContext rescoreContext
final RescoreContext rescoreContextForNeuralQuery
) {
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
neuralQueryBuilder.fieldName("passage_embedding");
neuralQueryBuilder.fieldName(VECTOR_EMBEDDING_FIELD);
neuralQueryBuilder.modelId(modelId);
neuralQueryBuilder.queryText(QUERY);
neuralQueryBuilder.k(5);
if (methodParameters != null) {
neuralQueryBuilder.methodParameters(methodParameters);
}
if (rescoreContext != null) {
neuralQueryBuilder.rescoreContext(rescoreContext);
if (Objects.nonNull(rescoreContextForNeuralQuery)) {
neuralQueryBuilder.rescoreContext(rescoreContextForNeuralQuery);
}

MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", QUERY);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.util.Locale;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
Expand All @@ -18,6 +19,7 @@
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.FieldDoc;
import org.opensearch.OpenSearchException;
import org.opensearch.common.Nullable;
import org.opensearch.common.lucene.search.FilteredCollector;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
Expand All @@ -33,6 +35,7 @@
import org.opensearch.search.query.MultiCollectorWrapper;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.ReduceableSearchResult;
import org.opensearch.search.rescore.RescoreContext;
import org.opensearch.search.sort.SortAndFormats;

import java.io.IOException;
Expand All @@ -55,6 +58,7 @@
* In most cases it will be wrapped in MultiCollectorManager.
*/
@RequiredArgsConstructor
@Log4j2
public abstract class HybridCollectorManager implements CollectorManager<Collector, ReduceableSearchResult> {

private final int numHits;
Expand All @@ -67,6 +71,7 @@ public abstract class HybridCollectorManager implements CollectorManager<Collect
private final TopDocsMerger topDocsMerger;
@Nullable
private final FieldDoc after;
private final SearchContext searchContext;

/**
* Create new instance of HybridCollectorManager depending on the concurrent search beeing enabled or disabled.
Expand Down Expand Up @@ -101,17 +106,15 @@ public static CollectorManager createHybridCollectorManager(final SearchContext
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
trackTotalHitsUpTo,
searchContext.sort(),
filteringWeight,
searchContext.searchAfter()
searchContext
)
: new HybridCollectorNonConcurrentManager(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())),
trackTotalHitsUpTo,
searchContext.sort(),
filteringWeight,
searchContext.searchAfter()
searchContext
);
}

Expand Down Expand Up @@ -161,28 +164,83 @@ private List<ReduceableSearchResult> getSearchResults(final List<HybridSearchCol
List<ReduceableSearchResult> results = new ArrayList<>();
DocValueFormat[] docValueFormats = getSortValueFormats(sortAndFormats);
for (HybridSearchCollector collector : hybridSearchCollectors) {
TopDocsAndMaxScore topDocsAndMaxScore = getTopDocsAndAndMaxScore(collector, docValueFormats);
boolean isSortEnabled = docValueFormats != null;
TopDocsAndMaxScore topDocsAndMaxScore = getTopDocsAndAndMaxScore(collector, isSortEnabled);
results.add((QuerySearchResult result) -> reduceCollectorResults(result, topDocsAndMaxScore, docValueFormats));
}
return results;
}

private TopDocsAndMaxScore getTopDocsAndAndMaxScore(
final HybridSearchCollector hybridSearchCollector,
final DocValueFormat[] docValueFormats
) {
TopDocs newTopDocs;
private TopDocsAndMaxScore getTopDocsAndAndMaxScore(final HybridSearchCollector hybridSearchCollector, final boolean isSortEnabled) {
List topDocs = hybridSearchCollector.topDocs();
if (docValueFormats != null) {
newTopDocs = getNewTopFieldDocs(
getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()),
topDocs,
sortAndFormats.sort.getSort()
);
} else {
newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()), topDocs);
if (isSortEnabled) {
return getSortedTopDocsAndMaxScore(topDocs, hybridSearchCollector);
}
return getTopDocsAndMaxScore(topDocs, hybridSearchCollector);
}

private TopDocsAndMaxScore getSortedTopDocsAndMaxScore(List<TopFieldDocs> topDocs, HybridSearchCollector hybridSearchCollector) {
TopDocs sortedTopDocs = getNewTopFieldDocs(
getTotalHits(this.trackTotalHitsUpTo, topDocs, hybridSearchCollector.getTotalHits()),
topDocs,
sortAndFormats.sort.getSort()
);
return new TopDocsAndMaxScore(sortedTopDocs, hybridSearchCollector.getMaxScore());
}

private TopDocsAndMaxScore getTopDocsAndMaxScore(List<TopDocs> topDocs, HybridSearchCollector hybridSearchCollector) {
List<TopDocs> rescoredTopDocs = rescore(topDocs);
float maxScore = calculateMaxScore(rescoredTopDocs, hybridSearchCollector.getMaxScore());
TopDocs finalTopDocs = getNewTopDocs(
getTotalHits(this.trackTotalHitsUpTo, rescoredTopDocs, hybridSearchCollector.getTotalHits()),
rescoredTopDocs
);
return new TopDocsAndMaxScore(finalTopDocs, maxScore);
}

private List<TopDocs> rescore(List<TopDocs> topDocs) {
List<RescoreContext> rescoreContexts = searchContext.rescore();
boolean shouldRescore = Objects.nonNull(rescoreContexts) && !rescoreContexts.isEmpty();
if (!shouldRescore) {
return topDocs;
}
List<TopDocs> rescoredTopDocs = topDocs;
for (RescoreContext ctx : rescoreContexts) {
rescoredTopDocs = rescoredTopDocs(ctx, rescoredTopDocs);
}
return new TopDocsAndMaxScore(newTopDocs, hybridSearchCollector.getMaxScore());
return rescoredTopDocs;
}

/**
* Rescores the top documents using the provided context. The input topDocs may be modified during this process.
*/
private List<TopDocs> rescoredTopDocs(final RescoreContext ctx, final List<TopDocs> topDocs) {
List<TopDocs> result = new ArrayList<>(topDocs.size());
for (TopDocs topDoc : topDocs) {
try {
result.add(ctx.rescorer().rescore(topDoc, searchContext.searcher(), ctx));
} catch (IOException exception) {
log.error("rescore failed for hybrid query", exception);
throw new OpenSearchException("rescore failed", exception);
}
}
return result;
}

/**
* Calculates the maximum score from the provided TopDocs, considering rescoring.
*/
private float calculateMaxScore(List<TopDocs> topDocsList, float initialMaxScore) {
List<RescoreContext> rescoreContexts = searchContext.rescore();
if (Objects.nonNull(rescoreContexts) && !rescoreContexts.isEmpty()) {
for (TopDocs topDocs : topDocsList) {
if (Objects.nonNull(topDocs.scoreDocs) && topDocs.scoreDocs.length > 0) {
// first top doc for each sub-query has the max score because top docs are sorted by score desc
initialMaxScore = Math.max(initialMaxScore, topDocs.scoreDocs[0].score);
}
}
}
return initialMaxScore;
}

private List<HybridSearchCollector> getHybridSearchCollectors(final Collection<Collector> collectors) {
Expand Down Expand Up @@ -415,18 +473,18 @@ public HybridCollectorNonConcurrentManager(
int numHits,
HitsThresholdChecker hitsThresholdChecker,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight,
ScoreDoc searchAfter
SearchContext searchContext
) {
super(
numHits,
hitsThresholdChecker,
trackTotalHitsUpTo,
sortAndFormats,
searchContext.sort(),
filteringWeight,
new TopDocsMerger(sortAndFormats),
(FieldDoc) searchAfter
new TopDocsMerger(searchContext.sort()),
searchContext.searchAfter(),
searchContext
);
scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null");
}
Expand All @@ -453,18 +511,18 @@ public HybridCollectorConcurrentSearchManager(
int numHits,
HitsThresholdChecker hitsThresholdChecker,
int trackTotalHitsUpTo,
SortAndFormats sortAndFormats,
Weight filteringWeight,
ScoreDoc searchAfter
SearchContext searchContext
) {
super(
numHits,
hitsThresholdChecker,
trackTotalHitsUpTo,
sortAndFormats,
searchContext.sort(),
filteringWeight,
new TopDocsMerger(sortAndFormats),
(FieldDoc) searchAfter
new TopDocsMerger(searchContext.sort()),
searchContext.searchAfter(),
searchContext
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ public boolean searchWith(
}
Query hybridQuery = extractHybridQuery(searchContext, query);
QueryPhaseSearcher queryPhaseSearcher = getQueryPhaseSearcher(searchContext);
return queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
queryPhaseSearcher.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout);
// we decide on rescore later in collector manager
return false;
}
}

Expand Down
Loading

0 comments on commit fcf7f51

Please sign in to comment.