From 939893227c66e700761257e0712358fa6110c0ec Mon Sep 17 00:00:00 2001 From: "kewei.11" Date: Sun, 10 Nov 2024 18:21:57 +0800 Subject: [PATCH] Pruning of estimating the point value count since BooleanScorerSupplier --- .../document/LatLonPointDistanceQuery.java | 30 +++++- .../lucene/document/RangeFieldQuery.java | 31 ++++++- .../document/XYPointInGeometryQuery.java | 31 ++++++- .../org/apache/lucene/index/PointValues.java | 20 +++- .../lucene/search/BooleanScorerSupplier.java | 92 ++++++++++++++----- .../apache/lucene/search/PointInSetQuery.java | 83 ++++++++++++++--- .../apache/lucene/search/PointRangeQuery.java | 31 +++++-- .../apache/lucene/search/ScorerSupplier.java | 4 + .../org/apache/lucene/search/ScorerUtil.java | 33 +++++++ .../sandbox/search/MultiRangeQuery.java | 34 +++++-- 10 files changed, 325 insertions(+), 64 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/document/LatLonPointDistanceQuery.java b/lucene/core/src/java/org/apache/lucene/document/LatLonPointDistanceQuery.java index 7f5f8cf6290c..e50c3c6c3aa4 100644 --- a/lucene/core/src/java/org/apache/lucene/document/LatLonPointDistanceQuery.java +++ b/lucene/core/src/java/org/apache/lucene/document/LatLonPointDistanceQuery.java @@ -20,6 +20,8 @@ import static org.apache.lucene.geo.GeoEncodingUtils.decodeLongitude; import static org.apache.lucene.geo.GeoEncodingUtils.encodeLatitude; import static org.apache.lucene.geo.GeoEncodingUtils.encodeLongitude; +import static org.apache.lucene.search.TotalHits.Relation.EQUAL_TO; +import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; import java.io.IOException; import org.apache.lucene.geo.GeoEncodingUtils; @@ -40,6 +42,7 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.DocIdSetBuilder; @@ -139,7 +142,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti return new ScorerSupplier() { - long cost = -1; + TotalHits estimatedCount; @Override public Scorer get(long leadCost) throws IOException { @@ -162,11 +165,28 @@ && cost() > reader.maxDoc() / 2) { @Override public long cost() { - if (cost == -1) { - cost = values.estimateDocCount(visitor); + if (estimatedCount == null || estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO) { + estimatedCount = + new TotalHits(values.estimateDocCount(visitor, Long.MAX_VALUE), EQUAL_TO); + assert estimatedCount.value() >= 0; } - assert cost >= 0; - return cost; + return estimatedCount.value(); + } + + @Override + public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) { + if (estimatedCount == null + || (estimatedCount.value() < upperBound + && estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO)) { + long cost = values.estimateDocCount(visitor, Long.MAX_VALUE); + if (cost < upperBound) { + estimatedCount = new TotalHits(cost, EQUAL_TO); + } else if (estimatedCount == null || cost > estimatedCount.value()) { + estimatedCount = new TotalHits(cost, GREATER_THAN_OR_EQUAL_TO); + } + assert estimatedCount.value() >= 0; + } + return estimatedCount; } }; } diff --git a/lucene/core/src/java/org/apache/lucene/document/RangeFieldQuery.java b/lucene/core/src/java/org/apache/lucene/document/RangeFieldQuery.java index f5747c0f8bde..36b060a46cab 100644 --- a/lucene/core/src/java/org/apache/lucene/document/RangeFieldQuery.java +++ b/lucene/core/src/java/org/apache/lucene/document/RangeFieldQuery.java @@ -16,6 +16,9 @@ */ package org.apache.lucene.document; +import static org.apache.lucene.search.TotalHits.Relation.EQUAL_TO; +import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + import java.io.IOException; import java.util.Arrays; import java.util.Objects; @@ -34,6 +37,7 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.ArrayUtil.ByteArrayComparator; @@ -477,7 +481,7 @@ public long cost() { final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field); final IntersectVisitor visitor = getIntersectVisitor(result); - long cost = -1; + TotalHits estimatedCount = null; @Override public Scorer get(long leadCost) throws IOException { @@ -488,12 +492,29 @@ public Scorer get(long leadCost) throws IOException { @Override public long cost() { - if (cost == -1) { + if (estimatedCount == null || estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO) { + estimatedCount = + new TotalHits(values.estimateDocCount(visitor, Long.MAX_VALUE), EQUAL_TO); + assert estimatedCount.value() >= 0; + } + return estimatedCount.value(); + } + + @Override + public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) { + if (estimatedCount == null + || (estimatedCount.value() < upperBound + && estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO)) { // Computing the cost may be expensive, so only do it if necessary - cost = values.estimateDocCount(visitor); - assert cost >= 0; + long cost = values.estimateDocCount(visitor, upperBound); + if (cost < upperBound) { + estimatedCount = new TotalHits(cost, EQUAL_TO); + } else if (estimatedCount == null || cost > estimatedCount.value()) { + estimatedCount = new TotalHits(cost, GREATER_THAN_OR_EQUAL_TO); + } + assert estimatedCount.value() >= 0; } - return cost; + return estimatedCount; } }; } diff --git a/lucene/core/src/java/org/apache/lucene/document/XYPointInGeometryQuery.java b/lucene/core/src/java/org/apache/lucene/document/XYPointInGeometryQuery.java index 47b6abb46c22..0f855c7cd722 100644 --- a/lucene/core/src/java/org/apache/lucene/document/XYPointInGeometryQuery.java +++ b/lucene/core/src/java/org/apache/lucene/document/XYPointInGeometryQuery.java @@ -16,6 +16,9 @@ */ package org.apache.lucene.document; +import static org.apache.lucene.search.TotalHits.Relation.EQUAL_TO; +import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + import java.io.IOException; import java.util.Arrays; import org.apache.lucene.geo.Component2D; @@ -36,6 +39,7 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; import org.apache.lucene.util.DocIdSetBuilder; @@ -144,7 +148,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti return new ScorerSupplier() { - long cost = -1; + TotalHits estimatedCount; DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field); final IntersectVisitor visitor = getIntersectVisitor(result, tree); @@ -156,12 +160,29 @@ public Scorer get(long leadCost) throws IOException { @Override public long cost() { - if (cost == -1) { + if (estimatedCount == null || estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO) { // Computing the cost may be expensive, so only do it if necessary - cost = values.estimateDocCount(visitor); - assert cost >= 0; + estimatedCount = + new TotalHits(values.estimateDocCount(visitor, Long.MAX_VALUE), EQUAL_TO); + assert estimatedCount.value() >= 0; + } + return estimatedCount.value(); + } + + @Override + public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) { + if (estimatedCount == null + || (estimatedCount.value() < upperBound + && estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO)) { + long cost = values.estimateDocCount(visitor, upperBound); + if (cost < upperBound) { + estimatedCount = new TotalHits(cost, EQUAL_TO); + } else if (estimatedCount == null || cost > estimatedCount.value()) { + estimatedCount = new TotalHits(cost, GREATER_THAN_OR_EQUAL_TO); + } + assert estimatedCount.value() >= 0; } - return cost; + return estimatedCount; } }; } diff --git a/lucene/core/src/java/org/apache/lucene/index/PointValues.java b/lucene/core/src/java/org/apache/lucene/index/PointValues.java index aebcaa90e2a0..ce46b8a70487 100644 --- a/lucene/core/src/java/org/apache/lucene/index/PointValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/PointValues.java @@ -385,9 +385,17 @@ private void intersect(IntersectVisitor visitor, PointTree pointTree) throws IOE * IntersectVisitor}. This should run many times faster than {@link #intersect(IntersectVisitor)}. */ public final long estimatePointCount(IntersectVisitor visitor) { + return estimatePointCount(visitor, Long.MAX_VALUE); + } + + /** + * Estimate the number of points within the given {@link IntersectVisitor} and a maximum of + * {upperBound} + */ + public final long estimatePointCount(IntersectVisitor visitor, long upperBound) { try { final PointTree pointTree = getPointTree(); - final long count = estimatePointCount(visitor, pointTree, Long.MAX_VALUE); + final long count = estimatePointCount(visitor, pointTree, upperBound); assert pointTree.moveToParent() == false; return count; } catch (IOException ioe) { @@ -449,7 +457,15 @@ private static long estimatePointCount( * @see DocIdSetIterator#cost */ public final long estimateDocCount(IntersectVisitor visitor) { - long estimatedPointCount = estimatePointCount(visitor); + return estimateDocCount(visitor, Long.MAX_VALUE); + } + + /** + * Estimate the number of documents that would be matched by {@link #intersect} with the given + * {upperBound} + */ + public final long estimateDocCount(IntersectVisitor visitor, long upperBound) { + long estimatedPointCount = estimatePointCount(visitor, upperBound); int docCount = getDocCount(); double size = size(); if (estimatedPointCount >= size) { diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java index a8169ad227f1..b9c34cf0e00e 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java @@ -16,6 +16,9 @@ */ package org.apache.lucene.search; +import static org.apache.lucene.search.TotalHits.Relation.EQUAL_TO; +import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -24,7 +27,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.OptionalLong; +import java.util.stream.Collectors; import java.util.stream.Stream; import org.apache.lucene.search.BooleanClause.Occur; import org.apache.lucene.search.Weight.DefaultBulkScorer; @@ -35,7 +38,7 @@ final class BooleanScorerSupplier extends ScorerSupplier { private final ScoreMode scoreMode; private final int minShouldMatch; private final int maxDoc; - private long cost = -1; + private TotalHits estimatedCount = null; private boolean topLevelScoringClause; BooleanScorerSupplier( @@ -69,21 +72,40 @@ final class BooleanScorerSupplier extends ScorerSupplier { this.maxDoc = maxDoc; } - private long computeCost() { - OptionalLong minRequiredCost = + private TotalHits computeCost(long upperBound) { + + TotalHits minRequiredCost = null; + TotalHits totalHits = null; + for (ScorerSupplier scorerSupplier : Stream.concat(subs.get(Occur.MUST).stream(), subs.get(Occur.FILTER).stream()) - .mapToLong(ScorerSupplier::cost) - .min(); - if (minRequiredCost.isPresent() && minShouldMatch == 0) { - return minRequiredCost.getAsLong(); + .collect(Collectors.toList())) { + totalHits = scorerSupplier.isEstimatedPointCountGreaterThanOrEqualTo(upperBound); + if (totalHits.relation() == EQUAL_TO && totalHits.value() < upperBound) { + upperBound = totalHits.value(); + minRequiredCost = totalHits; + } else if (minRequiredCost == null) { + minRequiredCost = totalHits; + } + } + + if (minRequiredCost != null && minShouldMatch == 0) { + return minRequiredCost; } else { final Collection optionalScorers = subs.get(Occur.SHOULD); - final long shouldCost = + final TotalHits shouldCost = ScorerUtil.costWithMinShouldMatch( - optionalScorers.stream().mapToLong(ScorerSupplier::cost), - optionalScorers.size(), - minShouldMatch); - return Math.min(minRequiredCost.orElse(Long.MAX_VALUE), shouldCost); + optionalScorers, optionalScorers.size(), minShouldMatch, upperBound); + + if (shouldCost.relation() == EQUAL_TO) { + return shouldCost; + } else if (minRequiredCost != null && minRequiredCost.relation() == EQUAL_TO) { + return minRequiredCost; + } else if (minRequiredCost != null) { + // or we should return small one? it doesn't matter + return (shouldCost.value() > minRequiredCost.value() ? shouldCost : minRequiredCost); + } else { + return shouldCost; + } } } @@ -103,10 +125,22 @@ public void setTopLevelScoringClause() throws IOException { @Override public long cost() { - if (cost == -1) { - cost = computeCost(); + if (estimatedCount == null || estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO) { + estimatedCount = computeCost(Long.MAX_VALUE); + assert estimatedCount.value() >= 0; } - return cost; + return estimatedCount.value(); + } + + @Override + public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) { + if (estimatedCount == null + || (estimatedCount.value() < upperBound + && estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO)) { + estimatedCount = computeCost(upperBound); + assert estimatedCount.value() >= 0; + } + return estimatedCount; } @Override @@ -126,7 +160,10 @@ public Scorer get(long leadCost) throws IOException { private Scorer getInternal(long leadCost) throws IOException { // three cases: conjunction, disjunction, or mix - leadCost = Math.min(leadCost, cost()); + estimatedCount = isEstimatedPointCountGreaterThanOrEqualTo(leadCost); + if (estimatedCount.relation() == EQUAL_TO && estimatedCount.value() < leadCost) { + leadCost = estimatedCount.value(); + } // pure conjunction if (subs.get(Occur.SHOULD).isEmpty()) { @@ -202,10 +239,11 @@ BulkScorer booleanScorer() throws IOException { // there will be no matches in the end) so we should only use // BooleanScorer if matches are very dense costThreshold = maxDoc / 3; - } - if (cost() < costThreshold) { - return null; + TotalHits estimatedCount = isEstimatedPointCountGreaterThanOrEqualTo(costThreshold); + if (estimatedCount.relation() == EQUAL_TO) { + return null; + } } positiveScorer = optionalBulkScorer(); @@ -315,10 +353,16 @@ private BulkScorer requiredBulkScorer() throws IOException { return scorer; } - long leadCost = - subs.get(Occur.MUST).stream().mapToLong(ScorerSupplier::cost).min().orElse(Long.MAX_VALUE); - leadCost = - subs.get(Occur.FILTER).stream().mapToLong(ScorerSupplier::cost).min().orElse(leadCost); + long leadCost = Long.MAX_VALUE; + TotalHits estimatedCount; + for (ScorerSupplier scorerSupplier : + Stream.concat(subs.get(Occur.MUST).stream(), subs.get(Occur.FILTER).stream()) + .collect(Collectors.toList())) { + estimatedCount = scorerSupplier.isEstimatedPointCountGreaterThanOrEqualTo(leadCost); + if (estimatedCount.relation() == EQUAL_TO && estimatedCount.value() < leadCost) { + leadCost = estimatedCount.value(); + } + } List requiredNoScoring = new ArrayList<>(); for (ScorerSupplier ss : subs.get(Occur.FILTER)) { diff --git a/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java b/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java index f0e0cfd6bdb8..04462db2bf99 100644 --- a/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java @@ -16,6 +16,9 @@ */ package org.apache.lucene.search; +import static org.apache.lucene.search.TotalHits.Relation.EQUAL_TO; +import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + import java.io.IOException; import java.io.UncheckedIOException; import java.util.AbstractCollection; @@ -176,7 +179,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti // We optimize this common case, effectively doing a merge sort of the indexed values vs // the queried set: return new ScorerSupplier() { - long cost = -1; // calculate lazily, only once + TotalHits estimatedCount = null; // calculate lazily, only once @Override public Scorer get(long leadCost) throws IOException { @@ -189,15 +192,42 @@ public Scorer get(long leadCost) throws IOException { @Override public long cost() { try { - if (cost == -1) { + if (estimatedCount == null + || estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO) { // Computing the cost may be expensive, so only do it if necessary DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field); - cost = + estimatedCount = + new TotalHits( + values.estimateDocCount( + new MergePointVisitor(sortedPackedPoints.iterator(), result), + Long.MAX_VALUE), + EQUAL_TO); + assert estimatedCount.value() >= 0; + } + return estimatedCount.value(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) { + try { + if (estimatedCount == null + || (estimatedCount.value() < upperBound + && estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO)) { + DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field); + long cost = values.estimateDocCount( - new MergePointVisitor(sortedPackedPoints.iterator(), result)); - assert cost >= 0; + new MergePointVisitor(sortedPackedPoints.iterator(), result), upperBound); + if (cost < upperBound) { + estimatedCount = new TotalHits(cost, EQUAL_TO); + } else if (estimatedCount == null || cost > estimatedCount.value()) { + estimatedCount = new TotalHits(cost, GREATER_THAN_OR_EQUAL_TO); + } + assert estimatedCount.value() >= 0; } - return cost; + return estimatedCount; } catch (IOException e) { throw new UncheckedIOException(e); } @@ -211,7 +241,7 @@ public long cost() { // index, which is probably tricky! return new ScorerSupplier() { - long cost = -1; // calculate lazily, only once + TotalHits estimatedCount = null; // calculate lazily, only once @Override public Scorer get(long leadCost) throws IOException { @@ -228,18 +258,49 @@ public Scorer get(long leadCost) throws IOException { @Override public long cost() { try { - if (cost == -1) { + if (estimatedCount == null + || estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO) { + DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field); + SinglePointVisitor visitor = new SinglePointVisitor(result); + TermIterator iterator = sortedPackedPoints.iterator(); + long cost = 0; + for (BytesRef point = iterator.next(); point != null; point = iterator.next()) { + visitor.setPoint(point); + cost += values.estimateDocCount(visitor, Long.MAX_VALUE); + } + assert cost >= 0; + estimatedCount = new TotalHits(cost, EQUAL_TO); + } + return estimatedCount.value(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) { + try { + if (estimatedCount == null + || (estimatedCount.value() < upperBound + && estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO)) { DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field); SinglePointVisitor visitor = new SinglePointVisitor(result); TermIterator iterator = sortedPackedPoints.iterator(); - cost = 0; + long cost = 0; for (BytesRef point = iterator.next(); point != null; point = iterator.next()) { visitor.setPoint(point); - cost += values.estimateDocCount(visitor); + cost += values.estimateDocCount(visitor, upperBound); + if (cost >= upperBound) { + estimatedCount = new TotalHits(cost, GREATER_THAN_OR_EQUAL_TO); + break; + } + } + if (cost < upperBound) { + estimatedCount = new TotalHits(cost, EQUAL_TO); } assert cost >= 0; } - return cost; + return estimatedCount; } catch (IOException e) { throw new UncheckedIOException(e); } diff --git a/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java b/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java index 1b6d6869c19e..a32007249245 100644 --- a/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java @@ -16,6 +16,9 @@ */ package org.apache.lucene.search; +import static org.apache.lucene.search.TotalHits.Relation.EQUAL_TO; +import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + import java.io.IOException; import java.util.Arrays; import java.util.Objects; @@ -360,7 +363,7 @@ public long cost() { final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field); final IntersectVisitor visitor = getIntersectVisitor(result); - long cost = -1; + TotalHits estimatedCount = null; @Override public Scorer get(long leadCost) throws IOException { @@ -385,12 +388,28 @@ && cost() > reader.maxDoc() / 2) { @Override public long cost() { - if (cost == -1) { - // Computing the cost may be expensive, so only do it if necessary - cost = values.estimateDocCount(visitor); - assert cost >= 0; + if (estimatedCount == null || estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO) { + estimatedCount = + new TotalHits(values.estimateDocCount(visitor, Long.MAX_VALUE), EQUAL_TO); + assert estimatedCount.value() >= 0; + } + return estimatedCount.value(); + } + + @Override + public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) { + if (estimatedCount == null + || (estimatedCount.value() < upperBound + && estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO)) { + long cost = values.estimateDocCount(visitor, upperBound); + if (cost < upperBound) { + estimatedCount = new TotalHits(cost, EQUAL_TO); + } else if (estimatedCount == null || cost > estimatedCount.value()) { + estimatedCount = new TotalHits(cost, GREATER_THAN_OR_EQUAL_TO); + } + assert estimatedCount.value() >= 0; } - return cost; + return estimatedCount; } }; } diff --git a/lucene/core/src/java/org/apache/lucene/search/ScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/search/ScorerSupplier.java index edfaad873fe0..d770e4cdadeb 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ScorerSupplier.java +++ b/lucene/core/src/java/org/apache/lucene/search/ScorerSupplier.java @@ -53,6 +53,10 @@ public BulkScorer bulkScorer() throws IOException { */ public abstract long cost(); + public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) { + return new TotalHits(cost(), TotalHits.Relation.EQUAL_TO); + } + /** * Inform this {@link ScorerSupplier} that its returned scorers produce scores that get passed to * the collector, as opposed to partial scores that then need to get combined (e.g. summed up). diff --git a/lucene/core/src/java/org/apache/lucene/search/ScorerUtil.java b/lucene/core/src/java/org/apache/lucene/search/ScorerUtil.java index 50c960719cf9..a11396c28d5d 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ScorerUtil.java +++ b/lucene/core/src/java/org/apache/lucene/search/ScorerUtil.java @@ -16,6 +16,10 @@ */ package org.apache.lucene.search; +import static org.apache.lucene.search.TotalHits.Relation.EQUAL_TO; +import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + +import java.util.Collection; import java.util.stream.LongStream; import java.util.stream.StreamSupport; import org.apache.lucene.util.PriorityQueue; @@ -46,4 +50,33 @@ protected boolean lessThan(Long a, Long b) { costs.forEach(pq::insertWithOverflow); return StreamSupport.stream(pq.spliterator(), false).mapToLong(Number::longValue).sum(); } + + static TotalHits costWithMinShouldMatch( + Collection collection, int numScorers, int minShouldMatch, long upperBound) { + int queueSize = Math.min(numScorers - minShouldMatch + 1, collection.size()); + final PriorityQueue pq = + new PriorityQueue(queueSize) { + @Override + protected boolean lessThan(Long a, Long b) { + return a > b; + } + }; + // Keep track of the last eliminated value that was added to the priority queue. + long leastTopNScoreBound = upperBound; + for (ScorerSupplier supplier : collection) { + TotalHits totalHits = supplier.isEstimatedPointCountGreaterThanOrEqualTo(leastTopNScoreBound); + if (totalHits.relation() == EQUAL_TO) { + Long oldCost = pq.insertWithOverflow(totalHits.value()); + if (oldCost != null && leastTopNScoreBound > oldCost) { + leastTopNScoreBound = oldCost; + } + } + } + long cost = StreamSupport.stream(pq.spliterator(), false).mapToLong(Number::longValue).sum(); + if (pq.size() < queueSize || cost > upperBound) { + return new TotalHits(Math.max(cost, upperBound), GREATER_THAN_OR_EQUAL_TO); + } else { + return new TotalHits(cost, EQUAL_TO); + } + } } diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiRangeQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiRangeQuery.java index 66c490b791a1..1f85cd0e4411 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiRangeQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/MultiRangeQuery.java @@ -17,6 +17,9 @@ package org.apache.lucene.sandbox.search; +import static org.apache.lucene.search.TotalHits.Relation.EQUAL_TO; +import static org.apache.lucene.search.TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO; + import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -36,6 +39,7 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.DocIdSetBuilder; @@ -352,7 +356,7 @@ public long cost() { final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field); final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, range); - long cost = -1; + TotalHits estimatedCount = null; @Override public Scorer get(long leadCost) throws IOException { @@ -363,12 +367,30 @@ public Scorer get(long leadCost) throws IOException { @Override public long cost() { - if (cost == -1) { - // Computing the cost may be expensive, so only do it if necessary - cost = values.estimateDocCount(visitor) * rangeClauses.size(); - assert cost >= 0; + if (estimatedCount == null || estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO) { + estimatedCount = + new TotalHits( + values.estimateDocCount(visitor, Long.MAX_VALUE) * rangeClauses.size(), + EQUAL_TO); + assert estimatedCount.value() >= 0; + } + return estimatedCount.value(); + } + + @Override + public TotalHits isEstimatedPointCountGreaterThanOrEqualTo(long upperBound) { + if (estimatedCount == null + || (estimatedCount.value() < upperBound + && estimatedCount.relation() == GREATER_THAN_OR_EQUAL_TO)) { + long cost = values.estimateDocCount(visitor, upperBound); + if (cost < upperBound) { + estimatedCount = new TotalHits(cost, EQUAL_TO); + } else if (estimatedCount == null || cost > estimatedCount.value()) { + estimatedCount = new TotalHits(cost, GREATER_THAN_OR_EQUAL_TO); + } + assert estimatedCount.value() >= 0; } - return cost; + return estimatedCount; } }; }