Skip to content

Commit

Permalink
Added Score Normalization and Combination feature (#241)
Browse files Browse the repository at this point in the history
* Added Score Normalization and Combination feature

Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski authored Aug 3, 2023
1 parent 2ff58a9 commit 61e6e98
Show file tree
Hide file tree
Showing 57 changed files with 8,464 additions and 26 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.7...2.x)
### Features
* Added Score Normalization and Combination feature ([#241](https://github.com/opensearch-project/neural-search/pull/241/))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
6 changes: 6 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,12 @@ testClusters.integTest {
// Increase heap size from default of 512mb to 1gb. When heap size is 512mb, our integ tests sporadically fail due
// to ml-commons memory circuit breaker exception
jvmArgs("-Xms1g", "-Xmx1g")

// enable features for testing
// hybrid search
systemProperty('neural_search_hybrid_search_enabled', 'true')
// search pipelines
systemProperty('opensearch.experimental.feature.search_pipeline.enabled', 'true')
}

// Remote Integration Tests
Expand Down
61 changes: 57 additions & 4 deletions src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,70 @@

package org.opensearch.neuralsearch.plugin;

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import lombok.extern.log4j.Log4j2;

import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.ingest.Processor;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
import org.opensearch.plugins.IngestPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

import com.google.common.annotations.VisibleForTesting;

/**
* Neural Search plugin class
*/
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin {

@Log4j2
public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, IngestPlugin, ExtensiblePlugin, SearchPipelinePlugin {
/**
* Gates the functionality of hybrid search
* Currently query phase searcher added with hybrid search will conflict with concurrent search in core.
* Once that problem is resolved this feature flag can be removed.
*/
@VisibleForTesting
public static final String NEURAL_SEARCH_HYBRID_SEARCH_ENABLED = "neural_search_hybrid_search_enabled";
private MLCommonsClientAccessor clientAccessor;
private NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory();
private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();;

@Override
public Collection<Object> createComponents(
Expand All @@ -56,12 +85,15 @@ public Collection<Object> createComponents(
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
NeuralQueryBuilder.initialize(clientAccessor);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
return List.of(clientAccessor);
}

@Override
public List<QuerySpec<?>> getQueries() {
return Collections.singletonList(
new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent)
return Arrays.asList(
new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent),
new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent)
);
}

Expand All @@ -70,4 +102,25 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env));
}

@Override
public Optional<QueryPhaseSearcher> getQueryPhaseSearcher() {
if (FeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED)) {
log.info("Registering hybrid query phase searcher with feature flag [%]", NEURAL_SEARCH_HYBRID_SEARCH_ENABLED);
return Optional.of(new HybridQueryPhaseSearcher());
}
log.info("Not registering hybrid query phase searcher because feature flag [%] is disabled", NEURAL_SEARCH_HYBRID_SEARCH_ENABLED);
// we want feature be disabled by default due to risk of colliding and breaking concurrent search in core
return Optional.empty();
}

@Override
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchPhaseResultsProcessor>> getSearchPhaseResultsProcessors(
Parameters parameters
) {
return Map.of(
NormalizationProcessor.TYPE,
new NormalizationProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory)
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor;

import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.opensearch.action.search.QueryPhaseResultConsumer;
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseName;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.query.QuerySearchResult;

/**
* Processor for score normalization and combination on post query search results. Updates query results with
* normalized and combined scores for next phase (typically it's FETCH)
*/
@Log4j2
@AllArgsConstructor
public class NormalizationProcessor implements SearchPhaseResultsProcessor {
public static final String TYPE = "normalization-processor";

private final String tag;
private final String description;
private final ScoreNormalizationTechnique normalizationTechnique;
private final ScoreCombinationTechnique combinationTechnique;
private final NormalizationProcessorWorkflow normalizationWorkflow;

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
if (shouldRunProcessor(searchPhaseResult)) {
return;
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
normalizationWorkflow.execute(querySearchResults, normalizationTechnique, combinationTechnique);
}

@Override
public SearchPhaseName getBeforePhase() {
return SearchPhaseName.QUERY;
}

@Override
public SearchPhaseName getAfterPhase() {
return SearchPhaseName.FETCH;
}

@Override
public String getType() {
return TYPE;
}

@Override
public String getTag() {
return tag;
}

@Override
public String getDescription() {
return description;
}

@Override
public boolean isIgnoreFailure() {
return false;
}

private <Result extends SearchPhaseResult> boolean shouldRunProcessor(SearchPhaseResults<Result> searchPhaseResult) {
if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer)) {
return true;
}

QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer) searchPhaseResult;
Optional<SearchPhaseResult> optionalSearchPhaseResult = queryPhaseResultConsumer.getAtomicArray()
.asList()
.stream()
.filter(Objects::nonNull)
.findFirst();
return isNotHybridQuery(optionalSearchPhaseResult);
}

private boolean isNotHybridQuery(final Optional<SearchPhaseResult> optionalSearchPhaseResult) {
return optionalSearchPhaseResult.isEmpty()
|| Objects.isNull(optionalSearchPhaseResult.get().queryResult())
|| Objects.isNull(optionalSearchPhaseResult.get().queryResult().topDocs())
|| !(optionalSearchPhaseResult.get().queryResult().topDocs().topDocs instanceof CompoundTopDocs);
}

private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(
final SearchPhaseResults<Result> results
) {
return results.getAtomicArray()
.asList()
.stream()
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor;

import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

import lombok.AllArgsConstructor;

import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.search.CompoundTopDocs;
import org.opensearch.search.query.QuerySearchResult;

/**
* Class abstracts steps required for score normalization and combination, this includes pre-processing of incoming data
* and post-processing of final results
*/
@AllArgsConstructor
public class NormalizationProcessorWorkflow {

private final ScoreNormalizer scoreNormalizer;
private final ScoreCombiner scoreCombiner;

/**
* Start execution of this workflow
* @param querySearchResults input data with QuerySearchResult from multiple shards
* @param normalizationTechnique technique for score normalization
* @param combinationTechnique technique for score combination
*/
public void execute(
final List<QuerySearchResult> querySearchResults,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
) {
// pre-process data
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults);

// normalize
scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique);

// combine
scoreCombiner.combineScores(queryTopDocs, combinationTechnique);

// post-process data
updateOriginalQueryResults(querySearchResults, queryTopDocs);
}

/**
* Getting list of CompoundTopDocs from list of QuerySearchResult. Each CompoundTopDocs is for individual shard
* @param querySearchResults collection of QuerySearchResult for all shards
* @return collection of CompoundTopDocs, one object for each shard
*/
private List<CompoundTopDocs> getQueryTopDocs(final List<QuerySearchResult> querySearchResults) {
List<CompoundTopDocs> queryTopDocs = querySearchResults.stream()
.filter(searchResult -> Objects.nonNull(searchResult.topDocs()))
.filter(searchResult -> searchResult.topDocs().topDocs instanceof CompoundTopDocs)
.map(searchResult -> (CompoundTopDocs) searchResult.topDocs().topDocs)
.collect(Collectors.toList());
return queryTopDocs;
}

private void updateOriginalQueryResults(final List<QuerySearchResult> querySearchResults, final List<CompoundTopDocs> queryTopDocs) {
for (int i = 0; i < querySearchResults.size(); i++) {
QuerySearchResult querySearchResult = querySearchResults.get(i);
if (!(querySearchResult.topDocs().topDocs instanceof CompoundTopDocs) || Objects.isNull(queryTopDocs.get(i))) {
continue;
}
CompoundTopDocs updatedTopDocs = queryTopDocs.get(i);
float maxScore = updatedTopDocs.totalHits.value > 0 ? updatedTopDocs.scoreDocs[0].score : 0.0f;
TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(updatedTopDocs, maxScore);
querySearchResult.topDocs(updatedTopDocsAndMaxScore, null);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.combination;

import java.util.List;
import java.util.Map;
import java.util.Set;

/**
* Abstracts combination of scores based on arithmetic mean method
*/
public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique {

public static final String TECHNIQUE_NAME = "arithmetic_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS);
private static final Float ZERO_SCORE = 0.0f;
private final List<Float> weights;
private final ScoreCombinationUtil scoreCombinationUtil;

public ArithmeticMeanScoreCombinationTechnique(final Map<String, Object> params, final ScoreCombinationUtil combinationUtil) {
scoreCombinationUtil = combinationUtil;
scoreCombinationUtil.validateParams(params, SUPPORTED_PARAMS);
weights = scoreCombinationUtil.getWeights(params);
}

/**
* Arithmetic mean method for combining scores.
* score = (weight1*score1 + weight2*score2 +...+ weightN*scoreN)/(weight1 + weight2 + ... + weightN)
*
* Zero (0.0) scores are excluded from number of scores N
*/
@Override
public float combine(final float[] scores) {
float combinedScore = 0.0f;
float sumOfWeights = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
float score = scores[indexOfSubQuery];
if (score >= 0.0) {
float weight = scoreCombinationUtil.getWeightForSubQuery(weights, indexOfSubQuery);
score = score * weight;
combinedScore += score;
sumOfWeights += weight;
}
}
if (sumOfWeights == 0.0f) {
return ZERO_SCORE;
}
return combinedScore / sumOfWeights;
}
}
Loading

0 comments on commit 61e6e98

Please sign in to comment.