diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 45dab17417d3..a27d7c70e3f1 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -107,6 +107,9 @@ Optimizations * GITHUB#13989: Faster checksum computation. (Jean-François Boeuf) +* GITHUB#14021: WANDScorer now computes scores on the fly, which helps prevent + advancing "tail" clauses in many cases. (Adrien Grand) + Bug Fixes --------------------- * GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended diff --git a/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java b/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java index ca494c2d8bf4..59441d21539e 100644 --- a/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java @@ -24,6 +24,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.List; import org.apache.lucene.util.MathUtil; @@ -129,7 +130,7 @@ private static long scaleMinScore(float minScore, int scalingFactor) { // some descriptions of WAND (Weak AND). DisiWrapper lead; int doc; // current doc ID of the leads - long leadMaxScore; // sum of the max scores of scorers in 'lead' + double leadScore; // score of the leads // priority queue of scorers that are too advanced compared to the current // doc. Ordered by doc ID. @@ -195,7 +196,7 @@ private static long scaleMinScore(float minScore, int scalingFactor) { } for (Scorer scorer : scorers) { - addLead(new DisiWrapper(scorer)); + addUnpositionedLead(new DisiWrapper(scorer)); } this.cost = @@ -208,7 +209,7 @@ private static long scaleMinScore(float minScore, int scalingFactor) { // returns a boolean so that it can be called from assert // the return value is useless: it always returns true - private boolean ensureConsistent() { + private boolean ensureConsistent() throws IOException { if (scoreMode == ScoreMode.TOP_SCORES) { long maxScoreSum = 0; for (int i = 0; i < tailSize; ++i) { @@ -217,12 +218,19 @@ private boolean ensureConsistent() { } assert maxScoreSum == tailMaxScore : maxScoreSum + " " + tailMaxScore; - maxScoreSum = 0; + List leadScores = new ArrayList<>(); for (DisiWrapper w = lead; w != null; w = w.next) { assert w.doc == doc; - maxScoreSum = Math.addExact(maxScoreSum, w.scaledMaxScore); + leadScores.add(w.scorer.score()); } - assert maxScoreSum == leadMaxScore : maxScoreSum + " " + leadMaxScore; + // Make sure to recompute the sum in the same order to get the same floating point rounding + // errors. + Collections.reverse(leadScores); + double recomputedLeadScore = 0; + for (float score : leadScores) { + recomputedLeadScore += score; + } + assert recomputedLeadScore == leadScore; assert minCompetitiveScore == 0 || tailMaxScore < minCompetitiveScore @@ -285,8 +293,6 @@ public int nextDoc() throws IOException { @Override public int advance(int target) throws IOException { - assert ensureConsistent(); - // Move 'lead' iterators back to the tail pushBackLeads(target); @@ -319,17 +325,34 @@ public boolean matches() throws IOException { assert lead == null; moveToNextCandidate(); - while (leadMaxScore < minCompetitiveScore || freq < minShouldMatch) { - if (leadMaxScore + tailMaxScore < minCompetitiveScore + long scaledLeadScore = 0; + if (scoreMode == ScoreMode.TOP_SCORES) { + scaledLeadScore = + scaleMaxScore( + (float) MathUtil.sumUpperBound(leadScore, FLOAT_MANTISSA_BITS), scalingFactor); + } + + while (scaledLeadScore < minCompetitiveScore || freq < minShouldMatch) { + assert ensureConsistent(); + if (scaledLeadScore + tailMaxScore < minCompetitiveScore || freq + tailSize < minShouldMatch) { return false; } else { // a match on doc is still possible, try to // advance scorers from the tail + DisiWrapper prevLead = lead; advanceTail(); + if (scoreMode == ScoreMode.TOP_SCORES && lead != prevLead) { + assert prevLead == lead.next; + scaledLeadScore = + scaleMaxScore( + (float) MathUtil.sumUpperBound(leadScore, FLOAT_MANTISSA_BITS), + scalingFactor); + } } } + assert ensureConsistent(); return true; } @@ -342,10 +365,20 @@ public float matchCost() { } /** Add a disi to the linked list of leads. */ - private void addLead(DisiWrapper lead) { + private void addLead(DisiWrapper lead) throws IOException { + lead.next = this.lead; + this.lead = lead; + freq += 1; + if (scoreMode == ScoreMode.TOP_SCORES) { + leadScore += lead.scorer.score(); + } + } + + /** Add a disi to the linked list of leads. */ + private void addUnpositionedLead(DisiWrapper lead) { + assert lead.doc == -1; lead.next = this.lead; this.lead = lead; - leadMaxScore += lead.scaledMaxScore; freq += 1; } @@ -359,7 +392,6 @@ private void pushBackLeads(int target) throws IOException { } } lead = null; - leadMaxScore = 0; } /** Make sure all disis in 'head' are on or after 'target'. */ @@ -488,8 +520,10 @@ private void moveToNextCandidate() throws IOException { lead = head.pop(); assert doc == lead.doc; lead.next = null; - leadMaxScore = lead.scaledMaxScore; freq = 1; + if (scoreMode == ScoreMode.TOP_SCORES) { + leadScore = lead.scorer.score(); + } while (head.size() > 0 && head.top().doc == doc) { addLead(head.pop()); } @@ -514,11 +548,15 @@ private void advanceAllTail() throws IOException { public float score() throws IOException { // we need to know about all matches advanceAllTail(); - double score = 0; - for (DisiWrapper s = lead; s != null; s = s.next) { - score += s.scorer.score(); + + double leadScore = this.leadScore; + if (scoreMode != ScoreMode.TOP_SCORES) { + // With TOP_SCORES, the score was already computed on the fly. + for (DisiWrapper s = lead; s != null; s = s.next) { + leadScore += s.scorer.score(); + } } - return (float) score; + return (float) leadScore; } @Override