Skip to content

Commit

Permalink
Make WANDScorer compute scores on the fly. (apache#14021)
Browse files Browse the repository at this point in the history
Currently, `WANDSCorer` considers that a hit is a match if the sum of maximum
scores across clauses is more than or equal to the minimum competitive score.
We can do better by computing scores of leading clauses on the fly. This helps
because scores are often lower than the score upper bound, so using actual
scores instead of score upper bounds can help skip advancing more clauses.

For reference, we are already doing the same trick in our conjunction (bulk)
scorers and in `MaxScoreBulkScorer` (bulk scorer for top-level disjunctions).
  • Loading branch information
jpountz authored and benchaplin committed Dec 31, 2024
1 parent 57bf762 commit d5ed4ff
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 18 deletions.
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ Optimizations

* GITHUB#13989: Faster checksum computation. (Jean-François Boeuf)

* GITHUB#14021: WANDScorer now computes scores on the fly, which helps prevent
advancing "tail" clauses in many cases. (Adrien Grand)

Bug Fixes
---------------------
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended
Expand Down
74 changes: 56 additions & 18 deletions lucene/core/src/java/org/apache/lucene/search/WANDScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.apache.lucene.util.MathUtil;

Expand Down Expand Up @@ -129,7 +130,7 @@ private static long scaleMinScore(float minScore, int scalingFactor) {
// some descriptions of WAND (Weak AND).
DisiWrapper lead;
int doc; // current doc ID of the leads
long leadMaxScore; // sum of the max scores of scorers in 'lead'
double leadScore; // score of the leads

// priority queue of scorers that are too advanced compared to the current
// doc. Ordered by doc ID.
Expand Down Expand Up @@ -195,7 +196,7 @@ private static long scaleMinScore(float minScore, int scalingFactor) {
}

for (Scorer scorer : scorers) {
addLead(new DisiWrapper(scorer));
addUnpositionedLead(new DisiWrapper(scorer));
}

this.cost =
Expand All @@ -208,7 +209,7 @@ private static long scaleMinScore(float minScore, int scalingFactor) {

// returns a boolean so that it can be called from assert
// the return value is useless: it always returns true
private boolean ensureConsistent() {
private boolean ensureConsistent() throws IOException {
if (scoreMode == ScoreMode.TOP_SCORES) {
long maxScoreSum = 0;
for (int i = 0; i < tailSize; ++i) {
Expand All @@ -217,12 +218,19 @@ private boolean ensureConsistent() {
}
assert maxScoreSum == tailMaxScore : maxScoreSum + " " + tailMaxScore;

maxScoreSum = 0;
List<Float> leadScores = new ArrayList<>();
for (DisiWrapper w = lead; w != null; w = w.next) {
assert w.doc == doc;
maxScoreSum = Math.addExact(maxScoreSum, w.scaledMaxScore);
leadScores.add(w.scorer.score());
}
assert maxScoreSum == leadMaxScore : maxScoreSum + " " + leadMaxScore;
// Make sure to recompute the sum in the same order to get the same floating point rounding
// errors.
Collections.reverse(leadScores);
double recomputedLeadScore = 0;
for (float score : leadScores) {
recomputedLeadScore += score;
}
assert recomputedLeadScore == leadScore;

assert minCompetitiveScore == 0
|| tailMaxScore < minCompetitiveScore
Expand Down Expand Up @@ -285,8 +293,6 @@ public int nextDoc() throws IOException {

@Override
public int advance(int target) throws IOException {
assert ensureConsistent();

// Move 'lead' iterators back to the tail
pushBackLeads(target);

Expand Down Expand Up @@ -319,17 +325,34 @@ public boolean matches() throws IOException {
assert lead == null;
moveToNextCandidate();

while (leadMaxScore < minCompetitiveScore || freq < minShouldMatch) {
if (leadMaxScore + tailMaxScore < minCompetitiveScore
long scaledLeadScore = 0;
if (scoreMode == ScoreMode.TOP_SCORES) {
scaledLeadScore =
scaleMaxScore(
(float) MathUtil.sumUpperBound(leadScore, FLOAT_MANTISSA_BITS), scalingFactor);
}

while (scaledLeadScore < minCompetitiveScore || freq < minShouldMatch) {
assert ensureConsistent();
if (scaledLeadScore + tailMaxScore < minCompetitiveScore
|| freq + tailSize < minShouldMatch) {
return false;
} else {
// a match on doc is still possible, try to
// advance scorers from the tail
DisiWrapper prevLead = lead;
advanceTail();
if (scoreMode == ScoreMode.TOP_SCORES && lead != prevLead) {
assert prevLead == lead.next;
scaledLeadScore =
scaleMaxScore(
(float) MathUtil.sumUpperBound(leadScore, FLOAT_MANTISSA_BITS),
scalingFactor);
}
}
}

assert ensureConsistent();
return true;
}

Expand All @@ -342,10 +365,20 @@ public float matchCost() {
}

/** Add a disi to the linked list of leads. */
private void addLead(DisiWrapper lead) {
private void addLead(DisiWrapper lead) throws IOException {
lead.next = this.lead;
this.lead = lead;
freq += 1;
if (scoreMode == ScoreMode.TOP_SCORES) {
leadScore += lead.scorer.score();
}
}

/** Add a disi to the linked list of leads. */
private void addUnpositionedLead(DisiWrapper lead) {
assert lead.doc == -1;
lead.next = this.lead;
this.lead = lead;
leadMaxScore += lead.scaledMaxScore;
freq += 1;
}

Expand All @@ -359,7 +392,6 @@ private void pushBackLeads(int target) throws IOException {
}
}
lead = null;
leadMaxScore = 0;
}

/** Make sure all disis in 'head' are on or after 'target'. */
Expand Down Expand Up @@ -488,8 +520,10 @@ private void moveToNextCandidate() throws IOException {
lead = head.pop();
assert doc == lead.doc;
lead.next = null;
leadMaxScore = lead.scaledMaxScore;
freq = 1;
if (scoreMode == ScoreMode.TOP_SCORES) {
leadScore = lead.scorer.score();
}
while (head.size() > 0 && head.top().doc == doc) {
addLead(head.pop());
}
Expand All @@ -514,11 +548,15 @@ private void advanceAllTail() throws IOException {
public float score() throws IOException {
// we need to know about all matches
advanceAllTail();
double score = 0;
for (DisiWrapper s = lead; s != null; s = s.next) {
score += s.scorer.score();

double leadScore = this.leadScore;
if (scoreMode != ScoreMode.TOP_SCORES) {
// With TOP_SCORES, the score was already computed on the fly.
for (DisiWrapper s = lead; s != null; s = s.next) {
leadScore += s.scorer.score();
}
}
return (float) score;
return (float) leadScore;
}

@Override
Expand Down

0 comments on commit d5ed4ff

Please sign in to comment.