Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski committed Jul 28, 2023
1 parent 65a0d5f commit c80ba2f
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.combination;

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Abstracts combination of scores based on arithmetic mean method
*/
public class HarmonicMeanScoreCombinationTechnique implements ScoreCombinationTechnique {

public static final String TECHNIQUE_NAME = "arithmetic_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS);
private static final Float ZERO_SCORE = 0.0f;
private final List<Float> weights;

public HarmonicMeanScoreCombinationTechnique(final Map<String, Object> params) {
validateParams(params);
weights = getWeights(params);
}

private List<Float> getWeights(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
return List.of();
}
// get weights, we don't need to check for instance as it's done during validation
return ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream()
.map(Double::floatValue)
.collect(Collectors.toUnmodifiableList());
}

/**
* Arithmetic mean method for combining scores.
* score = (weight1*score1 + weight2*score2 +...+ weightN*scoreN)/(weight1 + weight2 + ... + weightN)
*
* Zero (0.0) scores are excluded from number of scores N
*/
@Override
public float combine(final float[] scores) {
float combinedScore = 0.0f;
float weights = 0;
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
float score = scores[indexOfSubQuery];
if (score >= 0.0) {
float weight = getWeightForSubQuery(indexOfSubQuery);
score = score * weight;
combinedScore += score;
weights += weight;
}
}
if (weights == 0.0f) {
return ZERO_SCORE;
}
return combinedScore / weights;
}

private void validateParams(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
return;
}
// check if only supported params are passed
Optional<String> optionalNotSupportedParam = params.keySet()
.stream()
.filter(paramName -> !SUPPORTED_PARAMS.contains(paramName))
.findFirst();
if (optionalNotSupportedParam.isPresent()) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"provided parameter for combination technique is not supported. supported parameters are [%s]",
SUPPORTED_PARAMS.stream().collect(Collectors.joining(","))
)
);
}

// check param types
if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) {
if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS)
);
}
}
}

/**
* Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise
* @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query
* @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default
*/
private float getWeightForSubQuery(int indexOfSubQuery) {
return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

package org.opensearch.neuralsearch.processor.normalization;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.IntStream;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
Expand All @@ -30,15 +30,8 @@ public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechniqu
*/
@Override
public void normalize(final List<CompoundTopDocs> queryTopDocs) {
int numOfSubqueries = queryTopDocs.stream()
.filter(Objects::nonNull)
.filter(topDocs -> topDocs.getCompoundTopDocs().size() > 0)
.findAny()
.get()
.getCompoundTopDocs()
.size();
// get l2 norms for each sub-query
float[] normsPerSubquery = getL2Norm(queryTopDocs, numOfSubqueries);
List<Float> normsPerSubquery = getL2Norm(queryTopDocs);

// do normalization using actual score and l2 norm
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
Expand All @@ -49,29 +42,44 @@ public void normalize(final List<CompoundTopDocs> queryTopDocs) {
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
scoreDoc.score = normalizeSingleScore(scoreDoc.score, normsPerSubquery[j]);
scoreDoc.score = normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(j));
}
}
}
}

private float[] getL2Norm(final List<CompoundTopDocs> queryTopDocs, final int numOfSubqueries) {
private List<Float> getL2Norm(final List<CompoundTopDocs> queryTopDocs) {
// find any non-empty compound top docs, it's either empty if shard does not have any results for all of sub-queries,
// or it has results for all the sub-queries. In edge case of shard having results only for one sub-query, there will be TopDocs for
// rest of sub-queries with zero total hits
int numOfSubqueries = queryTopDocs.stream()
.filter(Objects::nonNull)
.filter(topDocs -> topDocs.getCompoundTopDocs().size() > 0)
.findAny()
.get()
.getCompoundTopDocs()
.size();
float[] l2Norms = new float[numOfSubqueries];
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs();
IntStream.range(0, topDocsPerSubQuery.size()).forEach(index -> {
int bound = topDocsPerSubQuery.size();
for (int index = 0; index < bound; index++) {
for (ScoreDoc scoreDocs : topDocsPerSubQuery.get(index).scoreDocs) {
l2Norms[index] += scoreDocs.score * scoreDocs.score;
}
});
}
}
for (int index = 0; index < l2Norms.length; index++) {
l2Norms[index] = (float) Math.sqrt(l2Norms[index]);
}
return l2Norms;
List<Float> l2NormList = new ArrayList<>();
for (int index = 0; index < numOfSubqueries; index++) {
l2NormList.add(l2Norms[index]);
}
return l2NormList;
}

private float normalizeSingleScore(final float score, final float l2Norm) {
Expand Down

0 comments on commit c80ba2f

Please sign in to comment.