Skip to content

Commit

Permalink
Automatically early terminate search query based on index sorting (#2…
Browse files Browse the repository at this point in the history
…4864)

This commit refactors the query phase in order to be able
to automatically detect queries that can be early terminated.
If the index sort matches the query sort, the top docs collection is early terminated
on each segment and the computing of the total number of hits that match the query is delegated to a simple TotalHitCountCollector.
This change also adds a new parameter to the search request called `track_total_hits`.
It indicates if the total number of hits that match the query should be tracked.
If false, queries sorted by the index sort will not try to compute this information and 
and will limit the collection to the first N documents per segment.
Aggregations are not impacted and will continue to see every document
even when the index sort matches the query sort and `track_total_hits` is false.

Relates #6720
  • Loading branch information
jimczi authored Jun 8, 2017
1 parent 21a57c1 commit 36a5cf8
Show file tree
Hide file tree
Showing 30 changed files with 1,539 additions and 431 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
abstract class CollapsingDocValuesSource<T> extends GroupSelector<T> {
protected final String field;

CollapsingDocValuesSource(String field) throws IOException {
CollapsingDocValuesSource(String field) {
this.field = field;
}

Expand All @@ -58,7 +58,7 @@ static class Numeric extends CollapsingDocValuesSource<Long> {
private long value;
private boolean hasValue;

Numeric(String field) throws IOException {
Numeric(String field) {
super(field);
}

Expand Down Expand Up @@ -148,7 +148,7 @@ static class Keyword extends CollapsingDocValuesSource<BytesRef> {
private SortedDocValues values;
private int ord;

Keyword(String field) throws IOException {
Keyword(String field) {
super(field);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public final class CollapsingTopDocsCollector<T> extends FirstPassGroupingCollec
private final boolean trackMaxScore;

CollapsingTopDocsCollector(GroupSelector<T> groupSelector, String collapseField, Sort sort,
int topN, boolean trackMaxScore) throws IOException {
int topN, boolean trackMaxScore) {
super(groupSelector, sort, topN);
this.collapseField = collapseField;
this.trackMaxScore = trackMaxScore;
Expand All @@ -60,7 +60,7 @@ public final class CollapsingTopDocsCollector<T> extends FirstPassGroupingCollec

/**
* Transform {@link FirstPassGroupingCollector#getTopGroups(int, boolean)} output in
* {@link CollapseTopFieldDocs}. The collapsing needs only one pass so we can create the final top docs at the end
* {@link CollapseTopFieldDocs}. The collapsing needs only one pass so we can get the final top docs at the end
* of the first pass.
*/
public CollapseTopFieldDocs getTopDocs() throws IOException {
Expand Down Expand Up @@ -132,10 +132,9 @@ public void collect(int doc) throws IOException {
* This must be non-null, ie, if you want to groupSort by relevance
* use Sort.RELEVANCE.
* @param topN How many top groups to keep.
* @throws IOException When I/O related errors occur
*/
public static CollapsingTopDocsCollector<?> createNumeric(String collapseField, Sort sort,
int topN, boolean trackMaxScore) throws IOException {
int topN, boolean trackMaxScore) {
return new CollapsingTopDocsCollector<>(new CollapsingDocValuesSource.Numeric(collapseField),
collapseField, sort, topN, trackMaxScore);
}
Expand All @@ -152,12 +151,10 @@ public static CollapsingTopDocsCollector<?> createNumeric(String collapseField,
* document per collapsed key.
* This must be non-null, ie, if you want to groupSort by relevance use Sort.RELEVANCE.
* @param topN How many top groups to keep.
* @throws IOException When I/O related errors occur
*/
public static CollapsingTopDocsCollector<?> createKeyword(String collapseField, Sort sort,
int topN, boolean trackMaxScore) throws IOException {
int topN, boolean trackMaxScore) {
return new CollapsingTopDocsCollector<>(new CollapsingDocValuesSource.Keyword(collapseField),
collapseField, sort, topN, trackMaxScore);
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -829,8 +829,7 @@ private enum ElasticsearchExceptionHandle {
org.elasticsearch.transport.SendRequestTransportException::new, 58, UNKNOWN_VERSION_ADDED),
ES_REJECTED_EXECUTION_EXCEPTION(org.elasticsearch.common.util.concurrent.EsRejectedExecutionException.class,
org.elasticsearch.common.util.concurrent.EsRejectedExecutionException::new, 59, UNKNOWN_VERSION_ADDED),
EARLY_TERMINATION_EXCEPTION(org.elasticsearch.common.lucene.Lucene.EarlyTerminationException.class,
org.elasticsearch.common.lucene.Lucene.EarlyTerminationException::new, 60, UNKNOWN_VERSION_ADDED),
// 60 used to be for EarlyTerminationException
// 61 used to be for RoutingValidationException
NOT_SERIALIZABLE_EXCEPTION_WRAPPER(org.elasticsearch.common.io.stream.NotSerializableExceptionWrapper.class,
org.elasticsearch.common.io.stream.NotSerializableExceptionWrapper::new, 62, UNKNOWN_VERSION_ADDED),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,18 @@ private SearchHits getHits(ReducedQueryPhase reducedQueryPhase, boolean ignoreFr
* @param queryResults a list of non-null query shard results
*/
public ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults, boolean isScrollRequest) {
return reducedQueryPhase(queryResults, null, new ArrayList<>(), new TopDocsStats(), 0, isScrollRequest);
return reducedQueryPhase(queryResults, isScrollRequest, true);
}

/**
* Reduces the given query results and consumes all aggregations and profile results.
* @param queryResults a list of non-null query shard results
*/
public ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults, boolean isScrollRequest, boolean trackTotalHits) {
return reducedQueryPhase(queryResults, null, new ArrayList<>(), new TopDocsStats(trackTotalHits), 0, isScrollRequest);
}


/**
* Reduces the given query results and consumes all aggregations and profile results.
* @param queryResults a list of non-null query shard results
Expand Down Expand Up @@ -711,6 +720,7 @@ InitialSearchPhase.SearchPhaseResults<SearchPhaseResult> newSearchPhaseResults(S
boolean isScrollRequest = request.scroll() != null;
final boolean hasAggs = source != null && source.aggregations() != null;
final boolean hasTopDocs = source == null || source.size() != 0;
final boolean trackTotalHits = source == null || source.trackTotalHits();

if (isScrollRequest == false && (hasAggs || hasTopDocs)) {
// no incremental reduce if scroll is used - we only hit a single shard or sometimes more...
Expand All @@ -722,18 +732,30 @@ InitialSearchPhase.SearchPhaseResults<SearchPhaseResult> newSearchPhaseResults(S
return new InitialSearchPhase.SearchPhaseResults(numShards) {
@Override
public ReducedQueryPhase reduce() {
return reducedQueryPhase(results.asList(), isScrollRequest);
return reducedQueryPhase(results.asList(), isScrollRequest, trackTotalHits);
}
};
}

static final class TopDocsStats {
final boolean trackTotalHits;
long totalHits;
long fetchHits;
float maxScore = Float.NEGATIVE_INFINITY;

TopDocsStats() {
this(true);
}

TopDocsStats(boolean trackTotalHits) {
this.trackTotalHits = trackTotalHits;
this.totalHits = trackTotalHits ? 0 : -1;
}

void add(TopDocs topDocs) {
totalHits += topDocs.totalHits;
if (trackTotalHits) {
totalHits += topDocs.totalHits;
}
fetchHits += topDocs.scoreDocs.length;
if (!Float.isNaN(topDocs.getMaxScore())) {
maxScore = Math.max(maxScore, topDocs.getMaxScore());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
import java.util.Collections;
import java.util.Objects;

import static org.elasticsearch.action.ValidateActions.addValidationError;

/**
* A request to execute search against one or more indices (or all). Best created using
* {@link org.elasticsearch.client.Requests#searchRequest(String...)}.
Expand Down Expand Up @@ -102,7 +104,12 @@ public SearchRequest(String[] indices, SearchSourceBuilder source) {

@Override
public ActionRequestValidationException validate() {
return null;
ActionRequestValidationException validationException = null;
if (source != null && source.trackTotalHits() == false && scroll() != null) {
validationException =
addValidationError("disabling [track_total_hits] is not allowed in a scroll context", validationException);
}
return validationException;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,14 +363,21 @@ public SearchRequestBuilder slice(SliceBuilder builder) {
}

/**
* Applies when sorting, and controls if scores will be tracked as well. Defaults to
* <tt>false</tt>.
* Applies when sorting, and controls if scores will be tracked as well. Defaults to <tt>false</tt>.
*/
public SearchRequestBuilder setTrackScores(boolean trackScores) {
sourceBuilder().trackScores(trackScores);
return this;
}

/**
* Indicates if the total hit count for the query should be tracked. Defaults to <tt>true</tt>
*/
public SearchRequestBuilder setTrackTotalHits(boolean trackTotalHits) {
sourceBuilder().trackTotalHits(trackTotalHits);
return this;
}

/**
* Adds stored fields to load and return (note, it must be stored) as part of the search request.
* To disable the stored fields entirely (source and metadata fields) use {@code storedField("_none_")}.
Expand Down
79 changes: 0 additions & 79 deletions core/src/main/java/org/elasticsearch/common/lucene/Lucene.java
Original file line number Diff line number Diff line change
Expand Up @@ -246,20 +246,6 @@ protected Object doBody(String segmentFileName) throws IOException {
}.run();
}

/**
* Wraps <code>delegate</code> with count based early termination collector with a threshold of <code>maxCountHits</code>
*/
public static final EarlyTerminatingCollector wrapCountBasedEarlyTerminatingCollector(final Collector delegate, int maxCountHits) {
return new EarlyTerminatingCollector(delegate, maxCountHits);
}

/**
* Wraps <code>delegate</code> with a time limited collector with a timeout of <code>timeoutInMillis</code>
*/
public static final TimeLimitingCollector wrapTimeLimitingCollector(final Collector delegate, final Counter counter, long timeoutInMillis) {
return new TimeLimitingCollector(delegate, counter, timeoutInMillis);
}

/**
* Check whether there is one or more documents matching the provided query.
*/
Expand Down Expand Up @@ -618,71 +604,6 @@ public static void writeExplanation(StreamOutput out, Explanation explanation) t
}
}

/**
* This exception is thrown when {@link org.elasticsearch.common.lucene.Lucene.EarlyTerminatingCollector}
* reaches early termination
* */
public static final class EarlyTerminationException extends ElasticsearchException {

public EarlyTerminationException(String msg) {
super(msg);
}

public EarlyTerminationException(StreamInput in) throws IOException{
super(in);
}
}

/**
* A collector that terminates early by throwing {@link org.elasticsearch.common.lucene.Lucene.EarlyTerminationException}
* when count of matched documents has reached <code>maxCountHits</code>
*/
public static final class EarlyTerminatingCollector extends SimpleCollector {

private final int maxCountHits;
private final Collector delegate;

private int count = 0;
private LeafCollector leafCollector;

EarlyTerminatingCollector(final Collector delegate, int maxCountHits) {
this.maxCountHits = maxCountHits;
this.delegate = Objects.requireNonNull(delegate);
}

public int count() {
return count;
}

public boolean exists() {
return count > 0;
}

@Override
public void setScorer(Scorer scorer) throws IOException {
leafCollector.setScorer(scorer);
}

@Override
public void collect(int doc) throws IOException {
leafCollector.collect(doc);

if (++count >= maxCountHits) {
throw new EarlyTerminationException("early termination [CountBased]");
}
}

@Override
public void doSetNextReader(LeafReaderContext atomicReaderContext) throws IOException {
leafCollector = delegate.getLeafCollector(atomicReaderContext);
}

@Override
public boolean needsScores() {
return delegate.needsScores();
}
}

private Lucene() {

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ private static void parseSearchSource(final SearchSourceBuilder searchSourceBuil
searchSourceBuilder.trackScores(request.paramAsBoolean("track_scores", false));
}

if (request.hasParam("track_total_hits")) {
searchSourceBuilder.trackTotalHits(request.paramAsBoolean("track_total_hits", true));
}

String sSorts = request.param("sort");
if (sSorts != null) {
String[] sorts = Strings.splitStringByCommaToArray(sSorts);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ final class DefaultSearchContext extends SearchContext {
private SortAndFormats sort;
private Float minimumScore;
private boolean trackScores = false; // when sorting, track scores as well...
private boolean trackTotalHits = true;
private FieldDoc searchAfter;
private CollapseContext collapse;
private boolean lowLevelCancellation;
Expand Down Expand Up @@ -548,6 +549,17 @@ public boolean trackScores() {
return this.trackScores;
}

@Override
public SearchContext trackTotalHits(boolean trackTotalHits) {
this.trackTotalHits = trackTotalHits;
return this;
}

@Override
public boolean trackTotalHits() {
return trackTotalHits;
}

@Override
public SearchContext searchAfter(FieldDoc searchAfter) {
this.searchAfter = searchAfter;
Expand Down
25 changes: 23 additions & 2 deletions core/src/main/java/org/elasticsearch/search/SearchHits.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.elasticsearch.search;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Streamable;
Expand Down Expand Up @@ -178,7 +179,17 @@ public static SearchHits readSearchHits(StreamInput in) throws IOException {

@Override
public void readFrom(StreamInput in) throws IOException {
totalHits = in.readVLong();
final boolean hasTotalHits;
if (in.getVersion().onOrAfter(Version.V_6_0_0_alpha3)) {
hasTotalHits = in.readBoolean();
} else {
hasTotalHits = true;
}
if (hasTotalHits) {
totalHits = in.readVLong();
} else {
totalHits = -1;
}
maxScore = in.readFloat();
int size = in.readVInt();
if (size == 0) {
Expand All @@ -193,7 +204,17 @@ public void readFrom(StreamInput in) throws IOException {

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVLong(totalHits);
final boolean hasTotalHits;
if (out.getVersion().onOrAfter(Version.V_6_0_0_alpha3)) {
hasTotalHits = totalHits >= 0;
out.writeBoolean(hasTotalHits);
} else {
assert totalHits >= 0;
hasTotalHits = true;
}
if (hasTotalHits) {
out.writeVLong(totalHits);
}
out.writeFloat(maxScore);
out.writeVInt(hits.length);
if (hits.length > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,10 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
}
}
context.trackScores(source.trackScores());
if (source.trackTotalHits() == false && context.scrollContext() != null) {
throw new SearchContextException(context, "disabling [track_total_hits] is not allowed in a scroll context");
}
context.trackTotalHits(source.trackTotalHits());
if (source.minScore() != null) {
context.minimumScore(source.minScore());
}
Expand Down
Loading

0 comments on commit 36a5cf8

Please sign in to comment.