From c80ba2f8b5ea6965fb01e5d426bbab167d0915cd Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 28 Jul 2023 16:36:44 -0700 Subject: [PATCH] Address review comments Signed-off-by: Martin Gaievski --- ...HarmonicMeanScoreCombinationTechnique.java | 104 ++++++++++++++++++ .../L2ScoreNormalizationTechnique.java | 36 +++--- 2 files changed, 126 insertions(+), 14 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java new file mode 100644 index 000000000..3fff2db2b --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Abstracts combination of scores based on arithmetic mean method + */ +public class HarmonicMeanScoreCombinationTechnique implements ScoreCombinationTechnique { + + public static final String TECHNIQUE_NAME = "arithmetic_mean"; + public static final String PARAM_NAME_WEIGHTS = "weights"; + private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); + private static final Float ZERO_SCORE = 0.0f; + private final List weights; + + public HarmonicMeanScoreCombinationTechnique(final Map params) { + validateParams(params); + weights = getWeights(params); + } + + private List getWeights(final Map params) { + if (Objects.isNull(params) || params.isEmpty()) { + return List.of(); + } + // get weights, we don't need to check for instance as it's done during validation + return ((List) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream() + .map(Double::floatValue) + .collect(Collectors.toUnmodifiableList()); + } + + /** + * Arithmetic mean method for combining scores. + * score = (weight1*score1 + weight2*score2 +...+ weightN*scoreN)/(weight1 + weight2 + ... + weightN) + * + * Zero (0.0) scores are excluded from number of scores N + */ + @Override + public float combine(final float[] scores) { + float combinedScore = 0.0f; + float weights = 0; + for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { + float score = scores[indexOfSubQuery]; + if (score >= 0.0) { + float weight = getWeightForSubQuery(indexOfSubQuery); + score = score * weight; + combinedScore += score; + weights += weight; + } + } + if (weights == 0.0f) { + return ZERO_SCORE; + } + return combinedScore / weights; + } + + private void validateParams(final Map params) { + if (Objects.isNull(params) || params.isEmpty()) { + return; + } + // check if only supported params are passed + Optional optionalNotSupportedParam = params.keySet() + .stream() + .filter(paramName -> !SUPPORTED_PARAMS.contains(paramName)) + .findFirst(); + if (optionalNotSupportedParam.isPresent()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "provided parameter for combination technique is not supported. supported parameters are [%s]", + SUPPORTED_PARAMS.stream().collect(Collectors.joining(",")) + ) + ); + } + + // check param types + if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { + if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS) + ); + } + } + } + + /** + * Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise + * @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query + * @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default + */ + private float getWeightForSubQuery(int indexOfSubQuery) { + return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java index 0e78841e5..0007a3ef3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -5,9 +5,9 @@ package org.opensearch.neuralsearch.processor.normalization; +import java.util.ArrayList; import java.util.List; import java.util.Objects; -import java.util.stream.IntStream; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -30,15 +30,8 @@ public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechniqu */ @Override public void normalize(final List queryTopDocs) { - int numOfSubqueries = queryTopDocs.stream() - .filter(Objects::nonNull) - .filter(topDocs -> topDocs.getCompoundTopDocs().size() > 0) - .findAny() - .get() - .getCompoundTopDocs() - .size(); // get l2 norms for each sub-query - float[] normsPerSubquery = getL2Norm(queryTopDocs, numOfSubqueries); + List normsPerSubquery = getL2Norm(queryTopDocs); // do normalization using actual score and l2 norm for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { @@ -49,29 +42,44 @@ public void normalize(final List queryTopDocs) { for (int j = 0; j < topDocsPerSubQuery.size(); j++) { TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { - scoreDoc.score = normalizeSingleScore(scoreDoc.score, normsPerSubquery[j]); + scoreDoc.score = normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(j)); } } } } - private float[] getL2Norm(final List queryTopDocs, final int numOfSubqueries) { + private List getL2Norm(final List queryTopDocs) { + // find any non-empty compound top docs, it's either empty if shard does not have any results for all of sub-queries, + // or it has results for all the sub-queries. In edge case of shard having results only for one sub-query, there will be TopDocs for + // rest of sub-queries with zero total hits + int numOfSubqueries = queryTopDocs.stream() + .filter(Objects::nonNull) + .filter(topDocs -> topDocs.getCompoundTopDocs().size() > 0) + .findAny() + .get() + .getCompoundTopDocs() + .size(); float[] l2Norms = new float[numOfSubqueries]; for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { if (Objects.isNull(compoundQueryTopDocs)) { continue; } List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); - IntStream.range(0, topDocsPerSubQuery.size()).forEach(index -> { + int bound = topDocsPerSubQuery.size(); + for (int index = 0; index < bound; index++) { for (ScoreDoc scoreDocs : topDocsPerSubQuery.get(index).scoreDocs) { l2Norms[index] += scoreDocs.score * scoreDocs.score; } - }); + } } for (int index = 0; index < l2Norms.length; index++) { l2Norms[index] = (float) Math.sqrt(l2Norms[index]); } - return l2Norms; + List l2NormList = new ArrayList<>(); + for (int index = 0; index < numOfSubqueries; index++) { + l2NormList.add(l2Norms[index]); + } + return l2NormList; } private float normalizeSingleScore(final float score, final float l2Norm) {