Skip to content

Commit

Permalink
Avoid double term construction in DfsPhase (#38716)
Browse files Browse the repository at this point in the history
DfsPhase captures terms used for scoring a query in order to build global term statistics across
multiple shards for more accurate scoring. It currently does this by building the query's `Weight`
and calling `extractTerms` on it to collect terms, and then calling `IndexSearcher.termStatistics()`
for each collected term. This duplicates work, however, as the various `Weight` implementations 
will already have collected these statistics at construction time.

This commit replaces this round-about way of collecting stats, instead using a delegating
IndexSearcher that collects the term contexts and statistics when `IndexSearcher.termStatistics()`
is called from the Weight.

It also fixes a bug when using rescorers, where a `QueryRescorer` would calculate distributed term
statistics, but ignore field statistics.  `Rescorer.extractTerms` has been removed, and replaced with
a new method on `RescoreContext` that returns any queries used by the rescore implementation.
The delegating IndexSearcher then collects term contexts and statistics in the same way described
above for each Query.
  • Loading branch information
romseygeek committed Feb 15, 2019
1 parent 27cf7e2 commit 176013e
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
package org.elasticsearch.example.rescore;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
Expand All @@ -46,7 +45,6 @@
import java.util.Arrays;
import java.util.Iterator;
import java.util.Objects;
import java.util.Set;

import static java.util.Collections.singletonList;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
Expand Down Expand Up @@ -224,9 +222,5 @@ public Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreCon
return Explanation.match(context.factor, "test", singletonList(sourceExplanation));
}

@Override
public void extractTerms(IndexSearcher searcher, RescoreContext rescoreContext, Set<Term> termsSet) {
// Since we don't use queries there are no terms to extract.
}
}
}
125 changes: 37 additions & 88 deletions server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@

package org.elasticsearch.search.dfs;

import com.carrotsearch.hppc.ObjectHashSet;
import com.carrotsearch.hppc.ObjectObjectHashMap;
import com.carrotsearch.hppc.cursors.ObjectCursor;

import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TermStatistics;
import org.elasticsearch.common.collect.HppcMaps;
Expand All @@ -36,9 +34,8 @@
import org.elasticsearch.tasks.TaskCancelledException;

import java.io.IOException;
import java.util.AbstractSet;
import java.util.Collection;
import java.util.Iterator;
import java.util.HashMap;
import java.util.Map;

/**
* Dfs phase of a search request, used to make scoring 100% accurate by collecting additional info from each shard before the query phase.
Expand All @@ -52,101 +49,53 @@ public void preProcess(SearchContext context) {

@Override
public void execute(SearchContext context) {
final ObjectHashSet<Term> termsSet = new ObjectHashSet<>();
try {
context.searcher().createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE, 1f)
.extractTerms(new DelegateSet(termsSet));
ObjectObjectHashMap<String, CollectionStatistics> fieldStatistics = HppcMaps.newNoNullKeysMap();
Map<Term, TermStatistics> stats = new HashMap<>();
IndexSearcher searcher = new IndexSearcher(context.searcher().getIndexReader()) {
@Override
public TermStatistics termStatistics(Term term, TermStates states) throws IOException {
if (context.isCancelled()) {
throw new TaskCancelledException("cancelled");
}
TermStatistics ts = super.termStatistics(term, states);
if (ts != null) {
stats.put(term, ts);
}
return ts;
}

@Override
public CollectionStatistics collectionStatistics(String field) throws IOException {
if (context.isCancelled()) {
throw new TaskCancelledException("cancelled");
}
CollectionStatistics cs = super.collectionStatistics(field);
if (cs != null) {
fieldStatistics.put(field, cs);
}
return cs;
}
};

searcher.createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE, 1);
for (RescoreContext rescoreContext : context.rescore()) {
try {
rescoreContext.rescorer().extractTerms(context.searcher(), rescoreContext, new DelegateSet(termsSet));
} catch (IOException e) {
throw new IllegalStateException("Failed to extract terms", e);
for (Query query : rescoreContext.getQueries()) {
searcher.createWeight(context.searcher().rewrite(query), ScoreMode.COMPLETE, 1);
}
}

Term[] terms = termsSet.toArray(Term.class);
Term[] terms = stats.keySet().toArray(new Term[0]);
TermStatistics[] termStatistics = new TermStatistics[terms.length];
IndexReaderContext indexReaderContext = context.searcher().getTopReaderContext();
for (int i = 0; i < terms.length; i++) {
if(context.isCancelled()) {
throw new TaskCancelledException("cancelled");
}
// LUCENE 4 UPGRADE: cache TermStates?
TermStates termContext = TermStates.build(indexReaderContext, terms[i], true);
termStatistics[i] = context.searcher().termStatistics(terms[i], termContext);
}

ObjectObjectHashMap<String, CollectionStatistics> fieldStatistics = HppcMaps.newNoNullKeysMap();
for (Term term : terms) {
assert term.field() != null : "field is null";
if (fieldStatistics.containsKey(term.field()) == false) {
final CollectionStatistics collectionStatistics = context.searcher().collectionStatistics(term.field());
if (collectionStatistics != null) {
fieldStatistics.put(term.field(), collectionStatistics);
}
if(context.isCancelled()) {
throw new TaskCancelledException("cancelled");
}
}
termStatistics[i] = stats.get(terms[i]);
}

context.dfsResult().termsStatistics(terms, termStatistics)
.fieldStatistics(fieldStatistics)
.maxDoc(context.searcher().getIndexReader().maxDoc());
} catch (Exception e) {
throw new DfsPhaseExecutionException(context, "Exception during dfs phase", e);
} finally {
termsSet.clear(); // don't hold on to terms
}
}

// We need to bridge to JCF world, b/c of Query#extractTerms
private static class DelegateSet extends AbstractSet<Term> {

private final ObjectHashSet<Term> delegate;

private DelegateSet(ObjectHashSet<Term> delegate) {
this.delegate = delegate;
}

@Override
public boolean add(Term term) {
return delegate.add(term);
}

@Override
public boolean addAll(Collection<? extends Term> terms) {
boolean result = false;
for (Term term : terms) {
result = delegate.add(term);
}
return result;
}

@Override
public Iterator<Term> iterator() {
final Iterator<ObjectCursor<Term>> iterator = delegate.iterator();
return new Iterator<Term>() {
@Override
public boolean hasNext() {
return iterator.hasNext();
}

@Override
public Term next() {
return iterator.next().value;
}

@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}

@Override
public int size() {
return delegate.size();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@

package org.elasticsearch.search.rescore;

import org.apache.lucene.index.Term;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Set;
import java.util.Collections;

import static java.util.stream.Collectors.toSet;

public final class QueryRescorer implements Rescorer {
Expand Down Expand Up @@ -170,6 +170,11 @@ public void setQuery(Query query) {
this.query = query;
}

@Override
public List<Query> getQueries() {
return Collections.singletonList(query);
}

public Query query() {
return query;
}
Expand Down Expand Up @@ -203,10 +208,4 @@ public void setScoreMode(String scoreMode) {
}
}

@Override
public void extractTerms(IndexSearcher searcher, RescoreContext rescoreContext, Set<Term> termsSet) throws IOException {
Query query = ((QueryRescoreContext) rescoreContext).query();
searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f).extractTerms(termsSet);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

package org.elasticsearch.search.rescore;

import org.apache.lucene.search.Query;

import java.util.Collections;
import java.util.List;
import java.util.Set;

/**
Expand All @@ -29,7 +33,7 @@
public class RescoreContext {
private final int windowSize;
private final Rescorer rescorer;
private Set<Integer> resroredDocs; //doc Ids for which rescoring was applied
private Set<Integer> rescoredDocs; //doc Ids for which rescoring was applied

/**
* Build the context.
Expand All @@ -55,10 +59,17 @@ public int getWindowSize() {
}

public void setRescoredDocs(Set<Integer> docIds) {
resroredDocs = docIds;
rescoredDocs = docIds;
}

public boolean isRescored(int docId) {
return resroredDocs.contains(docId);
return rescoredDocs.contains(docId);
}

/**
* Returns queries associated with the rescorer
*/
public List<Query> getQueries() {
return Collections.emptyList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,11 @@

package org.elasticsearch.search.rescore;

import org.apache.lucene.index.Term;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.TopDocs;
import org.elasticsearch.action.search.SearchType;

import java.io.IOException;
import java.util.Set;

/**
* A query rescorer interface used to re-rank the Top-K results of a previously
Expand Down Expand Up @@ -61,10 +58,4 @@ public interface Rescorer {
Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreContext rescoreContext,
Explanation sourceExplanation) throws IOException;

/**
* Extracts all terms needed to execute this {@link Rescorer}. This method
* is executed in a distributed frequency collection roundtrip for
* {@link SearchType#DFS_QUERY_THEN_FETCH}
*/
void extractTerms(IndexSearcher searcher, RescoreContext rescoreContext, Set<Term> termsSet) throws IOException;
}

0 comments on commit 176013e

Please sign in to comment.