diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index a27d7c70e3f1..01e963f4c486 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -110,6 +110,9 @@ Optimizations * GITHUB#14021: WANDScorer now computes scores on the fly, which helps prevent advancing "tail" clauses in many cases. (Adrien Grand) +* GITHUB#14014: Filtered disjunctions now get executed via `MaxScoreBulkScorer`. + (Adrien Grand) + Bug Fixes --------------------- * GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java b/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java index b50b0530a2d1..f80597d38e6d 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanQuery.java @@ -624,6 +624,26 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { } } + // Inline SHOULD clauses from the only MUST clause + { + if (clauseSets.get(Occur.SHOULD).isEmpty() + && clauseSets.get(Occur.MUST).size() == 1 + && clauseSets.get(Occur.MUST).iterator().next() instanceof BooleanQuery inner + && inner.clauses.size() == inner.clauseSets.get(Occur.SHOULD).size()) { + BooleanQuery.Builder rewritten = new BooleanQuery.Builder(); + for (BooleanClause clause : clauses) { + if (clause.occur() != Occur.MUST) { + rewritten.add(clause); + } + } + for (BooleanClause innerClause : inner.clauses()) { + rewritten.add(innerClause); + } + rewritten.setMinimumNumberShouldMatch(Math.max(1, inner.getMinimumNumberShouldMatch())); + return rewritten.build(); + } + } + return super.rewrite(indexSearcher); } diff --git a/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java index 515d0a6bba1d..7732445e8cd4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java +++ b/lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java @@ -183,7 +183,8 @@ public BulkScorer bulkScorer() throws IOException { BulkScorer booleanScorer() throws IOException { final int numOptionalClauses = subs.get(Occur.SHOULD).size(); - final int numRequiredClauses = subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size(); + final int numMustClauses = subs.get(Occur.MUST).size(); + final int numRequiredClauses = numMustClauses + subs.get(Occur.FILTER).size(); BulkScorer positiveScorer; if (numRequiredClauses == 0) { @@ -209,6 +210,8 @@ BulkScorer booleanScorer() throws IOException { } positiveScorer = optionalBulkScorer(); + } else if (numMustClauses == 0 && numOptionalClauses > 1 && minShouldMatch >= 1) { + positiveScorer = filteredOptionalBulkScorer(); } else if (numRequiredClauses > 0 && numOptionalClauses == 0 && minShouldMatch == 0) { positiveScorer = requiredBulkScorer(); } else { @@ -286,7 +289,7 @@ BulkScorer optionalBulkScorer() throws IOException { optionalScorers.add(ss.get(Long.MAX_VALUE)); } - return new MaxScoreBulkScorer(maxDoc, optionalScorers); + return new MaxScoreBulkScorer(maxDoc, optionalScorers, null); } List optional = new ArrayList(); @@ -297,6 +300,32 @@ BulkScorer optionalBulkScorer() throws IOException { return new BooleanScorer(optional, Math.max(1, minShouldMatch), scoreMode.needsScores()); } + BulkScorer filteredOptionalBulkScorer() throws IOException { + if (subs.get(Occur.MUST).isEmpty() == false + || subs.get(Occur.FILTER).isEmpty() + || scoreMode != ScoreMode.TOP_SCORES + || subs.get(Occur.SHOULD).size() <= 1 + || minShouldMatch > 1) { + return null; + } + long cost = cost(); + List optionalScorers = new ArrayList<>(); + for (ScorerSupplier ss : subs.get(Occur.SHOULD)) { + optionalScorers.add(ss.get(cost)); + } + List filters = new ArrayList<>(); + for (ScorerSupplier ss : subs.get(Occur.FILTER)) { + filters.add(ss.get(cost)); + } + Scorer filterScorer; + if (filters.size() == 1) { + filterScorer = filters.iterator().next(); + } else { + filterScorer = new ConjunctionScorer(filters, Collections.emptySet()); + } + return new MaxScoreBulkScorer(maxDoc, optionalScorers, filterScorer); + } + // Return a BulkScorer for the required clauses only private BulkScorer requiredBulkScorer() throws IOException { if (subs.get(Occur.MUST).size() + subs.get(Occur.FILTER).size() == 0) { diff --git a/lucene/core/src/java/org/apache/lucene/search/MaxScoreBulkScorer.java b/lucene/core/src/java/org/apache/lucene/search/MaxScoreBulkScorer.java index 56857bc67cc1..663662904321 100644 --- a/lucene/core/src/java/org/apache/lucene/search/MaxScoreBulkScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/MaxScoreBulkScorer.java @@ -46,12 +46,14 @@ final class MaxScoreBulkScorer extends BulkScorer { float minCompetitiveScore; private final Score scorable = new Score(); final double[] maxScoreSums; + private final DisiWrapper filter; private final long[] windowMatches = new long[FixedBitSet.bits2words(INNER_WINDOW_SIZE)]; private final double[] windowScores = new double[INNER_WINDOW_SIZE]; - MaxScoreBulkScorer(int maxDoc, List scorers) throws IOException { + MaxScoreBulkScorer(int maxDoc, List scorers, Scorer filter) throws IOException { this.maxDoc = maxDoc; + this.filter = filter == null ? null : new DisiWrapper(filter); allScorers = new DisiWrapper[scorers.size()]; scratch = new DisiWrapper[allScorers.length]; int i = 0; @@ -123,7 +125,7 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr } while (top.doc < outerWindowMax) { - scoreInnerWindow(collector, acceptDocs, outerWindowMax); + scoreInnerWindow(collector, acceptDocs, outerWindowMax, filter); top = essentialQueue.top(); if (minCompetitiveScore >= nextMinCompetitiveScore) { // The minimum competitive score increased substantially, so we can now partition scorers @@ -139,9 +141,11 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr return nextCandidate(max); } - private void scoreInnerWindow(LeafCollector collector, Bits acceptDocs, int max) - throws IOException { - if (allScorers.length - firstRequiredScorer >= 2) { + private void scoreInnerWindow( + LeafCollector collector, Bits acceptDocs, int max, DisiWrapper filter) throws IOException { + if (filter != null) { + scoreInnerWindowWithFilter(collector, acceptDocs, max, filter); + } else if (allScorers.length - firstRequiredScorer >= 2) { scoreInnerWindowAsConjunction(collector, acceptDocs, max); } else { DisiWrapper top = essentialQueue.top(); @@ -158,6 +162,55 @@ private void scoreInnerWindow(LeafCollector collector, Bits acceptDocs, int max) } } + private void scoreInnerWindowWithFilter( + LeafCollector collector, Bits acceptDocs, int max, DisiWrapper filter) throws IOException { + + // TODO: Sometimes load the filter into a bitset and use the more optimized execution paths with + // this bitset as `acceptDocs` + + DisiWrapper top = essentialQueue.top(); + assert top.doc < max; + if (top.doc < filter.doc) { + top.doc = top.approximation.advance(filter.doc); + } + + // Only score an inner window, after that we'll check if the min competitive score has increased + // enough for a more favorable partitioning to be used. + int innerWindowMin = top.doc; + int innerWindowMax = (int) Math.min(max, (long) innerWindowMin + INNER_WINDOW_SIZE); + + while (top.doc < innerWindowMax) { + assert filter.doc <= top.doc; // invariant + if (filter.doc < top.doc) { + filter.doc = filter.approximation.advance(top.doc); + } + + if (filter.doc != top.doc) { + do { + top.doc = top.iterator.advance(filter.doc); + top = essentialQueue.updateTop(); + } while (top.doc < filter.doc); + } else { + int doc = top.doc; + boolean match = + (acceptDocs == null || acceptDocs.get(doc)) + && (filter.twoPhaseView == null || filter.twoPhaseView.matches()); + double score = 0; + do { + if (match) { + score += top.scorer.score(); + } + top.doc = top.iterator.nextDoc(); + top = essentialQueue.updateTop(); + } while (top.doc == doc); + + if (match) { + scoreNonEssentialClauses(collector, doc, score, firstEssentialScorer); + } + } + } + } + private void scoreInnerWindowSingleEssentialClause( LeafCollector collector, Bits acceptDocs, int upTo) throws IOException { DisiWrapper top = essentialQueue.top(); @@ -284,8 +337,10 @@ private int computeOuterWindowMax(int windowMin) throws IOException { int windowMax = DocIdSetIterator.NO_MORE_DOCS; for (int i = firstWindowLead; i < allScorers.length; ++i) { final DisiWrapper scorer = allScorers[i]; - final int upTo = scorer.scorer.advanceShallow(Math.max(scorer.doc, windowMin)); - windowMax = (int) Math.min(windowMax, upTo + 1L); // upTo is inclusive + if (filter == null || scorer.cost >= filter.cost) { + final int upTo = scorer.scorer.advanceShallow(Math.max(scorer.doc, windowMin)); + windowMax = (int) Math.min(windowMax, upTo + 1L); // upTo is inclusive + } } if (allScorers.length - firstWindowLead > 1) { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestBooleanRewrites.java b/lucene/core/src/test/org/apache/lucene/search/TestBooleanRewrites.java index f36c7539c7dc..b876fb48963f 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestBooleanRewrites.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestBooleanRewrites.java @@ -792,6 +792,51 @@ public void testFlattenInnerConjunctions() throws IOException { assertEquals(expectedRewritten, searcher.rewrite(query)); } + public void testFlattenDisjunctionInMustClause() throws IOException { + IndexSearcher searcher = newSearcher(new MultiReader()); + + Query inner = + new BooleanQuery.Builder() + .add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD) + .add(new TermQuery(new Term("foo", "quux")), Occur.SHOULD) + .build(); + Query query = + new BooleanQuery.Builder() + .add(inner, Occur.MUST) + .add(new TermQuery(new Term("foo", "baz")), Occur.FILTER) + .build(); + Query expectedRewritten = + new BooleanQuery.Builder() + .add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD) + .add(new TermQuery(new Term("foo", "quux")), Occur.SHOULD) + .add(new TermQuery(new Term("foo", "baz")), Occur.FILTER) + .setMinimumNumberShouldMatch(1) + .build(); + assertEquals(expectedRewritten, searcher.rewrite(query)); + + inner = + new BooleanQuery.Builder() + .add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD) + .add(new TermQuery(new Term("foo", "quux")), Occur.SHOULD) + .add(new TermQuery(new Term("foo", "foo")), Occur.SHOULD) + .setMinimumNumberShouldMatch(2) + .build(); + query = + new BooleanQuery.Builder() + .add(inner, Occur.MUST) + .add(new TermQuery(new Term("foo", "baz")), Occur.FILTER) + .build(); + expectedRewritten = + new BooleanQuery.Builder() + .add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD) + .add(new TermQuery(new Term("foo", "quux")), Occur.SHOULD) + .add(new TermQuery(new Term("foo", "foo")), Occur.SHOULD) + .add(new TermQuery(new Term("foo", "baz")), Occur.FILTER) + .setMinimumNumberShouldMatch(2) + .build(); + assertEquals(expectedRewritten, searcher.rewrite(query)); + } + public void testDiscardShouldClauses() throws IOException { IndexSearcher searcher = newSearcher(new MultiReader()); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestMaxScoreBulkScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestMaxScoreBulkScorer.java index 6973cc0025a4..2f9fcc2a0fd8 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestMaxScoreBulkScorer.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestMaxScoreBulkScorer.java @@ -85,7 +85,8 @@ public void testBasicsWithTwoDisjunctionClauses() throws Exception { .scorer(context); BulkScorer scorer = - new MaxScoreBulkScorer(context.reader().maxDoc(), Arrays.asList(scorer1, scorer2)); + new MaxScoreBulkScorer( + context.reader().maxDoc(), Arrays.asList(scorer1, scorer2), null); scorer.score( new LeafCollector() { @@ -134,6 +135,141 @@ public void collect(int doc) throws IOException { } } + public void testFilteredDisjunction() throws Exception { + try (Directory dir = newDirectory()) { + writeDocuments(dir); + + try (IndexReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = newSearcher(reader); + + Query clause1 = + new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2); + Query clause2 = new ConstantScoreQuery(new TermQuery(new Term("foo", "C"))); + Query filter = new TermQuery(new Term("foo", "B")); + LeafReaderContext context = searcher.getIndexReader().leaves().get(0); + Scorer scorer1 = + searcher + .createWeight(searcher.rewrite(clause1), ScoreMode.TOP_SCORES, 1f) + .scorer(context); + Scorer scorer2 = + searcher + .createWeight(searcher.rewrite(clause2), ScoreMode.TOP_SCORES, 1f) + .scorer(context); + Scorer filterScorer = + searcher + .createWeight(searcher.rewrite(filter), ScoreMode.TOP_SCORES, 1f) + .scorer(context); + + BulkScorer scorer = + new MaxScoreBulkScorer( + context.reader().maxDoc(), Arrays.asList(scorer1, scorer2), filterScorer); + + scorer.score( + new LeafCollector() { + + private int i; + private Scorable scorer; + + @Override + public void setScorer(Scorable scorer) throws IOException { + this.scorer = scorer; + } + + @Override + public void collect(int doc) throws IOException { + switch (i++) { + case 0: + assertEquals(0, doc); + assertEquals(2, scorer.score(), 0); + break; + case 1: + assertEquals(12288, doc); + assertEquals(2 + 1, scorer.score(), 0); + break; + case 2: + assertEquals(20480, doc); + assertEquals(1, scorer.score(), 0); + break; + default: + fail(); + break; + } + } + }, + null, + 0, + DocIdSetIterator.NO_MORE_DOCS); + } + } + } + + public void testFilteredDisjunctionWithSkipping() throws Exception { + try (Directory dir = newDirectory()) { + writeDocuments(dir); + + try (IndexReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = newSearcher(reader); + + Query clause1 = + new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2); + Query clause2 = new ConstantScoreQuery(new TermQuery(new Term("foo", "C"))); + Query filter = new TermQuery(new Term("foo", "B")); + LeafReaderContext context = searcher.getIndexReader().leaves().get(0); + Scorer scorer1 = + searcher + .createWeight(searcher.rewrite(clause1), ScoreMode.TOP_SCORES, 1f) + .scorer(context); + Scorer scorer2 = + searcher + .createWeight(searcher.rewrite(clause2), ScoreMode.TOP_SCORES, 1f) + .scorer(context); + Scorer filterScorer = + searcher + .createWeight(searcher.rewrite(filter), ScoreMode.TOP_SCORES, 1f) + .scorer(context); + + BulkScorer scorer = + new MaxScoreBulkScorer( + context.reader().maxDoc(), Arrays.asList(scorer1, scorer2), filterScorer); + + scorer.score( + new LeafCollector() { + + private int i; + private Scorable scorer; + + @Override + public void setScorer(Scorable scorer) throws IOException { + this.scorer = scorer; + } + + @Override + public void collect(int doc) throws IOException { + switch (i++) { + case 0: + assertEquals(0, doc); + assertEquals(2, scorer.score(), 0); + scorer.setMinCompetitiveScore(Math.nextUp(2)); + break; + case 1: + assertEquals(12288, doc); + assertEquals(2 + 1, scorer.score(), 0); + scorer.setMinCompetitiveScore(Math.nextUp(2 + 1)); + break; + default: + System.out.println(i); + fail(); + break; + } + } + }, + null, + 0, + DocIdSetIterator.NO_MORE_DOCS); + } + } + } + public void testBasicsWithTwoDisjunctionClausesAndSkipping() throws Exception { try (Directory dir = newDirectory()) { writeDocuments(dir); @@ -155,7 +291,8 @@ public void testBasicsWithTwoDisjunctionClausesAndSkipping() throws Exception { .scorer(context); BulkScorer scorer = - new MaxScoreBulkScorer(context.reader().maxDoc(), Arrays.asList(scorer1, scorer2)); + new MaxScoreBulkScorer( + context.reader().maxDoc(), Arrays.asList(scorer1, scorer2), null); scorer.score( new LeafCollector() { @@ -227,7 +364,7 @@ public void testBasicsWithThreeDisjunctionClauses() throws Exception { BulkScorer scorer = new MaxScoreBulkScorer( - context.reader().maxDoc(), Arrays.asList(scorer1, scorer2, scorer3)); + context.reader().maxDoc(), Arrays.asList(scorer1, scorer2, scorer3), null); scorer.score( new LeafCollector() { @@ -304,7 +441,7 @@ public void testBasicsWithThreeDisjunctionClausesAndSkipping() throws Exception BulkScorer scorer = new MaxScoreBulkScorer( - context.reader().maxDoc(), Arrays.asList(scorer1, scorer2, scorer3)); + context.reader().maxDoc(), Arrays.asList(scorer1, scorer2, scorer3), null); scorer.score( new LeafCollector() { @@ -505,7 +642,8 @@ public void testPartition() throws IOException { fox.cost = 900; fox.maxScore = 1.1f; - MaxScoreBulkScorer scorer = new MaxScoreBulkScorer(10_000, Arrays.asList(the, quick, fox)); + MaxScoreBulkScorer scorer = + new MaxScoreBulkScorer(10_000, Arrays.asList(the, quick, fox), null); the.docID = 4; the.maxScoreUpTo = 130; quick.docID = 4;