Skip to content

Commit

Permalink
Run filtered disjunctions with MaxScoreBulkScorer. (apache#14014)
Browse files Browse the repository at this point in the history
Running filtered disjunctions with a specialized bulk scorer seems to yield a
good speedup. For what it's worth, I also tried to implement a MAXSCORE-based
scorer to see if it had to do with the `BulkScorer` specialization or the
algorithm, but it didn't help.

To work properly, I had to add a rewrite rule to inline disjunctions in a MUST
clause.

As a next step, it would be interesting to see if we can further optimize this
by loading the filter into a bitset and applying it like live docs.
  • Loading branch information
jpountz authored Nov 27, 2024
1 parent d9aa525 commit 98c59a7
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 14 deletions.
3 changes: 3 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ Optimizations
* GITHUB#14021: WANDScorer now computes scores on the fly, which helps prevent
advancing "tail" clauses in many cases. (Adrien Grand)

* GITHUB#14014: Filtered disjunctions now get executed via `MaxScoreBulkScorer`.
(Adrien Grand)

Bug Fixes
---------------------
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended
Expand Down
20 changes: 20 additions & 0 deletions lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,26 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
}
}

// Inline SHOULD clauses from the only MUST clause
{
if (clauseSets.get(Occur.SHOULD).isEmpty()
&& clauseSets.get(Occur.MUST).size() == 1
&& clauseSets.get(Occur.MUST).iterator().next() instanceof BooleanQuery inner
&& inner.clauses.size() == inner.clauseSets.get(Occur.SHOULD).size()) {
BooleanQuery.Builder rewritten = new BooleanQuery.Builder();
for (BooleanClause clause : clauses) {
if (clause.occur() != Occur.MUST) {
rewritten.add(clause);
}
}
for (BooleanClause innerClause : inner.clauses()) {
rewritten.add(innerClause);
}
rewritten.setMinimumNumberShouldMatch(Math.max(1, inner.getMinimumNumberShouldMatch()));
return rewritten.build();
}
}

return super.rewrite(indexSearcher);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ public BulkScorer bulkScorer() throws IOException {

BulkScorer booleanScorer() throws IOException {
final int numOptionalClauses = subs.get(Occur.SHOULD).size();
final int numRequiredClauses = subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size();
final int numMustClauses = subs.get(Occur.MUST).size();
final int numRequiredClauses = numMustClauses + subs.get(Occur.FILTER).size();

BulkScorer positiveScorer;
if (numRequiredClauses == 0) {
Expand All @@ -209,6 +210,8 @@ BulkScorer booleanScorer() throws IOException {
}

positiveScorer = optionalBulkScorer();
} else if (numMustClauses == 0 && numOptionalClauses > 1 && minShouldMatch >= 1) {
positiveScorer = filteredOptionalBulkScorer();
} else if (numRequiredClauses > 0 && numOptionalClauses == 0 && minShouldMatch == 0) {
positiveScorer = requiredBulkScorer();
} else {
Expand Down Expand Up @@ -286,7 +289,7 @@ BulkScorer optionalBulkScorer() throws IOException {
optionalScorers.add(ss.get(Long.MAX_VALUE));
}

return new MaxScoreBulkScorer(maxDoc, optionalScorers);
return new MaxScoreBulkScorer(maxDoc, optionalScorers, null);
}

List<Scorer> optional = new ArrayList<Scorer>();
Expand All @@ -297,6 +300,32 @@ BulkScorer optionalBulkScorer() throws IOException {
return new BooleanScorer(optional, Math.max(1, minShouldMatch), scoreMode.needsScores());
}

BulkScorer filteredOptionalBulkScorer() throws IOException {
if (subs.get(Occur.MUST).isEmpty() == false
|| subs.get(Occur.FILTER).isEmpty()
|| scoreMode != ScoreMode.TOP_SCORES
|| subs.get(Occur.SHOULD).size() <= 1
|| minShouldMatch > 1) {
return null;
}
long cost = cost();
List<Scorer> optionalScorers = new ArrayList<>();
for (ScorerSupplier ss : subs.get(Occur.SHOULD)) {
optionalScorers.add(ss.get(cost));
}
List<Scorer> filters = new ArrayList<>();
for (ScorerSupplier ss : subs.get(Occur.FILTER)) {
filters.add(ss.get(cost));
}
Scorer filterScorer;
if (filters.size() == 1) {
filterScorer = filters.iterator().next();
} else {
filterScorer = new ConjunctionScorer(filters, Collections.emptySet());
}
return new MaxScoreBulkScorer(maxDoc, optionalScorers, filterScorer);
}

// Return a BulkScorer for the required clauses only
private BulkScorer requiredBulkScorer() throws IOException {
if (subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size() == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ final class MaxScoreBulkScorer extends BulkScorer {
float minCompetitiveScore;
private final Score scorable = new Score();
final double[] maxScoreSums;
private final DisiWrapper filter;

private final long[] windowMatches = new long[FixedBitSet.bits2words(INNER_WINDOW_SIZE)];
private final double[] windowScores = new double[INNER_WINDOW_SIZE];

MaxScoreBulkScorer(int maxDoc, List<Scorer> scorers) throws IOException {
MaxScoreBulkScorer(int maxDoc, List<Scorer> scorers, Scorer filter) throws IOException {
this.maxDoc = maxDoc;
this.filter = filter == null ? null : new DisiWrapper(filter);
allScorers = new DisiWrapper[scorers.size()];
scratch = new DisiWrapper[allScorers.length];
int i = 0;
Expand Down Expand Up @@ -123,7 +125,7 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr
}

while (top.doc < outerWindowMax) {
scoreInnerWindow(collector, acceptDocs, outerWindowMax);
scoreInnerWindow(collector, acceptDocs, outerWindowMax, filter);
top = essentialQueue.top();
if (minCompetitiveScore >= nextMinCompetitiveScore) {
// The minimum competitive score increased substantially, so we can now partition scorers
Expand All @@ -139,9 +141,11 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr
return nextCandidate(max);
}

private void scoreInnerWindow(LeafCollector collector, Bits acceptDocs, int max)
throws IOException {
if (allScorers.length - firstRequiredScorer >= 2) {
private void scoreInnerWindow(
LeafCollector collector, Bits acceptDocs, int max, DisiWrapper filter) throws IOException {
if (filter != null) {
scoreInnerWindowWithFilter(collector, acceptDocs, max, filter);
} else if (allScorers.length - firstRequiredScorer >= 2) {
scoreInnerWindowAsConjunction(collector, acceptDocs, max);
} else {
DisiWrapper top = essentialQueue.top();
Expand All @@ -158,6 +162,55 @@ private void scoreInnerWindow(LeafCollector collector, Bits acceptDocs, int max)
}
}

private void scoreInnerWindowWithFilter(
LeafCollector collector, Bits acceptDocs, int max, DisiWrapper filter) throws IOException {

// TODO: Sometimes load the filter into a bitset and use the more optimized execution paths with
// this bitset as `acceptDocs`

DisiWrapper top = essentialQueue.top();
assert top.doc < max;
if (top.doc < filter.doc) {
top.doc = top.approximation.advance(filter.doc);
}

// Only score an inner window, after that we'll check if the min competitive score has increased
// enough for a more favorable partitioning to be used.
int innerWindowMin = top.doc;
int innerWindowMax = (int) Math.min(max, (long) innerWindowMin + INNER_WINDOW_SIZE);

while (top.doc < innerWindowMax) {
assert filter.doc <= top.doc; // invariant
if (filter.doc < top.doc) {
filter.doc = filter.approximation.advance(top.doc);
}

if (filter.doc != top.doc) {
do {
top.doc = top.iterator.advance(filter.doc);
top = essentialQueue.updateTop();
} while (top.doc < filter.doc);
} else {
int doc = top.doc;
boolean match =
(acceptDocs == null || acceptDocs.get(doc))
&& (filter.twoPhaseView == null || filter.twoPhaseView.matches());
double score = 0;
do {
if (match) {
score += top.scorer.score();
}
top.doc = top.iterator.nextDoc();
top = essentialQueue.updateTop();
} while (top.doc == doc);

if (match) {
scoreNonEssentialClauses(collector, doc, score, firstEssentialScorer);
}
}
}
}

private void scoreInnerWindowSingleEssentialClause(
LeafCollector collector, Bits acceptDocs, int upTo) throws IOException {
DisiWrapper top = essentialQueue.top();
Expand Down Expand Up @@ -284,8 +337,10 @@ private int computeOuterWindowMax(int windowMin) throws IOException {
int windowMax = DocIdSetIterator.NO_MORE_DOCS;
for (int i = firstWindowLead; i < allScorers.length; ++i) {
final DisiWrapper scorer = allScorers[i];
final int upTo = scorer.scorer.advanceShallow(Math.max(scorer.doc, windowMin));
windowMax = (int) Math.min(windowMax, upTo + 1L); // upTo is inclusive
if (filter == null || scorer.cost >= filter.cost) {
final int upTo = scorer.scorer.advanceShallow(Math.max(scorer.doc, windowMin));
windowMax = (int) Math.min(windowMax, upTo + 1L); // upTo is inclusive
}
}

if (allScorers.length - firstWindowLead > 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,51 @@ public void testFlattenInnerConjunctions() throws IOException {
assertEquals(expectedRewritten, searcher.rewrite(query));
}

public void testFlattenDisjunctionInMustClause() throws IOException {
IndexSearcher searcher = newSearcher(new MultiReader());

Query inner =
new BooleanQuery.Builder()
.add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD)
.add(new TermQuery(new Term("foo", "quux")), Occur.SHOULD)
.build();
Query query =
new BooleanQuery.Builder()
.add(inner, Occur.MUST)
.add(new TermQuery(new Term("foo", "baz")), Occur.FILTER)
.build();
Query expectedRewritten =
new BooleanQuery.Builder()
.add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD)
.add(new TermQuery(new Term("foo", "quux")), Occur.SHOULD)
.add(new TermQuery(new Term("foo", "baz")), Occur.FILTER)
.setMinimumNumberShouldMatch(1)
.build();
assertEquals(expectedRewritten, searcher.rewrite(query));

inner =
new BooleanQuery.Builder()
.add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD)
.add(new TermQuery(new Term("foo", "quux")), Occur.SHOULD)
.add(new TermQuery(new Term("foo", "foo")), Occur.SHOULD)
.setMinimumNumberShouldMatch(2)
.build();
query =
new BooleanQuery.Builder()
.add(inner, Occur.MUST)
.add(new TermQuery(new Term("foo", "baz")), Occur.FILTER)
.build();
expectedRewritten =
new BooleanQuery.Builder()
.add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD)
.add(new TermQuery(new Term("foo", "quux")), Occur.SHOULD)
.add(new TermQuery(new Term("foo", "foo")), Occur.SHOULD)
.add(new TermQuery(new Term("foo", "baz")), Occur.FILTER)
.setMinimumNumberShouldMatch(2)
.build();
assertEquals(expectedRewritten, searcher.rewrite(query));
}

public void testDiscardShouldClauses() throws IOException {
IndexSearcher searcher = newSearcher(new MultiReader());

Expand Down
Loading

0 comments on commit 98c59a7

Please sign in to comment.