Skip to content

Commit

Permalink
Misc cleanups and Speedups around sort queries
Browse files Browse the repository at this point in the history
We can get a few small allocation savings from saivng intermediary lists,
some minor speedups and code-shorting by using records in a few more spots and
a couple more small improvements to save some code + cycles.
  • Loading branch information
original-brownbear committed Sep 19, 2024
1 parent 90e343c commit 0c1418b
Show file tree
Hide file tree
Showing 25 changed files with 138 additions and 170 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ public TopDocsAndMaxScore topDocs(SearchHit hit) throws IOException {
TopDocsCollector<?> topDocsCollector;
MaxScoreCollector maxScoreCollector = null;
if (sort() != null) {
topDocsCollector = TopFieldCollector.create(sort().sort, topN, Integer.MAX_VALUE);
topDocsCollector = TopFieldCollector.create(sort().sort(), topN, Integer.MAX_VALUE);
if (trackScores()) {
maxScoreCollector = new MaxScoreCollector();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,8 @@ private static boolean shouldSortShards(MinAndMax<?>[] minAndMaxes) {
Class<?> clazz = null;
for (MinAndMax<?> minAndMax : minAndMaxes) {
if (clazz == null) {
clazz = minAndMax == null ? null : minAndMax.getMin().getClass();
} else if (minAndMax != null && clazz != minAndMax.getMin().getClass()) {
clazz = minAndMax == null ? null : minAndMax.minValue().getClass();
} else if (minAndMax != null && clazz != minAndMax.maxValue().getClass()) {
// we don't support sort values that mix different types (e.g.: long/double, numeric/keyword).
// TODO: we could fail the request because there is a high probability
// that the merging of topdocs will fail later for the same reason ?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ public TopDocsAndMaxScore topDocs(SearchHit hit) throws IOException {
TopDocsCollector<?> topDocsCollector;
MaxScoreCollector maxScoreCollector = null;
if (sort() != null) {
topDocsCollector = TopFieldCollector.create(sort().sort, topN, Integer.MAX_VALUE);
topDocsCollector = TopFieldCollector.create(sort().sort(), topN, Integer.MAX_VALUE);
if (trackScores()) {
maxScoreCollector = new MaxScoreCollector();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public final class CanMatchShardResponse extends SearchPhaseResult {
public CanMatchShardResponse(StreamInput in) throws IOException {
super(in);
this.canMatch = in.readBoolean();
estimatedMinAndMax = in.readOptionalWriteable(MinAndMax::new);
estimatedMinAndMax = in.readOptionalWriteable(MinAndMax::readFrom);
}

public CanMatchShardResponse(boolean canMatch, MinAndMax<?> estimatedMinAndMax) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ private static class Collectors {
public ScoreMode scoreMode() {
SortAndFormats sort = subSearchContext.sort();
if (sort != null) {
return sort.sort.needsScores() || subSearchContext.trackScores() ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
return sort.sort().needsScores() || subSearchContext.trackScores() ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
} else {
// sort by score
return ScoreMode.COMPLETE;
Expand Down Expand Up @@ -141,7 +141,7 @@ public void collect(int docId, long bucket) throws IOException {
// TODO: can we pass trackTotalHits=subSearchContext.trackTotalHits(){
// Note that this would require to catch CollectionTerminatedException
collectors = new Collectors(
TopFieldCollector.create(sort.sort, topN, Integer.MAX_VALUE),
TopFieldCollector.create(sort.sort(), topN, Integer.MAX_VALUE),
subSearchContext.trackScores() ? new MaxScoreCollector() : null
);
}
Expand Down Expand Up @@ -187,7 +187,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) throws IOE
}
final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, maxScore);
subSearchContext.queryResult()
.topDocs(topDocsAndMaxScore, subSearchContext.sort() == null ? null : subSearchContext.sort().formats);
.topDocs(topDocsAndMaxScore, subSearchContext.sort() == null ? null : subSearchContext.sort().formats());
int[] docIdsToLoad = new int[topDocs.scoreDocs.length];
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
docIdsToLoad[i] = topDocs.scoreDocs[i].doc;
Expand All @@ -203,7 +203,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) throws IOE
searchHitFields.shard(subSearchContext.shardTarget());
searchHitFields.score(scoreDoc.score);
if (scoreDoc instanceof FieldDoc fieldDoc) {
searchHitFields.sortValues(fieldDoc.fields, subSearchContext.sort().formats);
searchHitFields.sortValues(fieldDoc.fields, subSearchContext.sort().formats());
}
}
return new InternalTopHits(
Expand Down Expand Up @@ -233,7 +233,7 @@ public SearchExecutionContext getSearchExecutionContext() {
public InternalTopHits buildEmptyAggregation() {
TopDocs topDocs;
if (subSearchContext.sort() != null) {
topDocs = new TopFieldDocs(Lucene.TOTAL_HITS_EQUAL_TO_ZERO, new FieldDoc[0], subSearchContext.sort().sort.getSort());
topDocs = new TopFieldDocs(Lucene.TOTAL_HITS_EQUAL_TO_ZERO, new FieldDoc[0], subSearchContext.sort().sort().getSort());
} else {
topDocs = Lucene.EMPTY_TOP_DOCS;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public void setChildInnerHits(Map<String, InnerHitSubContext> childInnerHits) {

protected Weight getInnerHitQueryWeight() throws IOException {
if (innerHitQueryWeight == null) {
final boolean needsScores = size() != 0 && (sort() == null || sort().sort.needsScores());
final boolean needsScores = size() != 0 && (sort() == null || sort().sort().needsScores());
innerHitQueryWeight = context.searcher()
.createWeight(context.searcher().rewrite(query()), needsScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES, 1f);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private void hitExecute(Map<String, InnerHitsContext.InnerHitSubContext> innerHi
if (results == null) {
hit.setInnerHits(results = new HashMap<>());
}
innerHitsContext.queryResult().topDocs(topDoc, innerHitsContext.sort() == null ? null : innerHitsContext.sort().formats);
innerHitsContext.queryResult().topDocs(topDoc, innerHitsContext.sort() == null ? null : innerHitsContext.sort().formats());
int[] docIdsToLoad = new int[topDoc.topDocs.scoreDocs.length];
for (int j = 0; j < topDoc.topDocs.scoreDocs.length; j++) {
docIdsToLoad[j] = topDoc.topDocs.scoreDocs[j].doc;
Expand All @@ -101,7 +101,7 @@ private void hitExecute(Map<String, InnerHitsContext.InnerHitSubContext> innerHi
SearchHit searchHitFields = internalHits[j];
searchHitFields.score(scoreDoc.score);
if (scoreDoc instanceof FieldDoc fieldDoc) {
searchHitFields.sortValues(fieldDoc.fields, innerHitsContext.sort().formats);
searchHitFields.sortValues(fieldDoc.fields, innerHitsContext.sort().formats());
}
}
var h = fetchResult.hits();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,10 @@ static void addCollectorsAndSearch(SearchContext searchContext) throws QueryPhas
// skip to the desired doc
if (after != null) {
query = new BooleanQuery.Builder().add(query, BooleanClause.Occur.MUST)
.add(new SearchAfterSortedDocQuery(searchContext.sort().sort, (FieldDoc) after), BooleanClause.Occur.FILTER)
.add(
new SearchAfterSortedDocQuery(searchContext.sort().sort(), (FieldDoc) after),
BooleanClause.Occur.FILTER
)
.build();
}
}
Expand Down Expand Up @@ -236,10 +239,10 @@ static void addCollectorsAndSearch(SearchContext searchContext) throws QueryPhas
* with <code>sortAndFormats</code>.
**/
private static boolean canEarlyTerminate(IndexReader reader, SortAndFormats sortAndFormats) {
if (sortAndFormats == null || sortAndFormats.sort == null) {
if (sortAndFormats == null || sortAndFormats.sort() == null) {
return false;
}
final Sort sort = sortAndFormats.sort;
final Sort sort = sortAndFormats.sort();
for (LeafReaderContext ctx : reader.leaves()) {
Sort indexSort = ctx.reader().getMetaData().getSort();
if (indexSort == null || Lucene.canEarlyTerminate(sort, indexSort) == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ protected TopDocsAndMaxScore reduceTopDocsCollectors(Collection<Collector> colle
);
final TopDocs topDocs;
if (sortAndFormats != null) {
topDocs = new TopFieldDocs(totalHits, Lucene.EMPTY_SCORE_DOCS, sortAndFormats.sort.getSort());
topDocs = new TopFieldDocs(totalHits, Lucene.EMPTY_SCORE_DOCS, sortAndFormats.sort().getSort());
} else {
topDocs = new TopDocs(totalHits, Lucene.EMPTY_SCORE_DOCS);
}
Expand Down Expand Up @@ -393,7 +393,7 @@ private static class WithHits extends QueryPhaseCollectorManager {
this.trackMaxScore = trackMaxScore;

final int hitCountThreshold;
if ((sortAndFormats == null || SortField.FIELD_SCORE.equals(sortAndFormats.sort.getSort()[0])) && hasInfMaxScore(query)) {
if ((sortAndFormats == null || SortField.FIELD_SCORE.equals(sortAndFormats.sort().getSort()[0])) && hasInfMaxScore(query)) {
// disable max score optimization since we have a mandatory clause
// that doesn't track the maximum score
hitCountThreshold = Integer.MAX_VALUE;
Expand All @@ -417,7 +417,12 @@ private static class WithHits extends QueryPhaseCollectorManager {
if (sortAndFormats == null) {
this.topDocsManager = new TopScoreDocCollectorManager(numHits, searchAfter, hitCountThreshold);
} else {
this.topDocsManager = new TopFieldCollectorManager(sortAndFormats.sort, numHits, (FieldDoc) searchAfter, hitCountThreshold);
this.topDocsManager = new TopFieldCollectorManager(
sortAndFormats.sort(),
numHits,
(FieldDoc) searchAfter,
hitCountThreshold
);
}
}

Expand Down Expand Up @@ -465,7 +470,7 @@ protected TopDocsAndMaxScore reduceTopDocsCollectors(Collection<Collector> colle

@Override
protected final DocValueFormat[] getSortValueFormats() {
return sortAndFormats == null ? null : sortAndFormats.formats;
return sortAndFormats == null ? null : sortAndFormats.formats();
}
}

Expand Down Expand Up @@ -547,7 +552,7 @@ private static QueryPhaseCollectorManager forCollapsing(
) {
assert numHits > 0;
assert collapseContext != null;
Sort sort = sortAndFormats == null ? Sort.RELEVANCE : sortAndFormats.sort;
Sort sort = sortAndFormats == null ? Sort.RELEVANCE : sortAndFormats.sort();
final SinglePassGroupingCollector<?> topDocsCollector = collapseContext.createTopDocs(sort, numHits, after);
MaxScoreCollector maxScoreCollector = trackMaxScore ? new MaxScoreCollector() : null;
return new QueryPhaseCollectorManager(postFilterWeight, terminateAfterChecker, aggsCollectorManager, minScore, profile) {
Expand All @@ -570,7 +575,7 @@ protected TopDocsAndMaxScore reduceTopDocsCollectors(Collection<Collector> colle

@Override
protected DocValueFormat[] getSortValueFormats() {
return sortAndFormats == null ? new DocValueFormat[] { DocValueFormat.RAW } : sortAndFormats.formats;
return sortAndFormats == null ? new DocValueFormat[] { DocValueFormat.RAW } : sortAndFormats.formats();
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,14 @@ public Object[] getSortValues() {
}

public static FieldDoc buildFieldDoc(SortAndFormats sort, Object[] values, @Nullable String collapseField) {
if (sort == null || sort.sort.getSort() == null || sort.sort.getSort().length == 0) {
if (sort == null || sort.sort().getSort() == null || sort.sort().getSort().length == 0) {
throw new IllegalArgumentException("Sort must contain at least one field.");
}

SortField[] sortFields = sort.sort.getSort();
SortField[] sortFields = sort.sort().getSort();
if (sortFields.length != values.length) {
throw new IllegalArgumentException(
SEARCH_AFTER.getPreferredName() + " has " + values.length + " value(s) but sort has " + sort.sort.getSort().length + "."
SEARCH_AFTER.getPreferredName() + " has " + values.length + " value(s) but sort has " + sort.sort().getSort().length + "."
);
}

Expand All @@ -113,7 +113,7 @@ public static FieldDoc buildFieldDoc(SortAndFormats sort, Object[] values, @Null
Object[] fieldValues = new Object[sortFields.length];
for (int i = 0; i < sortFields.length; i++) {
SortField sortField = sortFields[i];
DocValueFormat format = sort.formats[i];
DocValueFormat format = sort.formats()[i];
if (values[i] != null) {
fieldValues[i] = convertValueFromSortField(values[i], sortField, format);
} else {
Expand Down Expand Up @@ -278,8 +278,7 @@ public boolean equals(Object other) {
if ((other instanceof SearchAfterBuilder) == false) {
return false;
}
boolean value = Arrays.equals(sortValues, ((SearchAfterBuilder) other).sortValues);
return value;
return Arrays.equals(sortValues, ((SearchAfterBuilder) other).sortValues);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ public static FieldSortBuilder getPrimaryFieldSortOrNull(SearchSourceBuilder sou
*/
public static MinAndMax<?> getMinMaxOrNull(SearchExecutionContext context, FieldSortBuilder sortBuilder) throws IOException {
SortAndFormats sort = SortBuilder.buildSort(Collections.singletonList(sortBuilder), context).get();
SortField sortField = sort.sort.getSort()[0];
SortField sortField = sort.sort().getSort()[0];
if (sortField.getField() == null) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ public GeoDistanceSortBuilder points(GeoPoint... points) {
* Returns the points to create the range distance facets from.
*/
public GeoPoint[] points() {
return this.points.toArray(new GeoPoint[this.points.size()]);
return this.points.toArray(new GeoPoint[0]);
}

/**
Expand Down Expand Up @@ -503,7 +503,7 @@ public static GeoDistanceSortBuilder fromXContent(XContentParser parser, String
}
}

GeoDistanceSortBuilder result = new GeoDistanceSortBuilder(fieldName, geoPoints.toArray(new GeoPoint[geoPoints.size()]));
GeoDistanceSortBuilder result = new GeoDistanceSortBuilder(fieldName, geoPoints.toArray(new GeoPoint[0]));
result.geoDistance(geoDistance);
result.unit(unit);
result.order(order);
Expand Down Expand Up @@ -568,7 +568,7 @@ public BucketedSort buildBucketedSort(SearchExecutionContext context, BigArrays
private GeoPoint[] localPoints() {
// validation was not available prior to 2.x, so to support bwc percolation queries we only ignore_malformed
// on 2.x created indexes
GeoPoint[] localPoints = points.toArray(new GeoPoint[points.size()]);
GeoPoint[] localPoints = points.toArray(new GeoPoint[0]);
if (GeoValidationMethod.isIgnoreMalformed(validation) == false) {
for (GeoPoint point : localPoints) {
if (GeoUtils.isValidLatitude(point.lat()) == false) {
Expand Down
27 changes: 5 additions & 22 deletions server/src/main/java/org/elasticsearch/search/sort/MinAndMax.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,16 @@
/**
* A class that encapsulates a minimum and a maximum, that are of the same type and {@link Comparable}.
*/
public class MinAndMax<T extends Comparable<? super T>> implements Writeable {
private final T minValue;
private final T maxValue;
public record MinAndMax<T extends Comparable<? super T>>(T minValue, T maxValue) implements Writeable {

public MinAndMax(T minValue, T maxValue) {
this.minValue = Objects.requireNonNull(minValue);
this.maxValue = Objects.requireNonNull(maxValue);
}

@SuppressWarnings("unchecked")
public MinAndMax(StreamInput in) throws IOException {
this.minValue = (T) Lucene.readSortValue(in);
this.maxValue = (T) Lucene.readSortValue(in);
public static <T extends Comparable<? super T>> MinAndMax<T> readFrom(StreamInput in) throws IOException {
return new MinAndMax<>((T) Lucene.readSortValue(in), (T) Lucene.readSortValue(in));
}

@Override
Expand All @@ -42,34 +39,20 @@ public void writeTo(StreamOutput out) throws IOException {
Lucene.writeSortValue(out, maxValue);
}

/**
* Return the minimum value.
*/
public T getMin() {
return minValue;
}

/**
* Return the maximum value.
*/
public T getMax() {
return maxValue;
}

@SuppressWarnings({ "unchecked", "rawtypes" })
private static final Comparator<MinAndMax> ASC_COMPARATOR = (left, right) -> {
if (left == null) {
return right == null ? 0 : -1; // nulls last
}
return right == null ? 1 : left.getMin().compareTo(right.getMin());
return right == null ? 1 : left.minValue.compareTo(right.minValue);
};

@SuppressWarnings({ "unchecked", "rawtypes" })
private static final Comparator<MinAndMax> DESC_COMPARATOR = (left, right) -> {
if (left == null) {
return right == null ? 0 : 1; // nulls first
}
return right == null ? -1 : right.getMax().compareTo(left.getMax());
return right == null ? -1 : right.maxValue.compareTo(left.maxValue);
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.Objects;

/**
* A sort builder allowing to sort by score.
Expand Down Expand Up @@ -151,12 +150,12 @@ public boolean equals(Object object) {
return false;
}
ScoreSortBuilder other = (ScoreSortBuilder) object;
return Objects.equals(order, other.order);
return order.equals(other.order);
}

@Override
public int hashCode() {
return Objects.hash(this.order);
return this.order.hashCode();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,12 @@
import org.apache.lucene.search.Sort;
import org.elasticsearch.search.DocValueFormat;

public final class SortAndFormats {
public record SortAndFormats(Sort sort, DocValueFormat[] formats) {

public final Sort sort;
public final DocValueFormat[] formats;

public SortAndFormats(Sort sort, DocValueFormat[] formats) {
public SortAndFormats {
if (sort.getSort().length != formats.length) {
throw new IllegalArgumentException("Number of sort field mismatch: " + sort.getSort().length + " != " + formats.length);
}
this.sort = sort;
this.formats = formats;
}

}
Loading

0 comments on commit 0c1418b

Please sign in to comment.