Skip to content

Commit

Permalink
Adding L2 norm technique (#236)
Browse files Browse the repository at this point in the history
* Adding L2 norm technique

Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski authored Jul 31, 2023
1 parent fe72dbc commit 6ad641a
Show file tree
Hide file tree
Showing 8 changed files with 559 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.combination;

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Abstracts combination of scores based on arithmetic mean method
*/
public class HarmonicMeanScoreCombinationTechnique implements ScoreCombinationTechnique {

public static final String TECHNIQUE_NAME = "arithmetic_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS);
private static final Float ZERO_SCORE = 0.0f;
private final List<Float> weights;

public HarmonicMeanScoreCombinationTechnique(final Map<String, Object> params) {
validateParams(params);
weights = getWeights(params);
}

private List<Float> getWeights(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
return List.of();
}
// get weights, we don't need to check for instance as it's done during validation
return ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream()
.map(Double::floatValue)
.collect(Collectors.toUnmodifiableList());
}

/**
* Arithmetic mean method for combining scores.
* score = (weight1*score1 + weight2*score2 +...+ weightN*scoreN)/(weight1 + weight2 + ... + weightN)
*
* Zero (0.0) scores are excluded from number of scores N
*/
@Override
public float combine(final float[] scores) {
float combinedScore = 0.0f;
float weights = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
float score = scores[indexOfSubQuery];
if (score >= 0.0) {
float weight = getWeightForSubQuery(indexOfSubQuery);
score = score * weight;
combinedScore += score;
weights += weight;
}
}
if (weights == 0.0f) {
return ZERO_SCORE;
}
return combinedScore / weights;
}

private void validateParams(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
return;
}
// check if only supported params are passed
Optional<String> optionalNotSupportedParam = params.keySet()
.stream()
.filter(paramName -> !SUPPORTED_PARAMS.contains(paramName))
.findFirst();
if (optionalNotSupportedParam.isPresent()) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"provided parameter for combination technique is not supported. supported parameters are [%s]",
SUPPORTED_PARAMS.stream().collect(Collectors.joining(","))
)
);
}

// check param types
if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) {
if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS)
);
}
}
}

/**
* Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise
* @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query
* @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default
*/
private float getWeightForSubQuery(int indexOfSubQuery) {
return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.normalization;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.neuralsearch.search.CompoundTopDocs;

/**
* Abstracts normalization of scores based on L2 method
*/
public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechnique {

public static final String TECHNIQUE_NAME = "l2";
private static final float MIN_SCORE = 0.001f;

/**
* L2 normalization method.
* n_score_i = score_i/sqrt(score1^2 + score2^2 + ... + scoren^2)
* Main algorithm steps:
* - calculate sum of squares of all scores
* - iterate over each result and update score as per formula above where "score" is raw score returned by Hybrid query
*/
@Override
public void normalize(final List<CompoundTopDocs> queryTopDocs) {
// get l2 norms for each sub-query
List<Float> normsPerSubquery = getL2Norm(queryTopDocs);

// do normalization using actual score and l2 norm
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
scoreDoc.score = normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(j));
}
}
}
}

private List<Float> getL2Norm(final List<CompoundTopDocs> queryTopDocs) {
// find any non-empty compound top docs, it's either empty if shard does not have any results for all of sub-queries,
// or it has results for all the sub-queries. In edge case of shard having results only for one sub-query, there will be TopDocs for
// rest of sub-queries with zero total hits
int numOfSubqueries = queryTopDocs.stream()
.filter(Objects::nonNull)
.filter(topDocs -> topDocs.getCompoundTopDocs().size() > 0)
.findAny()
.get()
.getCompoundTopDocs()
.size();
float[] l2Norms = new float[numOfSubqueries];
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs();
int bound = topDocsPerSubQuery.size();
for (int index = 0; index < bound; index++) {
for (ScoreDoc scoreDocs : topDocsPerSubQuery.get(index).scoreDocs) {
l2Norms[index] += scoreDocs.score * scoreDocs.score;
}
}
}
for (int index = 0; index < l2Norms.length; index++) {
l2Norms[index] = (float) Math.sqrt(l2Norms[index]);
}
List<Float> l2NormList = new ArrayList<>();
for (int index = 0; index < numOfSubqueries; index++) {
l2NormList.add(l2Norms[index]);
}
return l2NormList;
}

private float normalizeSingleScore(final float score, final float l2Norm) {
return l2Norm == 0 ? MIN_SCORE : score / l2Norm;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ public class ScoreNormalizationFactory {

private final Map<String, ScoreNormalizationTechnique> scoreNormalizationMethodsMap = Map.of(
MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME,
new MinMaxScoreNormalizationTechnique()
new MinMaxScoreNormalizationTechnique(),
L2ScoreNormalizationTechnique.TECHNIQUE_NAME,
new L2ScoreNormalizationTechnique()
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ protected String uploadModel(String requestBody) throws Exception {
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> uploadResJson = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
XContentType.JSON.xContent(),
EntityUtils.toString(uploadResponse.getEntity()),
false
);
Expand Down Expand Up @@ -136,7 +136,7 @@ protected void loadModel(String modelId) throws Exception {
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> uploadResJson = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
XContentType.JSON.xContent(),
EntityUtils.toString(uploadResponse.getEntity()),
false
);
Expand Down Expand Up @@ -185,7 +185,7 @@ protected float[] runInference(String modelId, String queryText) {
);

Map<String, Object> inferenceResJson = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
XContentType.JSON.xContent(),
EntityUtils.toString(inferenceResponse.getEntity()),
false
);
Expand Down Expand Up @@ -215,7 +215,7 @@ protected void createIndexWithConfiguration(String indexName, String indexConfig
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> node = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
XContentType.JSON.xContent(),
EntityUtils.toString(response.getEntity()),
false
);
Expand All @@ -239,7 +239,7 @@ protected void createPipelineProcessor(String modelId, String pipelineName) thro
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
Map<String, Object> node = XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
XContentType.JSON.xContent(),
EntityUtils.toString(pipelineCreateResponse.getEntity()),
false
);
Expand Down Expand Up @@ -329,7 +329,7 @@ protected Map<String, Object> search(

String responseBody = EntityUtils.toString(response.getEntity());

return XContentHelper.convertToMap(XContentFactory.xContent(XContentType.JSON), responseBody, false);
return XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false);
}

/**
Expand Down Expand Up @@ -445,11 +445,7 @@ protected Map<String, Object> getTaskQueryResponse(String taskId) throws Excepti
toHttpEntity(""),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
return XContentHelper.convertToMap(
XContentFactory.xContent(XContentType.JSON),
EntityUtils.toString(taskQueryResponse.getEntity()),
false
);
return XContentHelper.convertToMap(XContentType.JSON.xContent(), EntityUtils.toString(taskQueryResponse.getEntity()), false);
}

protected boolean checkComplete(Map<String, Object> node) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.neuralsearch.common.BaseNeuralSearchIT;
import org.opensearch.neuralsearch.processor.normalization.L2ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;

Expand Down Expand Up @@ -93,12 +94,7 @@ protected boolean preserveClusterUponCompletion() {
* "technique": "min-max"
* },
* "combination": {
* "technique": "sum",
* "parameters": {
* "weights": [
* 0.4, 0.7
* ]
* }
* "technique": "arithmetic_mean"
* }
* }
* }
Expand Down Expand Up @@ -251,6 +247,29 @@ public void testResultProcessor_whenMultipleShardsAndPartialMatches_thenSuccessf
assertQueryResults(searchResponseAsMap, 4, true);
}

/**
* Using search pipelines with result processor configs like below:
* {
* "description": "Post processor for hybrid search",
* "phase_results_processors": [
* {
* "normalization-processor": {
* "normalization": {
* "technique": "min-max"
* },
* "combination": {
* "technique": "arithmetic_mean",
* "parameters": {
* "weights": [
* 0.4, 0.7
* ]
* }
* }
* }
* }
* ]
* }
*/
@SneakyThrows
public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME);
Expand Down Expand Up @@ -337,6 +356,74 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() {
assertWeightedScores(searchResponseWithWeights4AsMap, 1.0, 1.0, 0.001);
}

/**
* Using search pipelines with config for l2 norm:
* {
* "description": "Post processor for hybrid search",
* "phase_results_processors": [
* {
* "normalization-processor": {
* "normalization": {
* "technique": "l2"
* },
* "combination": {
* "technique": "arithmetic_mean"
* }
* }
* }
* ]
* }
*/
@SneakyThrows
public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME);
createSearchPipeline(
SEARCH_PIPELINE,
L2ScoreNormalizationTechnique.TECHNIQUE_NAME,
COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f }))
);

NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId.get(), 5, null, null);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(neuralQueryBuilder);
hybridQueryBuilder.add(termQueryBuilder);

Map<String, Object> searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
);
int totalExpectedDocQty = 5;
assertNotNull(searchResponseAsMap);
Map<String, Object> total = getTotalHits(searchResponseAsMap);
assertNotNull(total.get("value"));
assertEquals(totalExpectedDocQty, total.get("value"));
assertNotNull(total.get("relation"));
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
assertTrue(getMaxScore(searchResponseAsMap).isPresent());
assertTrue(Range.between(.6f, 1.0f).contains(getMaxScore(searchResponseAsMap).get()));

List<Map<String, Object>> hitsNestedList = getNestedHits(searchResponseAsMap);
List<String> ids = new ArrayList<>();
List<Double> scores = new ArrayList<>();
for (Map<String, Object> oneHit : hitsNestedList) {
ids.add((String) oneHit.get("_id"));
scores.add((Double) oneHit.get("_score"));
}
// verify scores order
assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1)));
// verify the scores are normalized. for l2 scores max score will not be 1.0 so we're checking on a range
assertTrue(Range.between(.6f, 1.0f).contains((float) scores.stream().map(Double::floatValue).max(Double::compare).get()));

// verify that all ids are unique
assertEquals(Set.copyOf(ids).size(), ids.size());
}

private void initializeIndexIfNotExist(String indexName) throws IOException {
if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) {
prepareKnnIndex(
Expand Down
Loading

0 comments on commit 6ad641a

Please sign in to comment.