Skip to content

Commit

Permalink
Adjust SortField comparators to use new Pruning API (#101983)
Browse files Browse the repository at this point in the history
Introduced in apache/lucene#12405

We should account for the changes in our overrides and API. Now, to indicate that no skipping can occur, we utilize `Pruning.NONE`.
  • Loading branch information
benwtrent authored Nov 9, 2023
1 parent 180ef28 commit 3f9ab8a
Show file tree
Hide file tree
Showing 15 changed files with 37 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopFieldDocs;
import org.elasticsearch.search.DocValueFormat;
Expand All @@ -35,7 +36,7 @@ class BottomSortValuesCollector {
this.reverseMuls = new int[sortFields.length];
this.sortFields = sortFields;
for (int i = 0; i < sortFields.length; i++) {
comparators[i] = sortFields[i].getComparator(1, false);
comparators[i] = sortFields[i].getComparator(1, Pruning.NONE);
reverseMuls[i] = sortFields[i].getReverse() ? -1 : 1;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.comparators.TermOrdValComparator;
Expand Down Expand Up @@ -68,13 +69,13 @@ protected SortedBinaryDocValues getValues(LeafReaderContext context) throws IOEx
protected void setScorer(Scorable scorer) {}

@Override
public FieldComparator<?> newComparator(String fieldname, int numHits, boolean enableSkipping, boolean reversed) {
public FieldComparator<?> newComparator(String fieldname, int numHits, Pruning enableSkipping, boolean reversed) {
assert indexFieldData == null || fieldname.equals(indexFieldData.getFieldName());

final boolean sortMissingLast = sortMissingLast(missingValue) ^ reversed;
final BytesRef missingBytes = (BytesRef) missingObject(missingValue, reversed);
if (indexFieldData instanceof IndexOrdinalsFieldData) {
return new TermOrdValComparator(numHits, null, sortMissingLast, reversed, false) {
return new TermOrdValComparator(numHits, null, sortMissingLast, reversed, Pruning.NONE) {

@Override
protected SortedDocValues getSortedDocValues(LeafReaderContext context, String field) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.comparators.DoubleComparator;
Expand Down Expand Up @@ -72,13 +73,13 @@ private NumericDoubleValues getNumericDocValues(LeafReaderContext context, doubl
protected void setScorer(Scorable scorer) {}

@Override
public FieldComparator<?> newComparator(String fieldname, int numHits, boolean enableSkipping, boolean reversed) {
public FieldComparator<?> newComparator(String fieldname, int numHits, Pruning enableSkipping, boolean reversed) {
assert indexFieldData == null || fieldname.equals(indexFieldData.getFieldName());

final double dMissingValue = (Double) missingObject(missingValue, reversed);
// NOTE: it's important to pass null as a missing value in the constructor so that
// the comparator doesn't check docsWithField since we replace missing values in select()
return new DoubleComparator(numHits, null, null, reversed, false) {
return new DoubleComparator(numHits, null, null, reversed, Pruning.NONE) {
@Override
public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException {
return new DoubleLeafComparator(context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.comparators.FloatComparator;
Expand Down Expand Up @@ -65,13 +66,13 @@ private NumericDoubleValues getNumericDocValues(LeafReaderContext context, float
}

@Override
public FieldComparator<?> newComparator(String fieldname, int numHits, boolean enableSkipping, boolean reversed) {
public FieldComparator<?> newComparator(String fieldname, int numHits, Pruning enableSkipping, boolean reversed) {
assert indexFieldData == null || fieldname.equals(indexFieldData.getFieldName());

final float fMissingValue = (Float) missingObject(missingValue, reversed);
// NOTE: it's important to pass null as a missing value in the constructor so that
// the comparator doesn't check docsWithField since we replace missing values in select()
return new FloatComparator(numHits, null, null, reversed, false) {
return new FloatComparator(numHits, null, null, reversed, Pruning.NONE) {
@Override
public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException {
return new FloatLeafComparator(context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.comparators.LongComparator;
import org.apache.lucene.util.BitSet;
Expand Down Expand Up @@ -94,13 +95,13 @@ private NumericDocValues getNumericDocValues(LeafReaderContext context, long mis
}

@Override
public FieldComparator<?> newComparator(String fieldname, int numHits, boolean enableSkipping, boolean reversed) {
public FieldComparator<?> newComparator(String fieldname, int numHits, Pruning enableSkipping, boolean reversed) {
assert indexFieldData == null || fieldname.equals(indexFieldData.getFieldName());

final long lMissingValue = (Long) missingObject(missingValue, reversed);
// NOTE: it's important to pass null as a missing value in the constructor so that
// the comparator doesn't check docsWithField since we replace missing values in select()
return new LongComparator(numHits, null, null, reversed, false) {
return new LongComparator(numHits, null, null, reversed, Pruning.NONE) {
@Override
public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException {
return new LongLeafComparator(context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
Expand Down Expand Up @@ -169,7 +170,7 @@ private SinglePassGroupingCollector(
for (int i = 0; i < sortFields.length; i++) {
final SortField sortField = sortFields[i];
// use topNGroups + 1 so we have a spare slot to use for comparing (tracked by this.spareSlot):
comparators[i] = sortField.getComparator(topNGroups + 1, false);
comparators[i] = sortField.getComparator(topNGroups + 1, Pruning.NONE);
reversed[i] = sortField.getReverse() ? -1 : 1;
}
if (after != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
Expand Down Expand Up @@ -121,7 +122,7 @@ private static class MergeSortQueue extends PriorityQueue<ShardRef> {
reverseMul = new int[sortFields.length];
for (int compIDX = 0; compIDX < sortFields.length; compIDX++) {
final SortField sortField = sortFields[compIDX];
comparators[compIDX] = sortField.getComparator(1, false);
comparators[compIDX] = sortField.getComparator(1, Pruning.NONE);
reverseMul[compIDX] = sortField.getReverse() ? -1 : 1;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
Expand Down Expand Up @@ -52,7 +53,7 @@ public SearchAfterSortedDocQuery(Sort sort, FieldDoc after) {
this.reverseMuls = new int[numFields];
for (int i = 0; i < numFields; i++) {
SortField sortField = sort.getSort()[i];
FieldComparator<?> fieldComparator = sortField.getComparator(1, false);
FieldComparator<?> fieldComparator = sortField.getComparator(1, Pruning.NONE);
@SuppressWarnings("unchecked")
FieldComparator<Object> comparator = (FieldComparator<Object>) fieldComparator;
comparator.setTopValue(after.fields[i]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Sort;
Expand Down Expand Up @@ -359,8 +360,8 @@ public int hashCode() {
}

@Override
public FieldComparator<?> getComparator(int numHits, boolean enableSkipping) {
return new LongComparator(1, delegate.getField(), (Long) missingValue, delegate.getReverse(), false) {
public FieldComparator<?> getComparator(int numHits, Pruning enableSkipping) {
return new LongComparator(1, delegate.getField(), (Long) missingValue, delegate.getReverse(), Pruning.NONE) {
@Override
public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException {
return new LongLeafComparator(context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.comparators.DoubleComparator;
import org.apache.lucene.util.BitSet;
Expand Down Expand Up @@ -663,8 +664,8 @@ private NumericDoubleValues getNumericDoubleValues(LeafReaderContext context) th
}

@Override
public FieldComparator<?> newComparator(String fieldname, int numHits, boolean enableSkipping, boolean reversed) {
return new DoubleComparator(numHits, null, null, reversed, false) {
public FieldComparator<?> newComparator(String fieldname, int numHits, Pruning enableSkipping, boolean reversed) {
return new DoubleComparator(numHits, null, null, reversed, Pruning.NONE) {
@Override
public LeafFieldComparator getLeafComparator(LeafReaderContext context) throws IOException {
return new DoubleLeafComparator(context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.comparators.DocComparator;

Expand All @@ -34,8 +35,8 @@ int getShardRequestIndex() {
}

@Override
public FieldComparator<?> getComparator(int numHits, boolean enableSkipping) {
final DocComparator delegate = new DocComparator(numHits, getReverse(), false);
public FieldComparator<?> getComparator(int numHits, Pruning enableSkipping) {
final DocComparator delegate = new DocComparator(numHits, getReverse(), Pruning.NONE);

return new FieldComparator<Long>() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.TotalHits;
Expand Down Expand Up @@ -234,7 +235,7 @@ private Object[] newDateNanoArray(String... values) {
private TopFieldDocs createTopDocs(SortField sortField, int totalHits, Object[] values) {
FieldDoc[] fieldDocs = new FieldDoc[values.length];
@SuppressWarnings("unchecked")
FieldComparator<Object> cmp = (FieldComparator<Object>) sortField.getComparator(1, false);
FieldComparator<Object> cmp = (FieldComparator<Object>) sortField.getComparator(1, Pruning.NONE);
for (int i = 0; i < values.length; i++) {
fieldDocs[i] = new FieldDoc(i, Float.NaN, new Object[] { values[i] });
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
Expand Down Expand Up @@ -367,7 +368,7 @@ private Comparator<ScoreDoc> sortFieldsComparator(SortField[] sortFields) {
FieldComparator[] comparators = new FieldComparator[sortFields.length];
for (int i = 0; i < sortFields.length; i++) {
// Values passed to getComparator shouldn't matter
comparators[i] = sortFields[i].getComparator(0, false);
comparators[i] = sortFields[i].getComparator(0, Pruning.NONE);
}
return (lhs, rhs) -> {
FieldDoc l = (FieldDoc) lhs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.MultiTermQuery;
import org.apache.lucene.search.PrefixQuery;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryCachingPolicy;
import org.apache.lucene.search.ScoreDoc;
Expand Down Expand Up @@ -720,7 +721,7 @@ public void testIndexSortScrollOptimization() throws Exception {
@SuppressWarnings("unchecked")
FieldComparator<Object> comparator = (FieldComparator<Object>) searchSortAndFormat.sort.getSort()[i].getComparator(
1,
i == 0
i == 0 ? Pruning.GREATER_THAN : Pruning.NONE
);
int cmp = comparator.compareValues(firstDoc.fields[i], lastDoc.fields[i]);
if (cmp == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.lucene.document.LatLonDocValuesField;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.SortedNumericSortField;
Expand Down Expand Up @@ -216,7 +217,7 @@ public SortField.Type reducedType() {
}

@Override
public FieldComparator<?> newComparator(String fieldname, int numHits, boolean enableSkipping, boolean reversed) {
public FieldComparator<?> newComparator(String fieldname, int numHits, Pruning enableSkipping, boolean reversed) {
return null;
}

Expand Down

0 comments on commit 3f9ab8a

Please sign in to comment.