Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up advancing on the disjunction iterator. #14052

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOSupplier;
import org.apache.lucene.util.RamUsageEstimator;

/**
Expand Down Expand Up @@ -151,7 +150,8 @@ protected abstract WeightOrDocIdSetIterator rewriteInner(
int fieldDocCount,
Terms terms,
TermsEnum termsEnum,
List<TermAndState> collectedTerms)
List<TermAndState> collectedTerms,
long leadCost)
throws IOException;

private WeightOrDocIdSetIterator rewriteAsBooleanQuery(
Expand Down Expand Up @@ -247,21 +247,22 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti
cost = estimateCost(terms, q.getTermsCount());
}

IOSupplier<WeightOrDocIdSetIterator> weightOrIteratorSupplier =
() -> {
IOLongFunction<WeightOrDocIdSetIterator> weightOrIteratorSupplier =
leadCost -> {
if (collectResult) {
return rewriteAsBooleanQuery(context, collectedTerms);
} else {
// Too many terms to rewrite as a simple bq.
// Invoke rewriteInner logic to handle rewriting:
return rewriteInner(context, fieldDocCount, terms, termsEnum, collectedTerms);
return rewriteInner(
context, fieldDocCount, terms, termsEnum, collectedTerms, leadCost);
}
};

return new ScorerSupplier() {
@Override
public Scorer get(long leadCost) throws IOException {
WeightOrDocIdSetIterator weightOrIterator = weightOrIteratorSupplier.get();
WeightOrDocIdSetIterator weightOrIterator = weightOrIteratorSupplier.apply(leadCost);
final Scorer scorer;
if (weightOrIterator == null) {
scorer = null;
Expand All @@ -281,7 +282,8 @@ public Scorer get(long leadCost) throws IOException {

@Override
public BulkScorer bulkScorer() throws IOException {
WeightOrDocIdSetIterator weightOrIterator = weightOrIteratorSupplier.get();
WeightOrDocIdSetIterator weightOrIterator =
weightOrIteratorSupplier.apply(Long.MAX_VALUE);
final BulkScorer bulkScorer;
if (weightOrIterator == null) {
bulkScorer = null;
Expand Down Expand Up @@ -311,6 +313,10 @@ public long cost() {
};
}

private static interface IOLongFunction<T> {
T apply(long arg) throws IOException;
}

private static long estimateCost(Terms terms, long queryTermsCount) throws IOException {
// Estimate the cost. If the MTQ can provide its term count, we can do a better job
// estimating.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ BulkScorer booleanScorer() throws IOException {
Scorer prohibitedScorer =
prohibited.size() == 1
? prohibited.get(0)
: new DisjunctionSumScorer(prohibited, ScoreMode.COMPLETE_NO_SCORES);
: new DisjunctionSumScorer(
prohibited, ScoreMode.COMPLETE_NO_SCORES, positiveScorerCost);
return new ReqExclBulkScorer(positiveScorer, prohibitedScorer);
}
}
Expand Down Expand Up @@ -509,7 +510,7 @@ private Scorer opt(
if ((scoreMode == ScoreMode.TOP_SCORES && topLevelScoringClause) || minShouldMatch > 1) {
return new WANDScorer(optionalScorers, minShouldMatch, scoreMode, leadCost);
} else {
return new DisjunctionSumScorer(optionalScorers, scoreMode);
return new DisjunctionSumScorer(optionalScorers, scoreMode, leadCost);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,86 @@
package org.apache.lucene.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;

/**
* A {@link DocIdSetIterator} which is a disjunction of the approximations of the provided
* iterators.
*
* @lucene.internal
*/
public class DisjunctionDISIApproximation extends DocIdSetIterator {
public final class DisjunctionDISIApproximation extends DocIdSetIterator {

final DisiPriorityQueue subIterators;
final long cost;
public static DisjunctionDISIApproximation of(
Collection<DisiWrapper> subIterators, long leadCost) {

return new DisjunctionDISIApproximation(subIterators, leadCost);
}

// Heap of iterators that lead iteration.
private final DisiPriorityQueue leadIterators;
// List of iterators that will likely advance on every call to nextDoc() / advance()
private final DisiWrapper[] otherIterators;
private final long cost;
private DisiWrapper leadTop;
private int minOtherDoc;

public DisjunctionDISIApproximation(Collection<DisiWrapper> subIterators, long leadCost) {
// Using a heap to store disjunctive clauses is great for exhaustive evaluation, when a single
// clause needs to move through the heap on every iteration on average. However, when
// intersecting with a selective filter, it is possible that all clauses need advancing, which
// makes the reordering cost scale in O(N * log(N)) per advance() call when checking clauses
// linearly would scale in O(N).
// To protect against this reordering overhead, we try to have 1.5 clauses or less that advance
// on every advance() call by only putting clauses into the heap as long as Σ min(1, cost /
// leadCost) <= 1.5, or Σ min(leadCost, cost) <= 1.5 * leadCost. Other clauses are checked
// linearly.

List<DisiWrapper> wrappers = new ArrayList<>(subIterators);
// Sort by descending cost.
wrappers.sort(Comparator.<DisiWrapper>comparingLong(w -> w.cost).reversed());

leadIterators = new DisiPriorityQueue(subIterators.size());

long reorderThreshold = leadCost + (leadCost >> 1);
if (reorderThreshold < 0) { // overflow
reorderThreshold = Long.MAX_VALUE;
}
long reorderCost = 0;
while (wrappers.isEmpty() == false) {
DisiWrapper last = wrappers.getLast();
long inc = Math.min(last.cost, leadCost);
if (reorderCost + inc < 0 || reorderCost + inc > reorderThreshold) {
break;
}
leadIterators.add(wrappers.removeLast());
reorderCost += inc;
}

// Make leadIterators not empty. This helps save conditionals in the implementation which are
// rarely tested.
if (leadIterators.size() == 0) {
leadIterators.add(wrappers.removeLast());
}

otherIterators = wrappers.toArray(DisiWrapper[]::new);

public DisjunctionDISIApproximation(DisiPriorityQueue subIterators) {
this.subIterators = subIterators;
long cost = 0;
for (DisiWrapper w : subIterators) {
for (DisiWrapper w : leadIterators) {
cost += w.cost;
}
for (DisiWrapper w : otherIterators) {
cost += w.cost;
}
this.cost = cost;
minOtherDoc = Integer.MAX_VALUE;
for (DisiWrapper w : otherIterators) {
minOtherDoc = Math.min(minOtherDoc, w.doc);
}
leadTop = leadIterators.top();
}

@Override
Expand All @@ -45,29 +106,62 @@ public long cost() {

@Override
public int docID() {
return subIterators.top().doc;
return Math.min(minOtherDoc, leadTop.doc);
}

@Override
public int nextDoc() throws IOException {
DisiWrapper top = subIterators.top();
final int doc = top.doc;
do {
top.doc = top.approximation.nextDoc();
top = subIterators.updateTop();
} while (top.doc == doc);

return top.doc;
if (leadTop.doc < minOtherDoc) {
int curDoc = leadTop.doc;
do {
leadTop.doc = leadTop.approximation.nextDoc();
leadTop = leadIterators.updateTop();
} while (leadTop.doc == curDoc);
return Math.min(leadTop.doc, minOtherDoc);
} else {
return advance(minOtherDoc + 1);
}
}

@Override
public int advance(int target) throws IOException {
DisiWrapper top = subIterators.top();
do {
top.doc = top.approximation.advance(target);
top = subIterators.updateTop();
} while (top.doc < target);
while (leadTop.doc < target) {
leadTop.doc = leadTop.approximation.advance(target);
leadTop = leadIterators.updateTop();
}

return top.doc;
minOtherDoc = Integer.MAX_VALUE;
for (DisiWrapper w : otherIterators) {
if (w.doc < target) {
w.doc = w.approximation.advance(target);
}
minOtherDoc = Math.min(minOtherDoc, w.doc);
}

return Math.min(leadTop.doc, minOtherDoc);
}

/** Return the linked list of iterators positioned on the current doc. */
public DisiWrapper topList() {
if (leadTop.doc < minOtherDoc) {
return leadIterators.topList();
} else {
return computeTopList();
}
}

private DisiWrapper computeTopList() {
assert leadTop.doc >= minOtherDoc;
DisiWrapper topList = null;
if (leadTop.doc == minOtherDoc) {
topList = leadIterators.topList();
}
for (DisiWrapper w : otherIterators) {
if (w.doc == minOtherDoc) {
w.next = topList;
topList = w;
}
}
return topList;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ public Scorer get(long leadCost) throws IOException {
for (ScorerSupplier ss : scorerSuppliers) {
scorers.add(ss.get(leadCost));
}
return new DisjunctionMaxScorer(tieBreakerMultiplier, scorers, scoreMode);
return new DisjunctionMaxScorer(tieBreakerMultiplier, scorers, scoreMode, leadCost);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ final class DisjunctionMaxScorer extends DisjunctionScorer {
* as they are summed into the result.
* @param subScorers The sub scorers this Scorer should iterate on
*/
DisjunctionMaxScorer(float tieBreakerMultiplier, List<Scorer> subScorers, ScoreMode scoreMode)
DisjunctionMaxScorer(
float tieBreakerMultiplier, List<Scorer> subScorers, ScoreMode scoreMode, long leadCost)
throws IOException {
super(subScorers, scoreMode);
super(subScorers, scoreMode, leadCost);
this.subScorers = subScorers;
this.tieBreakerMultiplier = tieBreakerMultiplier;
if (tieBreakerMultiplier < 0 || tieBreakerMultiplier > 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,34 @@
/** Base class for Scorers that score disjunctions. */
abstract class DisjunctionScorer extends Scorer {

private final int numClauses;
private final boolean needsScores;

private final DisiPriorityQueue subScorers;
private final DocIdSetIterator approximation;
private final DisjunctionDISIApproximation approximation;
private final TwoPhase twoPhase;

protected DisjunctionScorer(List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
protected DisjunctionScorer(List<Scorer> subScorers, ScoreMode scoreMode, long leadCost)
throws IOException {
if (subScorers.size() <= 1) {
throw new IllegalArgumentException("There must be at least 2 subScorers");
}
this.subScorers = new DisiPriorityQueue(subScorers.size());
for (Scorer scorer : subScorers) {
final DisiWrapper w = new DisiWrapper(scorer, false);
this.subScorers.add(w);
}
this.numClauses = subScorers.size();
this.needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;
this.approximation = new DisjunctionDISIApproximation(this.subScorers);

boolean hasApproximation = false;
float sumMatchCost = 0;
long sumApproxCost = 0;
// Compute matchCost as the average over the matchCost of the subScorers.
// This is weighted by the cost, which is an expected number of matching documents.
for (DisiWrapper w : this.subScorers) {
List<DisiWrapper> wrappers = new ArrayList<>();
for (Scorer scorer : subScorers) {
DisiWrapper w = new DisiWrapper(scorer, false);
long costWeight = (w.cost <= 1) ? 1 : w.cost;
sumApproxCost += costWeight;
if (w.twoPhaseView != null) {
hasApproximation = true;
sumMatchCost += w.matchCost * costWeight;
}
wrappers.add(w);
}
this.approximation = new DisjunctionDISIApproximation(wrappers, leadCost);

if (hasApproximation == false) { // no sub scorer supports approximations
twoPhase = null;
Expand Down Expand Up @@ -91,7 +88,7 @@ private TwoPhase(DocIdSetIterator approximation, float matchCost) {
super(approximation);
this.matchCost = matchCost;
unverifiedMatches =
new PriorityQueue<DisiWrapper>(DisjunctionScorer.this.subScorers.size()) {
new PriorityQueue<DisiWrapper>(numClauses) {
@Override
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.matchCost < b.matchCost;
Expand All @@ -116,7 +113,7 @@ public boolean matches() throws IOException {
verifiedMatches = null;
unverifiedMatches.clear();

for (DisiWrapper w = subScorers.topList(); w != null; ) {
for (DisiWrapper w = DisjunctionScorer.this.approximation.topList(); w != null; ) {
DisiWrapper next = w.next;

if (w.twoPhaseView == null) {
Expand Down Expand Up @@ -160,12 +157,12 @@ public float matchCost() {

@Override
public final int docID() {
return subScorers.top().doc;
return approximation.docID();
}

DisiWrapper getSubMatches() throws IOException {
if (twoPhase == null) {
return subScorers.topList();
return approximation.topList();
} else {
return twoPhase.getSubMatches();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ final class DisjunctionSumScorer extends DisjunctionScorer {
*
* @param subScorers Array of at least two subscorers.
*/
DisjunctionSumScorer(List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
super(subScorers, scoreMode);
DisjunctionSumScorer(List<Scorer> subScorers, ScoreMode scoreMode, long leadCost)
throws IOException {
super(subScorers, scoreMode, leadCost);
this.scorers = subScorers;
}

Expand Down
Loading
Loading