diff --git a/server/src/main/java/org/elasticsearch/index/search/ESToParentBlockJoinQuery.java b/server/src/main/java/org/elasticsearch/index/search/ESToParentBlockJoinQuery.java index 30cd6e52f0dbd..ea9c8793d88c3 100644 --- a/server/src/main/java/org/elasticsearch/index/search/ESToParentBlockJoinQuery.java +++ b/server/src/main/java/org/elasticsearch/index/search/ESToParentBlockJoinQuery.java @@ -37,14 +37,16 @@ public final class ESToParentBlockJoinQuery extends Query { private final ToParentBlockJoinQuery query; private final String path; + private final ScoreMode scoreMode; public ESToParentBlockJoinQuery(Query childQuery, BitSetProducer parentsFilter, ScoreMode scoreMode, String path) { - this(new ToParentBlockJoinQuery(childQuery, parentsFilter, scoreMode), path); + this(new ToParentBlockJoinQuery(childQuery, parentsFilter, scoreMode), path, scoreMode); } - private ESToParentBlockJoinQuery(ToParentBlockJoinQuery query, String path) { + private ESToParentBlockJoinQuery(ToParentBlockJoinQuery query, String path, ScoreMode scoreMode) { this.query = query; this.path = path; + this.scoreMode = scoreMode; } /** Return the child query. */ @@ -57,6 +59,11 @@ public String getPath() { return path; } + /** Return the score mode for the matching children. **/ + public ScoreMode getScoreMode() { + return scoreMode; + } + @Override public Query rewrite(IndexReader reader) throws IOException { Query innerRewrite = query.rewrite(reader); @@ -68,7 +75,7 @@ public Query rewrite(IndexReader reader) throws IOException { // to a MatchNoDocsQuery. In that case it would be fine to lose information // about the nested path. if (innerRewrite instanceof ToParentBlockJoinQuery) { - return new ESToParentBlockJoinQuery((ToParentBlockJoinQuery) innerRewrite, path); + return new ESToParentBlockJoinQuery((ToParentBlockJoinQuery) innerRewrite, path, scoreMode); } else { return innerRewrite; } diff --git a/server/src/main/java/org/elasticsearch/search/query/TopDocsCollectorContext.java b/server/src/main/java/org/elasticsearch/search/query/TopDocsCollectorContext.java index 1e2cd7541f944..5ae6cc739c362 100644 --- a/server/src/main/java/org/elasticsearch/search/query/TopDocsCollectorContext.java +++ b/server/src/main/java/org/elasticsearch/search/query/TopDocsCollectorContext.java @@ -27,16 +27,21 @@ import org.apache.lucene.index.PointValues; import org.apache.lucene.index.Term; import org.apache.lucene.index.Terms; +import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.Collector; import org.apache.lucene.search.ConstantScoreQuery; import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.FilterCollector; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MultiCollector; import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocsCollector; @@ -47,11 +52,15 @@ import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.grouping.CollapseTopFieldDocs; import org.apache.lucene.search.grouping.CollapsingTopDocsCollector; +import org.apache.lucene.search.spans.SpanQuery; import org.elasticsearch.action.search.MaxScoreCollector; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; +import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery; +import org.elasticsearch.common.lucene.search.function.ScriptScoreQuery; import org.elasticsearch.common.util.CachedSupplier; +import org.elasticsearch.index.search.ESToParentBlockJoinQuery; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.collapse.CollapseContext; import org.elasticsearch.search.internal.ScrollContext; @@ -264,7 +273,29 @@ private SimpleTopDocsCollectorContext(IndexReader reader, } else { maxScoreSupplier = () -> Float.NaN; } - this.collector = MultiCollector.wrap(topDocsCollector, maxScoreCollector); + + final Collector collector = MultiCollector.wrap(topDocsCollector, maxScoreCollector); + if (sortAndFormats == null || + SortField.FIELD_SCORE.equals(sortAndFormats.sort.getSort()[0])) { + if (hasInfMaxScore(query)) { + // disable max score optimization since we have a mandatory clause + // that doesn't track the maximum score + this.collector = new FilterCollector(collector) { + @Override + public ScoreMode scoreMode() { + if (in.scoreMode() == ScoreMode.TOP_SCORES) { + return ScoreMode.COMPLETE; + } + return in.scoreMode(); + } + }; + } else { + this.collector = collector; + } + } else { + this.collector = collector; + } + } @Override @@ -437,4 +468,45 @@ boolean shouldRescore() { }; } } + + /** + * Return true if the provided query contains a mandatory clauses (MUST) + * that doesn't track the maximum scores per block + */ + static boolean hasInfMaxScore(Query query) { + MaxScoreQueryVisitor visitor = new MaxScoreQueryVisitor(); + query.visit(visitor); + return visitor.hasInfMaxScore; + } + + private static class MaxScoreQueryVisitor extends QueryVisitor { + private boolean hasInfMaxScore; + + @Override + public void visitLeaf(Query query) { + checkMaxScoreInfo(query); + } + + @Override + public QueryVisitor getSubVisitor(BooleanClause.Occur occur, Query parent) { + if (occur != BooleanClause.Occur.MUST) { + // boolean queries can skip documents even if they have some should + // clauses that don't track maximum scores + return QueryVisitor.EMPTY_VISITOR; + } + checkMaxScoreInfo(parent); + return this; + } + + void checkMaxScoreInfo(Query query) { + if (query instanceof FunctionScoreQuery + || query instanceof ScriptScoreQuery + || query instanceof SpanQuery) { + hasInfMaxScore = true; + } else if (query instanceof ESToParentBlockJoinQuery) { + ESToParentBlockJoinQuery q = (ESToParentBlockJoinQuery) query; + hasInfMaxScore |= (q.getScoreMode() != org.apache.lucene.search.join.ScoreMode.None); + } + } + } } diff --git a/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java b/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java index e65b4aa377a40..a5061c35d7fcf 100644 --- a/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.search.query; +import org.apache.lucene.analysis.standard.StandardAnalyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field.Store; import org.apache.lucene.document.LatLonDocValuesField; @@ -26,6 +27,7 @@ import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.document.StringField; +import org.apache.lucene.document.TextField; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriterConfig; @@ -52,12 +54,19 @@ import org.apache.lucene.search.SortField; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TotalHitCountCollector; +import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.Weight; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.search.spans.SpanNearQuery; +import org.apache.lucene.search.spans.SpanTermQuery; import org.apache.lucene.store.Directory; import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FixedBitSet; import org.elasticsearch.action.search.SearchTask; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.query.ParsedQuery; +import org.elasticsearch.index.search.ESToParentBlockJoinQuery; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.shard.IndexShardTestCase; import org.elasticsearch.search.DocValueFormat; @@ -559,6 +568,108 @@ public void testIndexSortScrollOptimization() throws Exception { dir.close(); } + + public void testDisableTopScoreCollection() throws Exception { + Directory dir = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig(new StandardAnalyzer()); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + for (int i = 0; i < 10; i++) { + doc.clear(); + if (i % 2 == 0) { + doc.add(new TextField("title", "foo bar", Store.NO)); + } else { + doc.add(new TextField("title", "foo", Store.NO)); + } + w.addDocument(doc); + } + w.close(); + + IndexReader reader = DirectoryReader.open(dir); + IndexSearcher contextSearcher = new IndexSearcher(reader); + TestSearchContext context = new TestSearchContext(null, indexShard); + context.setTask(new SearchTask(123L, "", "", "", null, Collections.emptyMap())); + Query q = new SpanNearQuery.Builder("title", true) + .addClause(new SpanTermQuery(new Term("title", "foo"))) + .addClause(new SpanTermQuery(new Term("title", "bar"))) + .build(); + + context.parsedQuery(new ParsedQuery(q)); + context.setSize(10); + TopDocsCollectorContext topDocsContext = + TopDocsCollectorContext.createTopDocsCollectorContext(context, reader, false); + assertEquals(topDocsContext.create(null).scoreMode(), org.apache.lucene.search.ScoreMode.COMPLETE); + QueryPhase.execute(context, contextSearcher, checkCancelled -> {}); + assertEquals(5, context.queryResult().topDocs().topDocs.totalHits.value); + assertEquals(context.queryResult().topDocs().topDocs.totalHits.relation, TotalHits.Relation.EQUAL_TO); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(5)); + + + context.sort(new SortAndFormats(new Sort(new SortField("other", SortField.Type.INT)), + new DocValueFormat[] { DocValueFormat.RAW })); + topDocsContext = + TopDocsCollectorContext.createTopDocsCollectorContext(context, reader, false); + assertEquals(topDocsContext.create(null).scoreMode(), org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES); + QueryPhase.execute(context, contextSearcher, checkCancelled -> {}); + assertEquals(5, context.queryResult().topDocs().topDocs.totalHits.value); + assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(5)); + assertEquals(context.queryResult().topDocs().topDocs.totalHits.relation, TotalHits.Relation.EQUAL_TO); + + reader.close(); + dir.close(); + } + + public void testMaxScoreQueryVisitor() { + BitSetProducer producer = context -> new FixedBitSet(1); + Query query = new ESToParentBlockJoinQuery(new MatchAllDocsQuery(), producer, ScoreMode.Avg, "nested"); + assertTrue(TopDocsCollectorContext.hasInfMaxScore(query)); + + query = new ESToParentBlockJoinQuery(new MatchAllDocsQuery(), producer, ScoreMode.None, "nested"); + assertFalse(TopDocsCollectorContext.hasInfMaxScore(query)); + + + for (Occur occur : Occur.values()) { + query = new BooleanQuery.Builder() + .add(new ESToParentBlockJoinQuery(new MatchAllDocsQuery(), producer, ScoreMode.Avg, "nested"), occur) + .build(); + if (occur == Occur.MUST) { + assertTrue(TopDocsCollectorContext.hasInfMaxScore(query)); + } else { + assertFalse(TopDocsCollectorContext.hasInfMaxScore(query)); + } + + query = new BooleanQuery.Builder() + .add(new BooleanQuery.Builder() + .add(new ESToParentBlockJoinQuery(new MatchAllDocsQuery(), producer, ScoreMode.Avg, "nested"), occur) + .build(), occur) + .build(); + if (occur == Occur.MUST) { + assertTrue(TopDocsCollectorContext.hasInfMaxScore(query)); + } else { + assertFalse(TopDocsCollectorContext.hasInfMaxScore(query)); + } + + query = new BooleanQuery.Builder() + .add(new BooleanQuery.Builder() + .add(new ESToParentBlockJoinQuery(new MatchAllDocsQuery(), producer, ScoreMode.Avg, "nested"), occur) + .build(), Occur.FILTER) + .build(); + assertFalse(TopDocsCollectorContext.hasInfMaxScore(query)); + + query = new BooleanQuery.Builder() + .add(new BooleanQuery.Builder() + .add(new SpanTermQuery(new Term("field", "foo")), occur) + .add(new ESToParentBlockJoinQuery(new MatchAllDocsQuery(), producer, ScoreMode.Avg, "nested"), occur) + .build(), occur) + .build(); + if (occur == Occur.MUST) { + assertTrue(TopDocsCollectorContext.hasInfMaxScore(query)); + } else { + assertFalse(TopDocsCollectorContext.hasInfMaxScore(query)); + } + } + } + private static IndexSearcher getAssertingEarlyTerminationSearcher(IndexReader reader, int size) { return new IndexSearcher(reader) { @Override