Skip to content

Commit

Permalink
Addressed review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski committed Jul 22, 2023
1 parent 77d6eec commit 08fc6d5
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 104 deletions.
17 changes: 2 additions & 15 deletions src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;
Expand Down Expand Up @@ -98,22 +97,10 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
@Override
public Optional<QueryPhaseSearcher> getQueryPhaseSearcher() {
if (FeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED)) {
log.info(
String.format(
Locale.ROOT,
"Registering hybrid query phase searcher with feature flag [%s]",
NEURAL_SEARCH_HYBRID_SEARCH_ENABLED
)
);
log.info("Registering hybrid query phase searcher with feature flag [%]", NEURAL_SEARCH_HYBRID_SEARCH_ENABLED);
return Optional.of(new HybridQueryPhaseSearcher());
}
log.info(
String.format(
Locale.ROOT,
"Not registering hybrid query phase searcher because feature flag [%s] is disabled",
NEURAL_SEARCH_HYBRID_SEARCH_ENABLED
)
);
log.info("Not registering hybrid query phase searcher because feature flag [%] is disabled", NEURAL_SEARCH_HYBRID_SEARCH_ENABLED);
// we want feature be disabled by default due to risk of colliding and breaking concurrent search in core
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

package org.opensearch.neuralsearch.processor;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

import lombok.AccessLevel;
import lombok.AllArgsConstructor;
Expand Down Expand Up @@ -55,7 +55,6 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor {
* are set as part of class constructor
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
* @param <Result>
*/
@Override
public <Result extends SearchPhaseResult> void process(
Expand Down Expand Up @@ -121,15 +120,10 @@ private boolean isNotHybridQuery(final Optional<SearchPhaseResult> maybeResult)
}

private <Result extends SearchPhaseResult> List<QuerySearchResult> getQuerySearchResults(final SearchPhaseResults<Result> results) {
List<Result> resultsPerShard = results.getAtomicArray().asList();
List<QuerySearchResult> querySearchResults = new ArrayList<>();
for (Result shardResult : resultsPerShard) {
if (shardResult == null) {
querySearchResults.add(null);
continue;
}
querySearchResults.add(shardResult.queryResult());
}
return querySearchResults;
return results.getAtomicArray()
.asList()
.stream()
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import java.util.Objects;
import java.util.stream.Collectors;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import lombok.AllArgsConstructor;

import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
Expand All @@ -26,17 +25,11 @@
* Class abstracts steps required for score normalization and combination, this includes pre-processing of income data
* and post-processing for final results
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@AllArgsConstructor
public class NormalizationProcessorWorkflow {

/**
* Return instance of workflow class. Making default constructor private for now
* as we may use singleton pattern here and share one instance among processors
* @return instance of NormalizationProcessorWorkflow
*/
public static NormalizationProcessorWorkflow create() {
return new NormalizationProcessorWorkflow();
}
private final ScoreNormalizer scoreNormalizer;
private final ScoreCombiner scoreCombiner;

/**
* Start execution of this workflow
Expand All @@ -53,19 +46,18 @@ public void execute(
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults);

// normalize
ScoreNormalizer scoreNormalizer = new ScoreNormalizer();
scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique);

// combine
ScoreCombiner scoreCombiner = new ScoreCombiner();
List<Float> combinedMaxScores = scoreCombiner.combineScores(queryTopDocs, combinationTechnique);

// post-process data
updateOriginalQueryResults(querySearchResults, queryTopDocs, combinedMaxScores);
}

private List<CompoundTopDocs> getQueryTopDocs(List<QuerySearchResult> querySearchResults) {
private List<CompoundTopDocs> getQueryTopDocs(final List<QuerySearchResult> querySearchResults) {
List<CompoundTopDocs> queryTopDocs = querySearchResults.stream()
.filter(searchResult -> Objects.nonNull(searchResult.topDocs()))
.filter(searchResult -> searchResult.topDocs().topDocs instanceof CompoundTopDocs)
.map(searchResult -> (CompoundTopDocs) searchResult.topDocs().topDocs)
.collect(Collectors.toList());
Expand All @@ -74,34 +66,20 @@ private List<CompoundTopDocs> getQueryTopDocs(List<QuerySearchResult> querySearc

@VisibleForTesting
protected void updateOriginalQueryResults(
List<QuerySearchResult> querySearchResults,
final List<QuerySearchResult> querySearchResults,
final List<CompoundTopDocs> queryTopDocs,
List<Float> combinedMaxScores
final List<Float> combinedMaxScores
) {
TopDocsAndMaxScore[] topDocsAndMaxScores = new TopDocsAndMaxScore[querySearchResults.size()];
for (int idx = 0; idx < querySearchResults.size(); idx++) {
QuerySearchResult querySearchResult = querySearchResults.get(idx);
TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs();
if (!(topDocsAndMaxScore.topDocs instanceof CompoundTopDocs)) {
continue;
}
topDocsAndMaxScores[idx] = topDocsAndMaxScore;
}
for (int i = 0; i < querySearchResults.size(); i++) {
QuerySearchResult querySearchResult = querySearchResults.get(i);
CompoundTopDocs updatedTopDocs = queryTopDocs.get(i);
if (Objects.isNull(updatedTopDocs)) {
if (!(querySearchResult.topDocs().topDocs instanceof CompoundTopDocs) || Objects.isNull(queryTopDocs.get(i))) {
continue;
}
CompoundTopDocs updatedTopDocs = queryTopDocs.get(i);
float maxScore = updatedTopDocs.totalHits.value > 0 ? updatedTopDocs.scoreDocs[0].score : 0.0f;
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(updatedTopDocs, maxScore);
if (querySearchResult == null) {
continue;
}
querySearchResult.topDocs(topDocsAndMaxScore, null);
if (topDocsAndMaxScores[i] != null) {
topDocsAndMaxScores[i].maxScore = combinedMaxScores.get(i);
}
TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(updatedTopDocs, maxScore);
querySearchResult.topDocs(updatedTopDocsAndMaxScore, null);
querySearchResults.get(i).topDocs().maxScore = combinedMaxScores.get(i);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

package org.opensearch.neuralsearch.processor.combination;

import java.util.Objects;

import lombok.AccessLevel;
import lombok.NoArgsConstructor;

Expand All @@ -16,12 +14,9 @@
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class ArithmeticMeanScoreCombinationMethod implements ScoreCombinationMethod {

private static ArithmeticMeanScoreCombinationMethod INSTANCE = new ArithmeticMeanScoreCombinationMethod();
private static final ArithmeticMeanScoreCombinationMethod INSTANCE = new ArithmeticMeanScoreCombinationMethod();

public static ArithmeticMeanScoreCombinationMethod getInstance() {
if (Objects.isNull(INSTANCE)) {
INSTANCE = new ArithmeticMeanScoreCombinationMethod();
}
return INSTANCE;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ public List<Float> combineScores(final List<CompoundTopDocs> queryTopDocs, final
.collect(Collectors.toList());
}

private float combineShardScores(ScoreCombinationTechnique scoreCombinationTechnique, CompoundTopDocs compoundQueryTopDocs) {
private float combineShardScores(
final ScoreCombinationTechnique scoreCombinationTechnique,
final CompoundTopDocs compoundQueryTopDocs
) {
if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.totalHits.value == 0) {
return ZERO_SCORE;
}
Expand All @@ -66,7 +69,7 @@ private float combineShardScores(ScoreCombinationTechnique scoreCombinationTechn

// - sort documents by scores and take first "max number" of docs
// create a priority queue of doc ids that are sorted by their combined scores
PriorityQueue<Integer> scoreQueue = getPriorityQueueOfDocIds(normalizedScoresPerDoc, combinedNormalizedScoresByDocId);
PriorityQueue<Integer> scoreQueue = getPriorityQueueOfDocIds(combinedNormalizedScoresByDocId);
// store max score to resulting list, call it now as priority queue will change after combining scores
float maxScore = combinedNormalizedScoresByDocId.get(scoreQueue.peek());

Expand All @@ -75,15 +78,12 @@ private float combineShardScores(ScoreCombinationTechnique scoreCombinationTechn
return maxScore;
}

private PriorityQueue<Integer> getPriorityQueueOfDocIds(
Map<Integer, float[]> normalizedScoresPerDoc,
Map<Integer, Float> combinedNormalizedScoresByDocId
) {
private PriorityQueue<Integer> getPriorityQueueOfDocIds(final Map<Integer, Float> combinedNormalizedScoresByDocId) {
PriorityQueue<Integer> pq = new PriorityQueue<>(
(a, b) -> Float.compare(combinedNormalizedScoresByDocId.get(b), combinedNormalizedScoresByDocId.get(a))
);
// we're merging docs with normalized and combined scores. we need to have only maxHits results
pq.addAll(normalizedScoresPerDoc.keySet());
pq.addAll(combinedNormalizedScoresByDocId.keySet());
return pq;
}

Expand All @@ -96,32 +96,32 @@ private ScoreDoc[] getCombinedScoreDocs(
ScoreDoc[] finalScoreDocs = new ScoreDoc[maxHits];

int shardId = compoundQueryTopDocs.scoreDocs[0].shardIndex;
for (int j = 0; j < maxHits; j++) {
for (int j = 0; j < maxHits && !scoreQueue.isEmpty(); j++) {
int docId = scoreQueue.poll();
finalScoreDocs[j] = new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId), shardId);
}
return finalScoreDocs;
}

private Map<Integer, float[]> getNormalizedScoresPerDocument(List<TopDocs> topDocsPerSubQuery) {
private Map<Integer, float[]> getNormalizedScoresPerDocument(final List<TopDocs> topDocsPerSubQuery) {
Map<Integer, float[]> normalizedScoresPerDoc = new HashMap<>();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
TopDocs topDocs = topDocsPerSubQuery.get(j);
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
normalizedScoresPerDoc.putIfAbsent(scoreDoc.doc, normalizedScoresPerDoc.computeIfAbsent(scoreDoc.doc, key -> {
normalizedScoresPerDoc.computeIfAbsent(scoreDoc.doc, key -> {
float[] scores = new float[topDocsPerSubQuery.size()];
// we initialize with -1.0, as after normalization it's possible that score is 0.0
Arrays.fill(scores, -1.0f);
return scores;
}));
});
normalizedScoresPerDoc.get(scoreDoc.doc)[j] = scoreDoc.score;
}
}
return normalizedScoresPerDoc;
}

private Map<Integer, Float> combineScoresAndGetCombinedNormilizedScoresPerDocument(
Map<Integer, float[]> normalizedScoresPerDocument,
final Map<Integer, float[]> normalizedScoresPerDocument,
final ScoreCombinationTechnique scoreCombinationTechnique
) {
return normalizedScoresPerDocument.entrySet()
Expand All @@ -130,10 +130,10 @@ private Map<Integer, Float> combineScoresAndGetCombinedNormilizedScoresPerDocume
}

private void updateQueryTopDocsWithCombinedScores(
CompoundTopDocs compoundQueryTopDocs,
List<TopDocs> topDocsPerSubQuery,
Map<Integer, Float> combinedNormalizedScoresByDocId,
PriorityQueue<Integer> scoreQueue
final CompoundTopDocs compoundQueryTopDocs,
final List<TopDocs> topDocsPerSubQuery,
final Map<Integer, Float> combinedNormalizedScoresByDocId,
final PriorityQueue<Integer> scoreQueue
) {
// - count max number of hits among sub-queries
int maxHits = getMaxHits(topDocsPerSubQuery);
Expand All @@ -142,7 +142,7 @@ private void updateQueryTopDocsWithCombinedScores(
compoundQueryTopDocs.totalHits = getTotalHits(topDocsPerSubQuery, maxHits);
}

private int getMaxHits(List<TopDocs> topDocsPerSubQuery) {
private int getMaxHits(final List<TopDocs> topDocsPerSubQuery) {
int maxHits = 0;
for (TopDocs topDocs : topDocsPerSubQuery) {
int hits = topDocs.scoreDocs.length;
Expand All @@ -151,7 +151,7 @@ private int getMaxHits(List<TopDocs> topDocsPerSubQuery) {
return maxHits;
}

private TotalHits getTotalHits(List<TopDocs> topDocsPerSubQuery, int maxHits) {
private TotalHits getTotalHits(final List<TopDocs> topDocsPerSubQuery, int maxHits) {
TotalHits.Relation totalHits = TotalHits.Relation.EQUAL_TO;
if (topDocsPerSubQuery.stream().anyMatch(topDocs -> topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)) {
totalHits = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.neuralsearch.processor.factory;

import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException;
import static org.opensearch.ingest.ConfigurationUtils.readOptionalMap;

import java.util.Map;
Expand All @@ -15,14 +16,20 @@
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;

/**
* Factory for query results normalization processor for search pipeline. Instantiates processor based on user provided input.
*/
public class NormalizationProcessorFactory implements Processor.Factory<SearchPhaseResultsProcessor> {
private final NormalizationProcessorWorkflow normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(
new ScoreNormalizer(),
new ScoreCombiner()
);

@Override
public SearchPhaseResultsProcessor create(
Expand Down Expand Up @@ -53,29 +60,49 @@ public SearchPhaseResultsProcessor create(
? ScoreCombinationTechnique.DEFAULT.name()
: (String) combinationClause.getOrDefault(NormalizationProcessor.TECHNIQUE, "");

validateParameters(normalizationTechnique, combinationTechnique);
validateParameters(normalizationTechnique, combinationTechnique, tag);

return new NormalizationProcessor(
tag,
description,
ScoreNormalizationTechnique.valueOf(normalizationTechnique),
ScoreCombinationTechnique.valueOf(combinationTechnique),
NormalizationProcessorWorkflow.create()
normalizationProcessorWorkflow
);
}

protected void validateParameters(final String normalizationTechniqueName, final String combinationTechniqueName) {
protected void validateParameters(final String normalizationTechniqueName, final String combinationTechniqueName, final String tag) {
if (StringUtils.isEmpty(normalizationTechniqueName)) {
throw new IllegalArgumentException("normalization technique cannot be empty");
throw newConfigurationException(
NormalizationProcessor.TYPE,
tag,
NormalizationProcessor.TECHNIQUE,
"normalization technique cannot be empty"
);
}
if (StringUtils.isEmpty(combinationTechniqueName)) {
throw new IllegalArgumentException("combination technique cannot be empty");
throw newConfigurationException(
NormalizationProcessor.TYPE,
tag,
NormalizationProcessor.TECHNIQUE,
"combination technique cannot be empty"
);
}
if (!EnumUtils.isValidEnum(ScoreNormalizationTechnique.class, normalizationTechniqueName)) {
throw new IllegalArgumentException("provided normalization technique is not supported");
throw newConfigurationException(
NormalizationProcessor.TYPE,
tag,
NormalizationProcessor.TECHNIQUE,
"provided normalization technique is not supported"
);
}
if (!EnumUtils.isValidEnum(ScoreCombinationTechnique.class, combinationTechniqueName)) {
throw new IllegalArgumentException("provided combination technique is not supported");
throw newConfigurationException(
NormalizationProcessor.TYPE,
tag,
NormalizationProcessor.TECHNIQUE,
"provided combination technique is not supported"
);
}
}
}
Loading

0 comments on commit 08fc6d5

Please sign in to comment.