Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hybrid search RRF duplicate results bug #957

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/unit_test_200gb_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ on:
push:
branches:
- mainline
- releases/*
pull_request:
branches:
- mainline
- releases/*

permissions:
contents: read
Expand Down
53 changes: 44 additions & 9 deletions vespa/src/main/java/ai/marqo/search/HybridSearcher.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package ai.marqo.search;

import com.sun.jdi.InternalException;
import com.yahoo.component.chain.dependencies.Before;
import com.yahoo.component.chain.dependencies.Provides;
import com.yahoo.search.Query;
Expand All @@ -21,6 +22,8 @@
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -40,6 +43,9 @@ public class HybridSearcher extends Searcher {
private static String MARQO_SEARCH_METHOD_TENSOR = "tensor";
private List<String> STANDARD_SEARCH_TYPES = new ArrayList<>();

// Compile the regex pattern once and store it as a static final variable
private static final Pattern PATTERN = Pattern.compile("^index\\:[^\\s\\/]+\\/\\d+\\/(.+)$");

@Override
public Result search(Query query, Execution execution) {
// All query parameters starting with 'marqo__' are custom for Marqo hybrid search.
Expand Down Expand Up @@ -149,8 +155,10 @@ HitGroup rrf(
HitGroup hitsTensor, HitGroup hitsLexical, Integer k, Double alpha, boolean verbose) {

HashMap<String, Double> rrfScores = new HashMap<>();
HashMap<String, String> docIdsToHitIds = new HashMap<>();
HitGroup result = new HitGroup();
Double reciprocalRank, existingScore, newScore;
String extractedDocId;

logIfVerbose("Beginning RRF process.", verbose);
logIfVerbose("Beginning (empty) result state: ", verbose);
Expand All @@ -174,9 +182,12 @@ HitGroup rrf(
verbose); // TODO: For easier debugging, expose marqo__id
logIfVerbose(hit.toString(), verbose);

extractedDocId = extractDocIdFromHitId(hit.getId().toString());
reciprocalRank = alpha * (1.0 / (rank + k));
rrfScores.put(
hit.getId().toString(), reciprocalRank); // Store hit's score via its URI
// Map hit's score to its shortened doc ID
rrfScores.put(extractedDocId, reciprocalRank);
// Map hit's full URI to its shortened doc ID
docIdsToHitIds.put(extractedDocId, hit.getId().toString());
hit.setField(
"marqo__raw_tensor_score",
hit.getRelevance()
Expand Down Expand Up @@ -208,7 +219,8 @@ HitGroup rrf(
verbose);

// Check if score already exists. If so, add to it.
existingScore = rrfScores.get(hit.getId().toString());
extractedDocId = extractDocIdFromHitId(hit.getId().toString());
existingScore = rrfScores.get(extractedDocId);
if (existingScore == null) {
// If the score doesn't exist, add new hit to result list (with rrf score).
logIfVerbose("No existing score found! Starting at 0.0.", verbose);
Expand All @@ -217,17 +229,21 @@ HitGroup rrf(
hit.getRelevance()
.getScore()); // Encode raw score for Marqo debugging purposes
hit.setRelevance(reciprocalRank); // Update score to be weighted RR (lexical)
rrfScores.put(hit.getId().toString(), reciprocalRank); // Log score in hashmap
// Map hit's score to its shortened doc ID
rrfScores.put(extractedDocId, reciprocalRank);
// Map hit's full URI to its shortened doc ID
docIdsToHitIds.put(extractedDocId, hit.getId().toString());
result.add(hit);

} else {
// If it does, find that hit in the result list and update it, adding new rrf to
// its score.
newScore = existingScore + reciprocalRank;
rrfScores.put(hit.getId().toString(), newScore);
rrfScores.put(extractedDocId, newScore);

// Update existing hit in result list (use map to find the full hit ID)
Hit existingHit = result.get(docIdsToHitIds.get(extractedDocId));

// Update existing hit in result list
Hit existingHit = result.get(hit.getId().toString());
existingHit.setField(
"marqo__raw_lexical_score",
hit.getRelevance()
Expand All @@ -237,7 +253,8 @@ HitGroup rrf(

logIfVerbose(
String.format(
"Existing score found for hit: %s.", hit.getId().toString()),
"Existing score found for hit: %s.",
extractDocIdFromHitId(hit.getId().toString())),
verbose);
logIfVerbose(String.format("Existing score is: %.7f", existingScore), verbose);
logIfVerbose(String.format("New score is: %.7f", newScore), verbose);
Expand Down Expand Up @@ -381,7 +398,9 @@ public void logHitGroup(HitGroup hits, boolean verbose) {
logger.info(
String.format(
"{IDX: %s, HIT ID: %s, RELEVANCE: %.7f}",
idx, hit.getId().toString(), hit.getRelevance().getScore()));
idx,
extractDocIdFromHitId(hit.getId().toString()),
hit.getRelevance().getScore()));
idx++;
}
logger.info("=======================");
Expand Down Expand Up @@ -424,4 +443,20 @@ Tensor extractTensorRankFeature(Query query, String featureName) {
String addQueryWrapper(String str) {
return "query(" + str + ")";
}

/*
* Extracts the document ID from a hit ID (use regex to extract the doc ID from the hit's URI)
*/
static String extractDocIdFromHitId(String fullPath) {
// Create a matcher for the input string using the precompiled pattern
Matcher matcher = PATTERN.matcher(fullPath);

// Check if the pattern matches and extract the document ID
if (matcher.find()) {
return matcher.group(1); // Return the captured group (document ID)
} else {
throw new InternalException(
"Vespa doc ID could not be extracted from the full hit ID: " + fullPath + ".");
}
}
}
Loading
Loading