diff --git a/build.gradle b/build.gradle index 6e4b4ada4..541d70354 100644 --- a/build.gradle +++ b/build.gradle @@ -253,8 +253,12 @@ testClusters.integTest { // Increase heap size from default of 512mb to 1gb. When heap size is 512mb, our integ tests sporadically fail due // to ml-commons memory circuit breaker exception jvmArgs("-Xms1g", "-Xmx1g") - // enable hybrid search for testing + + // enable features for testing + // hybrid search systemProperty('neural_search_hybrid_search_enabled', 'true') + // search pipelines + systemProperty('opensearch.experimental.feature.search_pipeline.enabled', 'true') } // Remote Integration Tests diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 83f8c396b..b656234f0 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -13,6 +13,8 @@ import java.util.Optional; import java.util.function.Supplier; +import lombok.extern.log4j.Log4j2; + import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; @@ -24,8 +26,15 @@ import org.opensearch.ingest.Processor; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.NormalizationProcessor; +import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; @@ -33,9 +42,11 @@ import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.IngestPlugin; import org.opensearch.plugins.Plugin; +import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; import org.opensearch.repositories.RepositoriesService; import org.opensearch.script.ScriptService; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.query.QueryPhaseSearcher; import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; @@ -45,7 +56,8 @@ /** * Neural Search plugin class */ -public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin { +@Log4j2 +public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin, SearchPipelinePlugin { /** * Gates the functionality of hybrid search * Currently query phase searcher added with hybrid search will conflict with concurrent search in core. @@ -54,6 +66,9 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, @VisibleForTesting public static final String NEURAL_SEARCH_HYBRID_SEARCH_ENABLED = "neural_search_hybrid_search_enabled"; private MLCommonsClientAccessor clientAccessor; + private NormalizationProcessorWorkflow normalizationProcessorWorkflow; + private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); + private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();; @Override public Collection createComponents( @@ -70,6 +85,7 @@ public Collection createComponents( final Supplier repositoriesServiceSupplier ) { NeuralQueryBuilder.initialize(clientAccessor); + normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()); return List.of(clientAccessor); } @@ -90,9 +106,21 @@ public Map getProcessors(Processor.Parameters paramet @Override public Optional getQueryPhaseSearcher() { if (FeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED)) { + log.info("Registering hybrid query phase searcher with feature flag [%]", NEURAL_SEARCH_HYBRID_SEARCH_ENABLED); return Optional.of(new HybridQueryPhaseSearcher()); } + log.info("Not registering hybrid query phase searcher because feature flag [%] is disabled", NEURAL_SEARCH_HYBRID_SEARCH_ENABLED); // we want feature be disabled by default due to risk of colliding and breaking concurrent search in core return Optional.empty(); } + + @Override + public Map> getSearchPhaseResultsProcessors( + Parameters parameters + ) { + return Map.of( + NormalizationProcessor.TYPE, + new NormalizationProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory) + ); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java new file mode 100644 index 000000000..879b53703 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -0,0 +1,121 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; +import org.opensearch.search.query.QuerySearchResult; + +/** + * Processor for score normalization and combination on post query search results. Updates query results with + * normalized and combined scores for next phase (typically it's FETCH) + */ +@Log4j2 +@AllArgsConstructor +public class NormalizationProcessor implements SearchPhaseResultsProcessor { + public static final String TYPE = "normalization-processor"; + + private final String tag; + private final String description; + private final ScoreNormalizationTechnique normalizationTechnique; + private final ScoreCombinationTechnique combinationTechnique; + private final NormalizationProcessorWorkflow normalizationWorkflow; + + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor + * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution + * @param searchPhaseContext {@link SearchContext} + */ + @Override + public void process( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext + ) { + if (shouldRunProcessor(searchPhaseResult)) { + return; + } + List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); + normalizationWorkflow.execute(querySearchResults, normalizationTechnique, combinationTechnique); + } + + @Override + public SearchPhaseName getBeforePhase() { + return SearchPhaseName.QUERY; + } + + @Override + public SearchPhaseName getAfterPhase() { + return SearchPhaseName.FETCH; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public String getTag() { + return tag; + } + + @Override + public String getDescription() { + return description; + } + + @Override + public boolean isIgnoreFailure() { + return false; + } + + private boolean shouldRunProcessor(SearchPhaseResults searchPhaseResult) { + if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer)) { + return true; + } + + QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer) searchPhaseResult; + Optional optionalSearchPhaseResult = queryPhaseResultConsumer.getAtomicArray() + .asList() + .stream() + .filter(Objects::nonNull) + .findFirst(); + return isNotHybridQuery(optionalSearchPhaseResult); + } + + private boolean isNotHybridQuery(final Optional optionalSearchPhaseResult) { + return optionalSearchPhaseResult.isEmpty() + || Objects.isNull(optionalSearchPhaseResult.get().queryResult()) + || Objects.isNull(optionalSearchPhaseResult.get().queryResult().topDocs()) + || !(optionalSearchPhaseResult.get().queryResult().topDocs().topDocs instanceof CompoundTopDocs); + } + + private List getQueryPhaseSearchResults( + final SearchPhaseResults results + ) { + return results.getAtomicArray() + .asList() + .stream() + .map(result -> result == null ? null : result.queryResult()) + .collect(Collectors.toList()); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java new file mode 100644 index 000000000..d56e47801 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import lombok.AllArgsConstructor; + +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.search.query.QuerySearchResult; + +/** + * Class abstracts steps required for score normalization and combination, this includes pre-processing of incoming data + * and post-processing of final results + */ +@AllArgsConstructor +public class NormalizationProcessorWorkflow { + + private final ScoreNormalizer scoreNormalizer; + private final ScoreCombiner scoreCombiner; + + /** + * Start execution of this workflow + * @param querySearchResults input data with QuerySearchResult from multiple shards + * @param normalizationTechnique technique for score normalization + * @param combinationTechnique technique for score combination + */ + public void execute( + final List querySearchResults, + final ScoreNormalizationTechnique normalizationTechnique, + final ScoreCombinationTechnique combinationTechnique + ) { + // pre-process data + List queryTopDocs = getQueryTopDocs(querySearchResults); + + // normalize + scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique); + + // combine + scoreCombiner.combineScores(queryTopDocs, combinationTechnique); + + // post-process data + updateOriginalQueryResults(querySearchResults, queryTopDocs); + } + + /** + * Getting list of CompoundTopDocs from list of QuerySearchResult. Each CompoundTopDocs is for individual shard + * @param querySearchResults collection of QuerySearchResult for all shards + * @return collection of CompoundTopDocs, one object for each shard + */ + private List getQueryTopDocs(final List querySearchResults) { + List queryTopDocs = querySearchResults.stream() + .filter(searchResult -> Objects.nonNull(searchResult.topDocs())) + .filter(searchResult -> searchResult.topDocs().topDocs instanceof CompoundTopDocs) + .map(searchResult -> (CompoundTopDocs) searchResult.topDocs().topDocs) + .collect(Collectors.toList()); + return queryTopDocs; + } + + private void updateOriginalQueryResults(final List querySearchResults, final List queryTopDocs) { + for (int i = 0; i < querySearchResults.size(); i++) { + QuerySearchResult querySearchResult = querySearchResults.get(i); + if (!(querySearchResult.topDocs().topDocs instanceof CompoundTopDocs) || Objects.isNull(queryTopDocs.get(i))) { + continue; + } + CompoundTopDocs updatedTopDocs = queryTopDocs.get(i); + float maxScore = updatedTopDocs.totalHits.value > 0 ? updatedTopDocs.scoreDocs[0].score : 0.0f; + TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(updatedTopDocs, maxScore); + querySearchResult.topDocs(updatedTopDocsAndMaxScore, null); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java new file mode 100644 index 000000000..10c5533f5 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import lombok.NoArgsConstructor; + +/** + * Abstracts combination of scores based on arithmetic mean method + */ +@NoArgsConstructor +public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique { + + public static final String TECHNIQUE_NAME = "arithmetic_mean"; + private static final Float ZERO_SCORE = 0.0f; + + /** + * Arithmetic mean method for combining scores. + * cscore = (score1 + score2 +...+ scoreN)/N + * + * Zero (0.0) scores are excluded from number of scores N + */ + @Override + public float combine(final float[] scores) { + float combinedScore = 0.0f; + int count = 0; + for (float score : scores) { + if (score >= 0.0) { + combinedScore += score; + count++; + } + } + if (count == 0) { + return ZERO_SCORE; + } + return combinedScore / count; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java new file mode 100644 index 000000000..bf55a8cc5 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import java.util.Map; +import java.util.Optional; + +import org.opensearch.OpenSearchParseException; + +/** + * Abstracts creation of exact score combination method based on technique name + */ +public class ScoreCombinationFactory { + + public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique(); + + private final Map scoreCombinationMethodsMap = Map.of( + ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, + new ArithmeticMeanScoreCombinationTechnique() + ); + + /** + * Get score combination method by technique name + * @param technique name of technique + * @return instance of ScoreCombinationTechnique for technique name + */ + public ScoreCombinationTechnique createCombination(final String technique) { + return Optional.ofNullable(scoreCombinationMethodsMap.get(technique)) + .orElseThrow(() -> new OpenSearchParseException("provided combination technique is not supported")); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java new file mode 100644 index 000000000..21090b1ce --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationTechnique.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +public interface ScoreCombinationTechnique { + /** + * Defines combination function specific to this technique + * @param scores array of collected original scores + * @return combined score + */ + float combine(final float[] scores); +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java new file mode 100644 index 000000000..67e776d77 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -0,0 +1,150 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +import lombok.extern.log4j.Log4j2; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.search.CompoundTopDocs; + +/** + * Abstracts combination of scores in query search results. + */ +@Log4j2 +public class ScoreCombiner { + + private static final Float ZERO_SCORE = 0.0f; + + /** + * Performs score combination based on input combination technique. Mutates input object by updating combined scores + * Main steps we're doing for combination: + * - create map of normalized scores per doc id + * - using normalized scores create another map of combined scores per doc id + * - count max number of hits among sub-queries + * - sort documents by scores and take first "max number" of docs + * - update query search results with normalized scores + * Different score combination techniques are different in step 2, where we create map of "doc id" - "combined score", + * other steps are same for all techniques. + * @param queryTopDocs query results that need to be normalized, mutated by method execution + * @param scoreCombinationTechnique exact combination method that should be applied + */ + public void combineScores(final List queryTopDocs, final ScoreCombinationTechnique scoreCombinationTechnique) { + // iterate over results from each shard. Every CompoundTopDocs object has results from + // multiple sub queries, doc ids may repeat for each sub query results + queryTopDocs.forEach(compoundQueryTopDocs -> combineShardScores(scoreCombinationTechnique, compoundQueryTopDocs)); + } + + private void combineShardScores(final ScoreCombinationTechnique scoreCombinationTechnique, final CompoundTopDocs compoundQueryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.totalHits.value == 0) { + return; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + // - create map of normalized scores results returned from the single shard + Map normalizedScoresPerDoc = getNormalizedScoresPerDocument(topDocsPerSubQuery); + + // - create map of combined scores per doc id + Map combinedNormalizedScoresByDocId = combineScoresAndGetCombinedNormalizedScoresPerDocument( + normalizedScoresPerDoc, + scoreCombinationTechnique + ); + + // - sort documents by scores and take first "max number" of docs + // create a collection of doc ids that are sorted by their combined scores + List sortedDocsIds = getSortedDocIds(combinedNormalizedScoresByDocId); + + // - update query search results with normalized scores + updateQueryTopDocsWithCombinedScores(compoundQueryTopDocs, topDocsPerSubQuery, combinedNormalizedScoresByDocId, sortedDocsIds); + } + + private List getSortedDocIds(final Map combinedNormalizedScoresByDocId) { + // we're merging docs with normalized and combined scores. we need to have only maxHits results + List sortedDocsIds = new ArrayList<>(combinedNormalizedScoresByDocId.keySet()); + sortedDocsIds.sort((a, b) -> Float.compare(combinedNormalizedScoresByDocId.get(b), combinedNormalizedScoresByDocId.get(a))); + return sortedDocsIds; + } + + private ScoreDoc[] getCombinedScoreDocs( + final CompoundTopDocs compoundQueryTopDocs, + final Map combinedNormalizedScoresByDocId, + final List sortedScores, + final int maxHits + ) { + ScoreDoc[] finalScoreDocs = new ScoreDoc[maxHits]; + + int shardId = compoundQueryTopDocs.scoreDocs[0].shardIndex; + for (int j = 0; j < maxHits && j < sortedScores.size(); j++) { + int docId = sortedScores.get(j); + finalScoreDocs[j] = new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId), shardId); + } + return finalScoreDocs; + } + + public Map getNormalizedScoresPerDocument(final List topDocsPerSubQuery) { + Map normalizedScoresPerDoc = new HashMap<>(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + TopDocs topDocs = topDocsPerSubQuery.get(j); + for (ScoreDoc scoreDoc : topDocs.scoreDocs) { + normalizedScoresPerDoc.computeIfAbsent(scoreDoc.doc, key -> { + float[] scores = new float[topDocsPerSubQuery.size()]; + // we initialize with -1.0, as after normalization it's possible that score is 0.0 + Arrays.fill(scores, -1.0f); + return scores; + }); + normalizedScoresPerDoc.get(scoreDoc.doc)[j] = scoreDoc.score; + } + } + return normalizedScoresPerDoc; + } + + private Map combineScoresAndGetCombinedNormalizedScoresPerDocument( + final Map normalizedScoresPerDocument, + final ScoreCombinationTechnique scoreCombinationTechnique + ) { + return normalizedScoresPerDocument.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue()))); + } + + private void updateQueryTopDocsWithCombinedScores( + final CompoundTopDocs compoundQueryTopDocs, + final List topDocsPerSubQuery, + final Map combinedNormalizedScoresByDocId, + final List sortedScores + ) { + // - count max number of hits among sub-queries + int maxHits = getMaxHits(topDocsPerSubQuery); + // - update query search results with normalized scores + compoundQueryTopDocs.scoreDocs = getCombinedScoreDocs(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sortedScores, maxHits); + compoundQueryTopDocs.totalHits = getTotalHits(topDocsPerSubQuery, maxHits); + } + + protected int getMaxHits(final List topDocsPerSubQuery) { + int maxHits = 0; + for (TopDocs topDocs : topDocsPerSubQuery) { + int hits = topDocs.scoreDocs.length; + maxHits = Math.max(maxHits, hits); + } + return maxHits; + } + + private TotalHits getTotalHits(final List topDocsPerSubQuery, int maxHits) { + TotalHits.Relation totalHits = TotalHits.Relation.EQUAL_TO; + if (topDocsPerSubQuery.stream().anyMatch(topDocs -> topDocs.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO)) { + totalHits = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + } + return new TotalHits(maxHits, totalHits); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java new file mode 100644 index 000000000..f31a5c6bc --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.opensearch.ingest.ConfigurationUtils.readOptionalMap; + +import java.util.Map; +import java.util.Objects; + +import lombok.AllArgsConstructor; + +import org.opensearch.neuralsearch.processor.NormalizationProcessor; +import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; + +/** + * Factory for query results normalization processor for search pipeline. Instantiates processor based on user provided input. + */ +@AllArgsConstructor +public class NormalizationProcessorFactory implements Processor.Factory { + public static final String NORMALIZATION_CLAUSE = "normalization"; + public static final String COMBINATION_CLAUSE = "combination"; + public static final String TECHNIQUE = "technique"; + + private final NormalizationProcessorWorkflow normalizationProcessorWorkflow; + private ScoreNormalizationFactory scoreNormalizationFactory; + private ScoreCombinationFactory scoreCombinationFactory; + + @Override + public SearchPhaseResultsProcessor create( + final Map> processorFactories, + final String tag, + final String description, + final boolean ignoreFailure, + final Map config, + final Processor.PipelineContext pipelineContext + ) throws Exception { + Map normalizationClause = readOptionalMap(NormalizationProcessor.TYPE, tag, config, NORMALIZATION_CLAUSE); + ScoreNormalizationTechnique normalizationTechnique = ScoreNormalizationFactory.DEFAULT_METHOD; + if (Objects.nonNull(normalizationClause)) { + String normalizationTechniqueName = (String) normalizationClause.getOrDefault(TECHNIQUE, ""); + normalizationTechnique = scoreNormalizationFactory.createNormalization(normalizationTechniqueName); + } + + Map combinationClause = readOptionalMap(NormalizationProcessor.TYPE, tag, config, COMBINATION_CLAUSE); + + ScoreCombinationTechnique scoreCombinationTechnique = ScoreCombinationFactory.DEFAULT_METHOD; + if (Objects.nonNull(combinationClause)) { + String combinationTechnique = (String) combinationClause.getOrDefault(TECHNIQUE, ""); + scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique); + } + + return new NormalizationProcessor( + tag, + description, + normalizationTechnique, + scoreCombinationTechnique, + normalizationProcessorWorkflow + ); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java new file mode 100644 index 000000000..7d5317c14 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -0,0 +1,114 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.opensearch.neuralsearch.search.CompoundTopDocs; + +import com.google.common.primitives.Floats; + +/** + * Abstracts normalization of scores based on min-max method + */ +public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTechnique { + + protected static final String TECHNIQUE_NAME = "min_max"; + private static final float MIN_SCORE = 0.001f; + private static final float SINGLE_RESULT_SCORE = 1.0f; + + /** + * Min-max normalization method. + * nscore = (score - min_score)/(max_score - min_score) + * Main algorithm steps: + * - calculate min and max scores for each sub query + * - iterate over each result and update score as per formula above where "score" is raw score returned by Hybrid query + */ + @Override + public void normalize(final List queryTopDocs) { + int numOfSubqueries = queryTopDocs.stream() + .filter(Objects::nonNull) + .filter(topDocs -> topDocs.getCompoundTopDocs().size() > 0) + .findAny() + .get() + .getCompoundTopDocs() + .size(); + // get min scores for each sub query + float[] minScoresPerSubquery = getMinScores(queryTopDocs, numOfSubqueries); + + // get max scores for each sub query + float[] maxScoresPerSubquery = getMaxScores(queryTopDocs, numOfSubqueries); + + // do normalization using actual score and min and max scores for corresponding sub query + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); + for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { + scoreDoc.score = normalizeSingleScore(scoreDoc.score, minScoresPerSubquery[j], maxScoresPerSubquery[j]); + } + } + } + } + + private float[] getMaxScores(final List queryTopDocs, final int numOfSubqueries) { + float[] maxScores = new float[numOfSubqueries]; + Arrays.fill(maxScores, Float.MIN_VALUE); + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + maxScores[j] = Math.max( + maxScores[j], + Arrays.stream(topDocsPerSubQuery.get(j).scoreDocs) + .map(scoreDoc -> scoreDoc.score) + .max(Float::compare) + .orElse(Float.MIN_VALUE) + ); + } + } + return maxScores; + } + + private float[] getMinScores(final List queryTopDocs, final int numOfScores) { + float[] minScores = new float[numOfScores]; + Arrays.fill(minScores, Float.MAX_VALUE); + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + minScores[j] = Math.min( + minScores[j], + Arrays.stream(topDocsPerSubQuery.get(j).scoreDocs) + .map(scoreDoc -> scoreDoc.score) + .min(Float::compare) + .orElse(Float.MAX_VALUE) + ); + } + } + return minScores; + } + + private float normalizeSingleScore(final float score, final float minScore, final float maxScore) { + // edge case when there is only one score and min and max scores are same + if (Floats.compare(maxScore, minScore) == 0 && Floats.compare(maxScore, score) == 0) { + return SINGLE_RESULT_SCORE; + } + float normalizedScore = (score - minScore) / (maxScore - minScore); + return normalizedScore == 0.0f ? MIN_SCORE : normalizedScore; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java new file mode 100644 index 000000000..17bf8cb23 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import java.util.Map; +import java.util.Optional; + +import org.opensearch.OpenSearchParseException; + +/** + * Abstracts creation of exact score normalization method based on technique name + */ +public class ScoreNormalizationFactory { + + public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique(); + + private final Map scoreNormalizationMethodsMap = Map.of( + MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME, + new MinMaxScoreNormalizationTechnique() + ); + + /** + * Get score normalization method by technique name + * @param technique name of technique + * @return instance of ScoreNormalizationMethod for technique name + */ + public ScoreNormalizationTechnique createNormalization(final String technique) { + return Optional.ofNullable(scoreNormalizationMethodsMap.get(technique)) + .orElseThrow(() -> new OpenSearchParseException("provided normalization technique is not supported")); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java new file mode 100644 index 000000000..fdaeb85d8 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import java.util.List; + +import org.opensearch.neuralsearch.search.CompoundTopDocs; + +/** + * Abstracts normalization of scores in query search results. + */ +public interface ScoreNormalizationTechnique { + + /** + * Performs score normalization based on input normalization technique. Mutates input object by updating normalized scores. + * @param queryTopDocs original query results from multiple shards and multiple sub-queries + */ + void normalize(final List queryTopDocs); +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java new file mode 100644 index 000000000..5b8b7b1ca --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import java.util.List; +import java.util.Objects; + +import org.opensearch.neuralsearch.search.CompoundTopDocs; + +public class ScoreNormalizer { + + /** + * Performs score normalization based on input normalization technique. Mutates input object by updating normalized scores. + * @param queryTopDocs original query results from multiple shards and multiple sub-queries + * @param scoreNormalizationTechnique exact normalization technique that should be applied + */ + public void normalizeScores(final List queryTopDocs, final ScoreNormalizationTechnique scoreNormalizationTechnique) { + if (canQueryResultsBeNormalized(queryTopDocs)) { + scoreNormalizationTechnique.normalize(queryTopDocs); + } + } + + private boolean canQueryResultsBeNormalized(final List queryTopDocs) { + return queryTopDocs.stream().filter(Objects::nonNull).anyMatch(topDocs -> topDocs.getCompoundTopDocs().size() > 0); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/OpenSearchSecureRestTestCase.java b/src/test/java/org/opensearch/neuralsearch/OpenSearchSecureRestTestCase.java index 5a4f6d381..07f9e53d7 100644 --- a/src/test/java/org/opensearch/neuralsearch/OpenSearchSecureRestTestCase.java +++ b/src/test/java/org/opensearch/neuralsearch/OpenSearchSecureRestTestCase.java @@ -36,8 +36,8 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaType; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.test.rest.OpenSearchRestTestCase; @@ -150,7 +150,7 @@ protected boolean preserveIndicesUponCompletion() { @After public void deleteExternalIndices() throws IOException { final Response response = client().performRequest(new Request("GET", "/_cat/indices?format=json" + "&expand_wildcards=all")); - final XContentType xContentType = XContentType.fromMediaType(response.getEntity().getContentType()); + final MediaType xContentType = MediaType.fromMediaType(response.getEntity().getContentType()); try ( final XContentParser parser = xContentType.xContent() .createParser( diff --git a/src/test/java/org/opensearch/neuralsearch/TestUtils.java b/src/test/java/org/opensearch/neuralsearch/TestUtils.java index 257545132..97a289bcd 100644 --- a/src/test/java/org/opensearch/neuralsearch/TestUtils.java +++ b/src/test/java/org/opensearch/neuralsearch/TestUtils.java @@ -5,13 +5,18 @@ package org.opensearch.neuralsearch; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import static org.opensearch.test.OpenSearchTestCase.randomFloat; +import java.util.Arrays; +import java.util.List; import java.util.Map; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.search.query.QuerySearchResult; public class TestUtils { @@ -51,4 +56,42 @@ public static float[] createRandomVector(int dimension) { } return vector; } + + /** + * Assert results of hyrdir query after score normalization and combination + * @param querySearchResults collection of query search results after they processed by normalization processor + */ + public static void assertQueryResultScores(List querySearchResults) { + assertNotNull(querySearchResults); + float maxScore = querySearchResults.stream() + .map(searchResult -> searchResult.topDocs().maxScore) + .max(Float::compare) + .orElse(Float.MAX_VALUE); + assertEquals(1.0f, maxScore, 0.0f); + float totalMaxScore = querySearchResults.stream() + .map(searchResult -> searchResult.getMaxScore()) + .max(Float::compare) + .orElse(Float.MAX_VALUE); + assertEquals(1.0f, totalMaxScore, 0.0f); + float maxScoreScoreFromScoreDocs = querySearchResults.stream() + .map( + searchResult -> Arrays.stream(searchResult.topDocs().topDocs.scoreDocs) + .map(scoreDoc -> scoreDoc.score) + .max(Float::compare) + .orElse(Float.MAX_VALUE) + ) + .max(Float::compare) + .orElse(Float.MAX_VALUE); + assertEquals(1.0f, maxScoreScoreFromScoreDocs, 0.0f); + float minScoreScoreFromScoreDocs = querySearchResults.stream() + .map( + searchResult -> Arrays.stream(searchResult.topDocs().topDocs.scoreDocs) + .map(scoreDoc -> scoreDoc.score) + .min(Float::compare) + .orElse(Float.MAX_VALUE) + ) + .min(Float::compare) + .orElse(Float.MAX_VALUE); + assertEquals(0.001f, minScoreScoreFromScoreDocs, 0.0f); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index e1f907a5c..bf56ab92d 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -57,6 +57,8 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000; private static final String DEFAULT_USER_AGENT = "Kibana"; + private static final String NORMALIZATION_METHOD = "min_max"; + private static final String COMBINATION_METHOD = "arithmetic_mean"; protected final ClassLoader classLoader = this.getClass().getClassLoader(); @@ -281,6 +283,27 @@ protected Map search(String index, QueryBuilder queryBuilder, in */ @SneakyThrows protected Map search(String index, QueryBuilder queryBuilder, QueryBuilder rescorer, int resultSize) { + return search(index, queryBuilder, rescorer, resultSize, Map.of()); + } + + /** + * Execute a search request initialized from a neural query builder that can add a rescore query to the request + * + * @param index Index to search against + * @param queryBuilder queryBuilder to produce source of query + * @param rescorer used for rescorer query builder + * @param resultSize number of results to return in the search + * @param requestParams additional request params for search + * @return Search results represented as a map + */ + @SneakyThrows + protected Map search( + String index, + QueryBuilder queryBuilder, + QueryBuilder rescorer, + int resultSize, + Map requestParams + ) { XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("query"); queryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -294,6 +317,9 @@ protected Map search(String index, QueryBuilder queryBuilder, Qu Request request = new Request("POST", "/" + index + "/_search"); request.addParameter("size", Integer.toString(resultSize)); + if (requestParams != null && !requestParams.isEmpty()) { + requestParams.forEach(request::addParameter); + } request.setJsonEntity(Strings.toString(builder)); Response response = client().performRequest(request); @@ -386,7 +412,12 @@ protected int getHitCount(Map searchResponseAsMap) { */ @SneakyThrows protected void prepareKnnIndex(String indexName, List knnFieldConfigs) { - createIndexWithConfiguration(indexName, buildIndexConfiguration(knnFieldConfigs), ""); + prepareKnnIndex(indexName, knnFieldConfigs, 3); + } + + @SneakyThrows + protected void prepareKnnIndex(String indexName, List knnFieldConfigs, int numOfShards) { + createIndexWithConfiguration(indexName, buildIndexConfiguration(knnFieldConfigs, numOfShards), ""); } /** @@ -425,11 +456,11 @@ protected boolean checkComplete(Map node) { } @SneakyThrows - private String buildIndexConfiguration(List knnFieldConfigs) { + private String buildIndexConfiguration(List knnFieldConfigs, int numberOfShards) { XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() .startObject("settings") - .field("number_of_shards", 3) + .field("number_of_shards", numberOfShards) .field("index.knn", true) .endObject() .startObject("mappings") @@ -524,4 +555,73 @@ protected void deleteModel(String modelId) { public boolean isUpdateClusterSettings() { return true; } + + @SneakyThrows + protected void createSearchPipelineWithResultsPostProcessor(final String pipelineId) { + createSearchPipeline(pipelineId, NORMALIZATION_METHOD, COMBINATION_METHOD, Map.of()); + } + + @SneakyThrows + protected void createSearchPipeline( + final String pipelineId, + final String normalizationMethod, + String combinationMethod, + final Map combinationParams + ) { + makeRequest( + client(), + "PUT", + String.format(LOCALE, "/_search/pipeline/%s", pipelineId), + null, + toHttpEntity( + String.format( + LOCALE, + "{\"description\": \"Post processor pipeline\"," + + "\"phase_results_processors\": [{ " + + "\"normalization-processor\": {" + + "\"normalization\": {" + + "\"technique\": \"%s\"" + + "}," + + "\"combination\": {" + + "\"technique\": \"%s\"" + + "}" + + "}}]}", + normalizationMethod, + combinationMethod + ) + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } + + @SneakyThrows + protected void createSearchPipelineWithDefaultResultsPostProcessor(final String pipelineId) { + makeRequest( + client(), + "PUT", + String.format(LOCALE, "/_search/pipeline/%s", pipelineId), + null, + toHttpEntity( + String.format( + LOCALE, + "{\"description\": \"Post processor pipeline\"," + + "\"phase_results_processors\": [{ " + + "\"normalization-processor\": {}}]}" + ) + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } + + @SneakyThrows + protected void deleteSearchPipeline(final String pipelineId) { + makeRequest( + client(), + "DELETE", + String.format(LOCALE, "/_search/pipeline/%s", pipelineId), + null, + toHttpEntity(""), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index c4b1d49f7..7918126c5 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -12,12 +12,16 @@ import java.util.Optional; import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; +import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; +import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.query.QueryPhaseSearcher; public class NeuralSearchTests extends OpenSearchQueryTestCase { @@ -55,4 +59,18 @@ public void testProcessors() { assertNotNull(processors); assertNotNull(processors.get(TextEmbeddingProcessor.TYPE)); } + + public void testSearchPhaseResultsProcessors() { + NeuralSearch plugin = new NeuralSearch(); + SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); + Map> searchPhaseResultsProcessors = plugin + .getSearchPhaseResultsProcessors(parameters); + assertNotNull(searchPhaseResultsProcessors); + assertEquals(1, searchPhaseResultsProcessors.size()); + assertTrue(searchPhaseResultsProcessors.containsKey("normalization-processor")); + org.opensearch.search.pipeline.Processor.Factory scoringProcessor = searchPhaseResultsProcessors.get( + NormalizationProcessor.TYPE + ); + assertTrue(scoringProcessor instanceof NormalizationProcessorFactory); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java new file mode 100644 index 000000000..5c9d4b2b7 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -0,0 +1,252 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.junit.After; +import org.junit.Before; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseController; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchProgressListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.breaker.CircuitBreaker; +import org.opensearch.common.breaker.NoopCircuitBreaker; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.util.BigArrays; +import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.common.util.concurrent.OpenSearchThreadPoolExecutor; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.InternalAggregation; +import org.opensearch.search.aggregations.pipeline.PipelineAggregator; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; + +public class NormalizationProcessorTests extends OpenSearchTestCase { + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + private static final String INDEX_NAME = "index1"; + private static final String NORMALIZATION_METHOD = "min_max"; + private static final String COMBINATION_METHOD = "arithmetic_mean"; + private SearchPhaseController searchPhaseController; + private ThreadPool threadPool; + private OpenSearchThreadPoolExecutor executor; + + @Before + public void setup() { + searchPhaseController = new SearchPhaseController(writableRegistry(), s -> new InternalAggregation.ReduceContextBuilder() { + @Override + public InternalAggregation.ReduceContext forPartialReduction() { + return InternalAggregation.ReduceContext.forPartialReduction( + BigArrays.NON_RECYCLING_INSTANCE, + null, + () -> PipelineAggregator.PipelineTree.EMPTY + ); + } + + public InternalAggregation.ReduceContext forFinalReduction() { + return InternalAggregation.ReduceContext.forFinalReduction( + BigArrays.NON_RECYCLING_INSTANCE, + null, + b -> {}, + PipelineAggregator.PipelineTree.EMPTY + ); + }; + }); + threadPool = new TestThreadPool(NormalizationProcessorTests.class.getName()); + executor = OpenSearchExecutors.newFixed( + "test", + 1, + 10, + OpenSearchExecutors.daemonThreadFactory("test"), + threadPool.getThreadContext() + ); + } + + @After + public void cleanup() { + executor.shutdownNow(); + terminate(threadPool); + } + + public void testClassFields_whenCreateNewObject_thenAllFieldsPresent() { + NormalizationProcessor normalizationProcessor = new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD), + new ScoreCombinationFactory().createCombination(COMBINATION_METHOD), + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + + assertEquals(DESCRIPTION, normalizationProcessor.getDescription()); + assertEquals(PROCESSOR_TAG, normalizationProcessor.getTag()); + assertEquals(SearchPhaseName.FETCH, normalizationProcessor.getAfterPhase()); + assertEquals(SearchPhaseName.QUERY, normalizationProcessor.getBeforePhase()); + assertFalse(normalizationProcessor.isIgnoreFailure()); + } + + public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombination() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + NormalizationProcessor normalizationProcessor = new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD), + new ScoreCombinationFactory().createCombination(COMBINATION_METHOD), + normalizationProcessorWorkflow + ); + + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.setBatchedReduceSize(4); + AtomicReference onPartialMergeFailure = new AtomicReference<>(); + QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( + searchRequest, + executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + searchPhaseController, + SearchProgressListener.NOOP, + writableRegistry(), + 10, + e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { + curr.addSuppressed(prev); + return curr; + }) + ); + CountDownLatch partialReduceLatch = new CountDownLatch(5); + for (int shardId = 0; shardId < 4; shardId++) { + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + CompoundTopDocs topDocs = new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f), new ScoreDoc(4, 0.25f), new ScoreDoc(10, 0.2f) } + ) + ) + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + + queryPhaseResultConsumer.consumeResult(querySearchResult, partialReduceLatch::countDown); + } + + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); + + List querySearchResults = queryPhaseResultConsumer.getAtomicArray() + .asList() + .stream() + .map(result -> result == null ? null : result.queryResult()) + .collect(Collectors.toList()); + + TestUtils.assertQueryResultScores(querySearchResults); + } + + public void testEmptySearchResults_whenEmptySearchResults_thenDoNotExecuteWorkflow() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + NormalizationProcessor normalizationProcessor = new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD), + new ScoreCombinationFactory().createCombination(ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME), + normalizationProcessorWorkflow + ); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + normalizationProcessor.process(null, searchPhaseContext); + + verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any()); + } + + public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResult_thenDoNotExecuteWorkflow() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + NormalizationProcessor normalizationProcessor = new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD), + new ScoreCombinationFactory().createCombination(COMBINATION_METHOD), + normalizationProcessorWorkflow + ); + + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.setBatchedReduceSize(4); + AtomicReference onPartialMergeFailure = new AtomicReference<>(); + QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( + searchRequest, + executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + searchPhaseController, + SearchProgressListener.NOOP, + writableRegistry(), + 10, + e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { + curr.addSuppressed(prev); + return curr; + }) + ); + CountDownLatch partialReduceLatch = new CountDownLatch(5); + for (int shardId = 0; shardId < 4; shardId++) { + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f), new ScoreDoc(4, 0.25f), new ScoreDoc(10, 0.2f) } + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + + queryPhaseResultConsumer.consumeResult(querySearchResult, partialReduceLatch::countDown); + } + + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); + + verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java new file mode 100644 index 000000000..453725a0d --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.mockito.Mockito.spy; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.action.OriginalIndices; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.OpenSearchTestCase; + +public class NormalizationProcessorWorkflowTests extends OpenSearchTestCase { + + public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombination() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + + List querySearchResults = new ArrayList<>(); + for (int shardId = 0; shardId < 4; shardId++) { + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + CompoundTopDocs topDocs = new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f), new ScoreDoc(4, 0.25f), new ScoreDoc(10, 0.2f) } + ) + ) + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + querySearchResults.add(querySearchResult); + } + + normalizationProcessorWorkflow.execute( + querySearchResults, + ScoreNormalizationFactory.DEFAULT_METHOD, + ScoreCombinationFactory.DEFAULT_METHOD + ); + + TestUtils.assertQueryResultScores(querySearchResults); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java new file mode 100644 index 000000000..1a7f895cd --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java @@ -0,0 +1,86 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.List; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.test.OpenSearchTestCase; + +public class ScoreCombinationTechniqueTests extends OpenSearchTestCase { + + public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() { + ScoreCombiner scoreCombiner = new ScoreCombiner(); + scoreCombiner.combineScores(List.of(), ScoreCombinationFactory.DEFAULT_METHOD); + } + + public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenScoresCombined() { + ScoreCombiner scoreCombiner = new ScoreCombiner(); + + final List queryTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(1, 1.0f), new ScoreDoc(2, .25f), new ScoreDoc(4, 0.001f) } + ), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 1.0f), new ScoreDoc(5, 0.001f) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.9f), new ScoreDoc(4, 0.6f), new ScoreDoc(7, 0.5f), new ScoreDoc(9, 0.01f) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]) + ) + ) + ); + + scoreCombiner.combineScores(queryTopDocs, ScoreCombinationFactory.DEFAULT_METHOD); + + assertNotNull(queryTopDocs); + assertEquals(3, queryTopDocs.size()); + + assertEquals(3, queryTopDocs.get(0).scoreDocs.length); + assertEquals(1.0, queryTopDocs.get(0).scoreDocs[0].score, 0.001f); + assertEquals(1, queryTopDocs.get(0).scoreDocs[0].doc); + assertEquals(1.0, queryTopDocs.get(0).scoreDocs[1].score, 0.001f); + assertEquals(3, queryTopDocs.get(0).scoreDocs[1].doc); + assertEquals(0.25, queryTopDocs.get(0).scoreDocs[2].score, 0.001f); + assertEquals(2, queryTopDocs.get(0).scoreDocs[2].doc); + + assertEquals(4, queryTopDocs.get(1).scoreDocs.length); + assertEquals(0.9, queryTopDocs.get(1).scoreDocs[0].score, 0.001f); + assertEquals(2, queryTopDocs.get(1).scoreDocs[0].doc); + assertEquals(0.6, queryTopDocs.get(1).scoreDocs[1].score, 0.001f); + assertEquals(4, queryTopDocs.get(1).scoreDocs[1].doc); + assertEquals(0.5, queryTopDocs.get(1).scoreDocs[2].score, 0.001f); + assertEquals(7, queryTopDocs.get(1).scoreDocs[2].doc); + assertEquals(0.01, queryTopDocs.get(1).scoreDocs[3].score, 0.001f); + assertEquals(9, queryTopDocs.get(1).scoreDocs[3].doc); + + assertEquals(0, queryTopDocs.get(2).scoreDocs.length); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java new file mode 100644 index 000000000..d72ee6f7f --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationCombinationIT.java @@ -0,0 +1,405 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.neuralsearch.TestUtils.createRandomVector; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.IntStream; + +import lombok.SneakyThrows; + +import org.apache.commons.lang3.Range; +import org.junit.After; +import org.junit.Before; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; + +import com.google.common.primitives.Floats; + +public class ScoreNormalizationCombinationIT extends BaseNeuralSearchIT { + private static final String TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME = "test-neural-multi-doc-one-shard-index"; + private static final String TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME = "test-neural-multi-doc-three-shards-index"; + private static final String TEST_QUERY_TEXT3 = "hello"; + private static final String TEST_QUERY_TEXT4 = "place"; + private static final String TEST_QUERY_TEXT5 = "welcome"; + private static final String TEST_QUERY_TEXT6 = "notexistingword"; + private static final String TEST_QUERY_TEXT7 = "notexistingwordtwo"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String TEST_DOC_TEXT4 = "Hello, I'm glad to you see you pal"; + private static final String TEST_DOC_TEXT5 = "Say hello and enter my friend"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; + private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final int TEST_DIMENSION = 768; + private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; + private static final AtomicReference modelId = new AtomicReference<>(); + private static final String SEARCH_PIPELINE = "phase-results-pipeline"; + private final float[] testVector1 = createRandomVector(TEST_DIMENSION); + private final float[] testVector2 = createRandomVector(TEST_DIMENSION); + private final float[] testVector3 = createRandomVector(TEST_DIMENSION); + private final float[] testVector4 = createRandomVector(TEST_DIMENSION); + private final static String RELATION_EQUAL_TO = "eq"; + private final static String RELATION_GREATER_OR_EQUAL_TO = "gte"; + + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + modelId.compareAndSet(null, prepareModel()); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteSearchPipeline(SEARCH_PIPELINE); + } + + @Override + public boolean isUpdateClusterSettings() { + return false; + } + + @Override + protected boolean preserveClusterUponCompletion() { + return true; + } + + /** + * Using search pipelines with result processor configs like below: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * "normalization": { + * "technique": "min-max" + * }, + * "combination": { + * "technique": "sum", + * "parameters": { + * "weights": [ + * 0.4, 0.7 + * ] + * } + * } + * } + * } + * ] + * } + */ + @SneakyThrows + public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertQueryResults(searchResponseAsMap, 5, false); + } + + /** + * Using search pipelines with default result processor configs: + * { + * "description": "Post processor for hybrid search", + * "phase_results_processors": [ + * { + * "normalization-processor": { + * } + * } + * ] + * } + */ + @SneakyThrows + public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME); + createSearchPipelineWithDefaultResultsPostProcessor(SEARCH_PIPELINE); + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertQueryResults(searchResponseAsMap, 5, false); + } + + @SneakyThrows + public void testResultProcessor_whenMultipleShardsAndQueryMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + int totalExpectedDocQty = 6; + + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 6, null, null); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 6, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertNotNull(searchResponseAsMap); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(totalExpectedDocQty, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertTrue(Range.between(.75f, 1.0f).contains(getMaxScore(searchResponseAsMap).get())); + List> hitsNestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + // verify scores order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + + // verify the scores are normalized. we need special assert logic because combined score may vary as neural search query + // based on random vectors and return results for every doc. In some cases that may affect 1.0 score from term query and make it + // lower. + float highestScore = scores.stream().max(Double::compare).get().floatValue(); + assertTrue(Range.between(.75f, 1.0f).contains(highestScore)); + float lowestScore = scores.stream().min(Double::compare).get().floatValue(); + assertTrue(Range.between(.0f, .5f).contains(lowestScore)); + + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + } + + @SneakyThrows + public void testResultProcessor_whenMultipleShardsAndNoMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT6)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertQueryResults(searchResponseAsMap, 0, true); + } + + @SneakyThrows + public void testResultProcessor_whenMultipleShardsAndPartialMatches_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4)); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + assertQueryResults(searchResponseAsMap, 4, true); + } + + private void initializeIndexIfNotExist(String indexName) throws IOException { + if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 1 + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, + "5", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4) + ); + assertEquals(5, getDocCount(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)); + } + + if (TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME)) { + prepareKnnIndex( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + 3 + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "4", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "5", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4) + ); + addKnnDoc( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + "6", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT5) + ); + assertEquals(6, getDocCount(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME)); + } + } + + private List> getNestedHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (List>) hitsMap.get("hits"); + } + + private Map getTotalHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (Map) hitsMap.get("total"); + } + + private Optional getMaxScore(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); + } + + private void assertQueryResults(Map searchResponseAsMap, int totalExpectedDocQty, boolean assertMinScore) { + assertNotNull(searchResponseAsMap); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(totalExpectedDocQty, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + if (totalExpectedDocQty > 0) { + assertEquals(1.0, getMaxScore(searchResponseAsMap).get(), 0.001f); + } else { + assertEquals(0.0, getMaxScore(searchResponseAsMap).get(), 0.001f); + } + + List> hitsNestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + // verify scores order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify the scores are normalized + if (totalExpectedDocQty > 0) { + assertEquals(1.0, (double) scores.stream().max(Double::compare).get(), 0.001); + if (assertMinScore) { + assertEquals(0.001, (double) scores.stream().min(Double::compare).get(), 0.001); + } + } + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java new file mode 100644 index 000000000..6188a7ef5 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java @@ -0,0 +1,231 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.List; + +import lombok.SneakyThrows; + +import org.apache.commons.lang3.Range; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.test.OpenSearchTestCase; + +public class ScoreNormalizationTechniqueTests extends OpenSearchTestCase { + + public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() { + ScoreNormalizer scoreNormalizationMethod = new ScoreNormalizer(); + scoreNormalizationMethod.normalizeScores(List.of(), ScoreNormalizationFactory.DEFAULT_METHOD); + } + + @SneakyThrows + public void testNormalization_whenOneSubqueryAndOneShardAndDefaultMethod_thenScoreNormalized() { + ScoreNormalizer scoreNormalizationMethod = new ScoreNormalizer(); + final List queryTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + List.of(new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 2.0f) })) + ) + ); + scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + assertNotNull(queryTopDocs); + assertEquals(1, queryTopDocs.size()); + CompoundTopDocs resultDoc = queryTopDocs.get(0); + assertNotNull(resultDoc.getCompoundTopDocs()); + assertEquals(1, resultDoc.getCompoundTopDocs().size()); + TopDocs topDoc = resultDoc.getCompoundTopDocs().get(0); + assertEquals(1, topDoc.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDoc.totalHits.relation); + assertNotNull(topDoc.scoreDocs); + assertEquals(1, topDoc.scoreDocs.length); + ScoreDoc scoreDoc = topDoc.scoreDocs[0]; + assertEquals(1.0, scoreDoc.score, 0.001f); + assertEquals(1, scoreDoc.doc); + } + + @SneakyThrows + public void testNormalization_whenOneSubqueryMultipleHitsAndOneShardAndDefaultMethod_thenScoreNormalized() { + ScoreNormalizer scoreNormalizationMethod = new ScoreNormalizer(); + final List queryTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(1, 10.0f), new ScoreDoc(2, 2.5f), new ScoreDoc(4, 0.1f) } + ) + ) + ) + ); + scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + assertNotNull(queryTopDocs); + assertEquals(1, queryTopDocs.size()); + CompoundTopDocs resultDoc = queryTopDocs.get(0); + assertNotNull(resultDoc.getCompoundTopDocs()); + assertEquals(1, resultDoc.getCompoundTopDocs().size()); + TopDocs topDoc = resultDoc.getCompoundTopDocs().get(0); + assertEquals(3, topDoc.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDoc.totalHits.relation); + assertNotNull(topDoc.scoreDocs); + assertEquals(3, topDoc.scoreDocs.length); + ScoreDoc highScoreDoc = topDoc.scoreDocs[0]; + assertEquals(1.0, highScoreDoc.score, 0.001f); + assertEquals(1, highScoreDoc.doc); + ScoreDoc lowScoreDoc = topDoc.scoreDocs[topDoc.scoreDocs.length - 1]; + assertEquals(0.0, lowScoreDoc.score, 0.001f); + assertEquals(4, lowScoreDoc.doc); + } + + public void testNormalization_whenMultipleSubqueriesMultipleHitsAndOneShardAndDefaultMethod_thenScoreNormalized() { + ScoreNormalizer scoreNormalizationMethod = new ScoreNormalizer(); + final List queryTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(1, 10.0f), new ScoreDoc(2, 2.5f), new ScoreDoc(4, 0.1f) } + ), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.8f), new ScoreDoc(5, 0.5f) } + ) + ) + ) + ); + scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + + assertNotNull(queryTopDocs); + assertEquals(1, queryTopDocs.size()); + CompoundTopDocs resultDoc = queryTopDocs.get(0); + assertNotNull(resultDoc.getCompoundTopDocs()); + assertEquals(2, resultDoc.getCompoundTopDocs().size()); + // sub-query one + TopDocs topDocSubqueryOne = resultDoc.getCompoundTopDocs().get(0); + assertEquals(3, topDocSubqueryOne.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryOne.totalHits.relation); + assertNotNull(topDocSubqueryOne.scoreDocs); + assertEquals(3, topDocSubqueryOne.scoreDocs.length); + ScoreDoc highScoreDoc = topDocSubqueryOne.scoreDocs[0]; + assertEquals(1.0, highScoreDoc.score, 0.001f); + assertEquals(1, highScoreDoc.doc); + ScoreDoc lowScoreDoc = topDocSubqueryOne.scoreDocs[topDocSubqueryOne.scoreDocs.length - 1]; + assertEquals(0.0, lowScoreDoc.score, 0.001f); + assertEquals(4, lowScoreDoc.doc); + // sub-query two + TopDocs topDocSubqueryTwo = resultDoc.getCompoundTopDocs().get(1); + assertEquals(2, topDocSubqueryTwo.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryTwo.totalHits.relation); + assertNotNull(topDocSubqueryTwo.scoreDocs); + assertEquals(2, topDocSubqueryTwo.scoreDocs.length); + assertEquals(1.0, topDocSubqueryTwo.scoreDocs[0].score, 0.001f); + assertEquals(3, topDocSubqueryTwo.scoreDocs[0].doc); + assertEquals(0.0, topDocSubqueryTwo.scoreDocs[topDocSubqueryTwo.scoreDocs.length - 1].score, 0.001f); + assertEquals(5, topDocSubqueryTwo.scoreDocs[topDocSubqueryTwo.scoreDocs.length - 1].doc); + } + + public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAndDefaultMethod_thenScoreNormalized() { + ScoreNormalizer scoreNormalizationMethod = new ScoreNormalizer(); + final List queryTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(1, 10.0f), new ScoreDoc(2, 2.5f), new ScoreDoc(4, 0.1f) } + ), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.8f), new ScoreDoc(5, 0.5f) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 2.2f), new ScoreDoc(4, 1.8f), new ScoreDoc(7, 0.9f), new ScoreDoc(9, 0.01f) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]) + ) + ) + ); + scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + assertNotNull(queryTopDocs); + assertEquals(3, queryTopDocs.size()); + // shard one + CompoundTopDocs resultDocShardOne = queryTopDocs.get(0); + assertEquals(2, resultDocShardOne.getCompoundTopDocs().size()); + // sub-query one + TopDocs topDocSubqueryOne = resultDocShardOne.getCompoundTopDocs().get(0); + assertEquals(3, topDocSubqueryOne.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryOne.totalHits.relation); + assertNotNull(topDocSubqueryOne.scoreDocs); + assertEquals(3, topDocSubqueryOne.scoreDocs.length); + ScoreDoc highScoreDoc = topDocSubqueryOne.scoreDocs[0]; + assertEquals(1.0, highScoreDoc.score, 0.001f); + assertEquals(1, highScoreDoc.doc); + ScoreDoc lowScoreDoc = topDocSubqueryOne.scoreDocs[topDocSubqueryOne.scoreDocs.length - 1]; + assertEquals(0.0, lowScoreDoc.score, 0.001f); + assertEquals(4, lowScoreDoc.doc); + // sub-query two + TopDocs topDocSubqueryTwo = resultDocShardOne.getCompoundTopDocs().get(1); + assertEquals(2, topDocSubqueryTwo.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryTwo.totalHits.relation); + assertNotNull(topDocSubqueryTwo.scoreDocs); + assertEquals(2, topDocSubqueryTwo.scoreDocs.length); + assertTrue(Range.between(0.0f, 1.0f).contains(topDocSubqueryTwo.scoreDocs[0].score)); + assertEquals(3, topDocSubqueryTwo.scoreDocs[0].doc); + assertTrue(Range.between(0.0f, 1.0f).contains(topDocSubqueryTwo.scoreDocs[topDocSubqueryTwo.scoreDocs.length - 1].score)); + assertEquals(5, topDocSubqueryTwo.scoreDocs[topDocSubqueryTwo.scoreDocs.length - 1].doc); + + // shard two + CompoundTopDocs resultDocShardTwo = queryTopDocs.get(1); + assertEquals(2, resultDocShardTwo.getCompoundTopDocs().size()); + // sub-query one + TopDocs topDocShardTwoSubqueryOne = resultDocShardTwo.getCompoundTopDocs().get(0); + assertEquals(0, topDocShardTwoSubqueryOne.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardTwoSubqueryOne.totalHits.relation); + assertNotNull(topDocShardTwoSubqueryOne.scoreDocs); + assertEquals(0, topDocShardTwoSubqueryOne.scoreDocs.length); + // sub-query two + TopDocs topDocShardTwoSubqueryTwo = resultDocShardTwo.getCompoundTopDocs().get(1); + assertEquals(4, topDocShardTwoSubqueryTwo.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardTwoSubqueryTwo.totalHits.relation); + assertNotNull(topDocShardTwoSubqueryTwo.scoreDocs); + assertEquals(4, topDocShardTwoSubqueryTwo.scoreDocs.length); + assertEquals(1.0, topDocShardTwoSubqueryTwo.scoreDocs[0].score, 0.001f); + assertEquals(2, topDocShardTwoSubqueryTwo.scoreDocs[0].doc); + assertEquals(0.0, topDocShardTwoSubqueryTwo.scoreDocs[topDocShardTwoSubqueryTwo.scoreDocs.length - 1].score, 0.001f); + assertEquals(9, topDocShardTwoSubqueryTwo.scoreDocs[topDocShardTwoSubqueryTwo.scoreDocs.length - 1].doc); + + // shard three + CompoundTopDocs resultDocShardThree = queryTopDocs.get(2); + assertEquals(2, resultDocShardThree.getCompoundTopDocs().size()); + // sub-query one + TopDocs topDocShardThreeSubqueryOne = resultDocShardThree.getCompoundTopDocs().get(0); + assertEquals(0, topDocShardThreeSubqueryOne.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardThreeSubqueryOne.totalHits.relation); + assertEquals(0, topDocShardThreeSubqueryOne.scoreDocs.length); + // sub-query two + TopDocs topDocShardThreeSubqueryTwo = resultDocShardThree.getCompoundTopDocs().get(1); + assertEquals(0, topDocShardThreeSubqueryTwo.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardThreeSubqueryTwo.totalHits.relation); + assertEquals(0, topDocShardThreeSubqueryTwo.scoreDocs.length); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java new file mode 100644 index 000000000..babeed214 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java @@ -0,0 +1,176 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.mockito.Mockito.mock; + +import java.util.HashMap; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.opensearch.OpenSearchParseException; +import org.opensearch.neuralsearch.processor.NormalizationProcessor; +import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; +import org.opensearch.test.OpenSearchTestCase; + +public class NormalizationProcessorFactoryTests extends OpenSearchTestCase { + + private static final String NORMALIZATION_METHOD = "min_max"; + private static final String COMBINATION_METHOD = "arithmetic_mean"; + + @SneakyThrows + public void testNormalizationProcessor_whenNoParams_thenSuccessful() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertNotNull(searchPhaseResultsProcessor); + assertTrue(searchPhaseResultsProcessor instanceof NormalizationProcessor); + NormalizationProcessor normalizationProcessor = (NormalizationProcessor) searchPhaseResultsProcessor; + assertEquals("normalization-processor", normalizationProcessor.getType()); + } + + @SneakyThrows + public void testNormalizationProcessor_whenWithParams_thenSuccessful() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put("normalization", Map.of("technique", "min_max")); + config.put("combination", Map.of("technique", "arithmetic_mean")); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertNotNull(searchPhaseResultsProcessor); + assertTrue(searchPhaseResultsProcessor instanceof NormalizationProcessor); + NormalizationProcessor normalizationProcessor = (NormalizationProcessor) searchPhaseResultsProcessor; + assertEquals("normalization-processor", normalizationProcessor.getType()); + } + + public void testInputValidation_whenInvalidParameters_thenFail() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + + expectThrows( + OpenSearchParseException.class, + () -> normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + new HashMap<>( + Map.of( + NormalizationProcessorFactory.NORMALIZATION_CLAUSE, + Map.of(NormalizationProcessorFactory.TECHNIQUE, ""), + NormalizationProcessorFactory.COMBINATION_CLAUSE, + Map.of(NormalizationProcessorFactory.TECHNIQUE, ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME) + ) + ), + pipelineContext + ) + ); + + expectThrows( + OpenSearchParseException.class, + () -> normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + new HashMap<>( + Map.of( + NormalizationProcessorFactory.NORMALIZATION_CLAUSE, + Map.of(NormalizationProcessorFactory.TECHNIQUE, NORMALIZATION_METHOD), + NormalizationProcessorFactory.COMBINATION_CLAUSE, + Map.of(NormalizationProcessorFactory.TECHNIQUE, "") + ) + ), + pipelineContext + ) + ); + + expectThrows( + OpenSearchParseException.class, + () -> normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + new HashMap<>( + Map.of( + NormalizationProcessorFactory.NORMALIZATION_CLAUSE, + Map.of(NormalizationProcessorFactory.TECHNIQUE, "random_name_for_normalization"), + NormalizationProcessorFactory.COMBINATION_CLAUSE, + Map.of(NormalizationProcessorFactory.TECHNIQUE, ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME) + ) + ), + pipelineContext + ) + ); + + expectThrows( + OpenSearchParseException.class, + () -> normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + new HashMap<>( + Map.of( + NormalizationProcessorFactory.NORMALIZATION_CLAUSE, + Map.of(NormalizationProcessorFactory.TECHNIQUE, NORMALIZATION_METHOD), + NormalizationProcessorFactory.COMBINATION_CLAUSE, + Map.of(NormalizationProcessorFactory.TECHNIQUE, "random_name_for_combination") + ) + ), + pipelineContext + ) + ); + } +}