From 06dad494aeaf20bc0473df7cc7f0652625766dfe Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Sun, 22 Sep 2024 15:01:27 +0200 Subject: [PATCH] Speedup Query Phase Merging Reducing contention and context switching in merging for the query phase by avoiding respining the merge task repeatedly, removing things that don't need synchronization from the synchronized blocks and merging repeated loops over the same query result arrays. --- .../search/QueryPhaseResultConsumer.java | 395 +++++++++--------- .../action/search/SearchPhaseController.java | 45 +- 2 files changed, 218 insertions(+), 222 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java index 89411ac302b10..6c654d9235ec2 100644 --- a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java +++ b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java @@ -19,7 +19,6 @@ import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.core.Releasable; -import org.elasticsearch.core.Releasables; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.SearchShardTarget; @@ -121,26 +120,50 @@ public void consumeResult(SearchPhaseResult result, Runnable next) { public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { if (pendingMerges.hasPendingMerges()) { throw new AssertionError("partial reduce in-flight"); - } else if (pendingMerges.hasFailure()) { - throw pendingMerges.getFailure(); + } + Exception failure = pendingMerges.failure.get(); + if (failure != null) { + throw failure; } // ensure consistent ordering pendingMerges.sortBuffer(); - final TopDocsStats topDocsStats = pendingMerges.consumeTopDocsStats(); - final List topDocsList = pendingMerges.consumeTopDocs(); + final TopDocsStats topDocsStats = pendingMerges.topDocsStats; + final int resultSize = pendingMerges.buffer.size() + (pendingMerges.mergeResult == null ? 0 : 1); + final List topDocsList = hasTopDocs ? new ArrayList<>(resultSize) : null; + final List> aggsList = hasAggs ? new ArrayList<>(resultSize) : null; + synchronized (pendingMerges) { + if (pendingMerges.mergeResult != null) { + if (topDocsList != null) { + topDocsList.add(pendingMerges.mergeResult.reducedTopDocs); + } + if (aggsList != null) { + aggsList.add(DelayableWriteable.referencing(pendingMerges.mergeResult.reducedAggs)); + } + } + for (QuerySearchResult result : pendingMerges.buffer) { + topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly()); + if (topDocsList != null) { + TopDocsAndMaxScore topDocs = result.consumeTopDocs(); + setShardIndex(topDocs.topDocs, result.getShardIndex()); + topDocsList.add(topDocs.topDocs); + } + if (aggsList != null) { + aggsList.add(result.getAggs()); + } + } + } SearchPhaseController.ReducedQueryPhase reducePhase; long breakerSize = pendingMerges.circuitBreakerBytes; try { - final List> aggsList = pendingMerges.getAggs(); - if (hasAggs) { + if (aggsList != null) { // Add an estimate of the final reduce size breakerSize = pendingMerges.addEstimateAndMaybeBreak(PendingMerges.estimateRamBytesUsedForReduce(breakerSize)); } reducePhase = SearchPhaseController.reducedQueryPhase( results.asList(), aggsList, - topDocsList, + topDocsList == null ? Collections.emptyList() : topDocsList, topDocsStats, pendingMerges.numReducePhases, false, @@ -183,65 +206,59 @@ private MergeResult partialReduce( // ensure consistent ordering Arrays.sort(toConsume, RESULT_COMPARATOR); - for (QuerySearchResult result : toConsume) { - topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly()); - } - + final List processedShards = new ArrayList<>(emptyResults); final TopDocs newTopDocs; + final InternalAggregations newAggs; + final List> aggsList; + final int resultSetSize = toConsume.length + (lastMerge != null ? 1 : 0); + if (hasAggs) { + aggsList = new ArrayList<>(resultSetSize); + if (lastMerge != null) { + aggsList.add(DelayableWriteable.referencing(lastMerge.reducedAggs)); + } + } else { + aggsList = null; + } + List topDocsList; if (hasTopDocs) { - List topDocsList = new ArrayList<>(); + topDocsList = new ArrayList<>(resultSetSize); if (lastMerge != null) { topDocsList.add(lastMerge.reducedTopDocs); } - for (QuerySearchResult result : toConsume) { - TopDocsAndMaxScore topDocs = result.consumeTopDocs(); - setShardIndex(topDocs.topDocs, result.getShardIndex()); - topDocsList.add(topDocs.topDocs); - } - newTopDocs = mergeTopDocs( - topDocsList, - // we have to merge here in the same way we collect on a shard - topNSize, - 0 - ); } else { - newTopDocs = null; + topDocsList = null; } - - final InternalAggregations newAggs; - if (hasAggs) { - try { - final List> aggsList = new ArrayList<>(); - if (lastMerge != null) { - aggsList.add(DelayableWriteable.referencing(lastMerge.reducedAggs)); - } - for (QuerySearchResult result : toConsume) { + try { + for (QuerySearchResult result : toConsume) { + topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly()); + SearchShardTarget target = result.getSearchShardTarget(); + processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId())); + if (aggsList != null) { aggsList.add(result.getAggs()); } - newAggs = InternalAggregations.topLevelReduceDelayable(aggsList, aggReduceContextBuilder.forPartialReduction()); - } finally { - for (QuerySearchResult result : toConsume) { - result.releaseAggs(); + if (topDocsList != null) { + TopDocsAndMaxScore topDocs = result.consumeTopDocs(); + setShardIndex(topDocs.topDocs, result.getShardIndex()); + topDocsList.add(topDocs.topDocs); } } - } else { - newAggs = null; + // we have to merge here in the same way we collect on a shard + newTopDocs = topDocsList == null ? null : mergeTopDocs(topDocsList, topNSize, 0); + newAggs = aggsList == null + ? null + : InternalAggregations.topLevelReduceDelayable(aggsList, aggReduceContextBuilder.forPartialReduction()); + } finally { + releaseAggs(toConsume); } - List processedShards = new ArrayList<>(emptyResults); if (lastMerge != null) { processedShards.addAll(lastMerge.processedShards); } - for (QuerySearchResult result : toConsume) { - SearchShardTarget target = result.getSearchShardTarget(); - processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId())); - } if (progressListener != SearchProgressListener.NOOP) { progressListener.notifyPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases); } // we leave the results un-serialized because serializing is slow but we compute the serialized // size as an estimate of the memory used by the newly reduced aggregations. - long serializedSize = hasAggs ? DelayableWriteable.getSerializedSize(newAggs) : 0; - return new MergeResult(processedShards, newTopDocs, newAggs, hasAggs ? serializedSize : 0); + return new MergeResult(processedShards, newTopDocs, newAggs, newAggs != null ? DelayableWriteable.getSerializedSize(newAggs) : 0); } public int getNumReducePhases() { @@ -274,11 +291,7 @@ private class PendingMerges implements Releasable { @Override public synchronized void close() { - if (hasFailure()) { - assert circuitBreakerBytes == 0; - } else { - assert circuitBreakerBytes >= 0; - } + assert assertFailureAndBreakerConsistent(); releaseBuffer(); circuitBreaker.addWithoutBreaking(-circuitBreakerBytes); @@ -290,8 +303,14 @@ public synchronized void close() { } } - synchronized Exception getFailure() { - return failure.get(); + private boolean assertFailureAndBreakerConsistent() { + boolean hasFailure = failure.get() != null; + if (hasFailure) { + assert circuitBreakerBytes == 0; + } else { + assert circuitBreakerBytes >= 0; + } + return true; } boolean hasFailure() { @@ -342,56 +361,71 @@ static long estimateRamBytesUsedForReduce(long size) { } public void consume(QuerySearchResult result, Runnable next) { - boolean executeNextImmediately = true; - synchronized (this) { - if (hasFailure() || result.isNull()) { - result.consumeAll(); - if (result.isNull()) { - SearchShardTarget target = result.getSearchShardTarget(); - emptyResults.add(new SearchShard(target.getClusterAlias(), target.getShardId())); - } - } else { - if (hasAggs) { - long aggsSize = ramBytesUsedQueryResult(result); - try { - addEstimateAndMaybeBreak(aggsSize); - } catch (Exception exc) { - result.releaseAggs(); - releaseBuffer(); - onMergeFailure(exc); - next.run(); - return; + if (hasFailure()) { + result.consumeAll(); + next.run(); + } else if (result.isNull()) { + result.consumeAll(); + SearchShardTarget target = result.getSearchShardTarget(); + SearchShard searchShard = new SearchShard(target.getClusterAlias(), target.getShardId()); + synchronized (this) { + emptyResults.add(searchShard); + } + next.run(); + } else { + final long aggsSize = ramBytesUsedQueryResult(result); + boolean executeNextImmediately = true; + boolean hasFailure = false; + synchronized (this) { + if (hasFailure()) { + hasFailure = true; + } else { + if (hasAggs) { + try { + addEstimateAndMaybeBreak(aggsSize); + } catch (Exception exc) { + releaseBuffer(); + onMergeFailure(exc); + hasFailure = true; + } + } + if (hasFailure == false) { + aggsCurrentBufferSize += aggsSize; + // add one if a partial merge is pending + int size = buffer.size() + (hasPartialReduce ? 1 : 0); + if (size >= batchReduceSize) { + hasPartialReduce = true; + executeNextImmediately = false; + QuerySearchResult[] clone = buffer.toArray(QuerySearchResult[]::new); + MergeTask task = new MergeTask(clone, aggsCurrentBufferSize, new ArrayList<>(emptyResults), next); + aggsCurrentBufferSize = 0; + buffer.clear(); + emptyResults.clear(); + queue.add(task); + tryExecuteNext(); + } + buffer.add(result); } - aggsCurrentBufferSize += aggsSize; - } - // add one if a partial merge is pending - int size = buffer.size() + (hasPartialReduce ? 1 : 0); - if (size >= batchReduceSize) { - hasPartialReduce = true; - executeNextImmediately = false; - QuerySearchResult[] clone = buffer.toArray(QuerySearchResult[]::new); - MergeTask task = new MergeTask(clone, aggsCurrentBufferSize, new ArrayList<>(emptyResults), next); - aggsCurrentBufferSize = 0; - buffer.clear(); - emptyResults.clear(); - queue.add(task); - tryExecuteNext(); } - buffer.add(result); } - } - if (executeNextImmediately) { - next.run(); + if (hasFailure) { + result.consumeAll(); + } + if (executeNextImmediately) { + next.run(); + } } } private void releaseBuffer() { - buffer.forEach(QuerySearchResult::releaseAggs); + for (QuerySearchResult querySearchResult : buffer) { + querySearchResult.releaseAggs(); + } buffer.clear(); } private synchronized void onMergeFailure(Exception exc) { - if (hasFailure()) { + if (failure.compareAndSet(null, exc) == false) { assert circuitBreakerBytes == 0; return; } @@ -401,79 +435,89 @@ private synchronized void onMergeFailure(Exception exc) { circuitBreaker.addWithoutBreaking(-circuitBreakerBytes); circuitBreakerBytes = 0; } - failure.compareAndSet(null, exc); - final List toCancels = new ArrayList<>(); - toCancels.add(() -> onPartialMergeFailure.accept(exc)); + onPartialMergeFailure.accept(exc); final MergeTask task = runningTask.getAndSet(null); if (task != null) { - toCancels.add(task::cancel); + task.cancel(); } MergeTask mergeTask; while ((mergeTask = queue.pollFirst()) != null) { - toCancels.add(mergeTask::cancel); + mergeTask.cancel(); } mergeResult = null; - Releasables.close(toCancels); - } - - private void onAfterMerge(MergeTask task, MergeResult newResult, long estimatedSize) { - synchronized (this) { - if (hasFailure()) { - return; - } - runningTask.compareAndSet(task, null); - mergeResult = newResult; - if (hasAggs) { - // Update the circuit breaker to remove the size of the source aggregations - // and replace the estimation with the serialized size of the newly reduced result. - long newSize = mergeResult.estimatedSize - estimatedSize; - addWithoutBreaking(newSize); - logger.trace( - "aggs partial reduction [{}->{}] max [{}]", - estimatedSize, - mergeResult.estimatedSize, - maxAggsCurrentBufferSize - ); - } - task.consumeListener(); - } } private void tryExecuteNext() { final MergeTask task; synchronized (this) { - if (queue.isEmpty() || hasFailure() || runningTask.get() != null) { + if (hasFailure() || runningTask.get() != null) { return; } task = queue.poll(); - runningTask.compareAndSet(null, task); + runningTask.set(task); + } + if (task == null) { + return; } executor.execute(new AbstractRunnable() { @Override protected void doRun() { - final MergeResult thisMergeResult = mergeResult; - long estimatedTotalSize = (thisMergeResult != null ? thisMergeResult.estimatedSize : 0) + task.aggsBufferSize; - final MergeResult newMerge; - final QuerySearchResult[] toConsume = task.consumeBuffer(); - if (toConsume == null) { - return; - } - try { - long estimatedMergeSize = estimateRamBytesUsedForReduce(estimatedTotalSize); - addEstimateAndMaybeBreak(estimatedMergeSize); - estimatedTotalSize += estimatedMergeSize; - ++numReducePhases; - newMerge = partialReduce(toConsume, task.emptyResults, topDocsStats, thisMergeResult, numReducePhases); - } catch (Exception t) { - for (QuerySearchResult result : toConsume) { - result.releaseAggs(); + MergeTask mergeTask = task; + QuerySearchResult[] toConsume = mergeTask.consumeBuffer(); + while (mergeTask != null) { + final MergeResult thisMergeResult = mergeResult; + long estimatedTotalSize = (thisMergeResult != null ? thisMergeResult.estimatedSize : 0) + mergeTask.aggsBufferSize; + final MergeResult newMerge; + try { + long estimatedMergeSize = estimateRamBytesUsedForReduce(estimatedTotalSize); + addEstimateAndMaybeBreak(estimatedMergeSize); + estimatedTotalSize += estimatedMergeSize; + ++numReducePhases; + newMerge = partialReduce(toConsume, mergeTask.emptyResults, topDocsStats, thisMergeResult, numReducePhases); + } catch (Exception t) { + QueryPhaseResultConsumer.releaseAggs(toConsume); + onMergeFailure(t); + return; + } + synchronized (QueryPhaseResultConsumer.this) { + if (hasFailure()) { + return; + } + mergeResult = newMerge; + if (hasAggs) { + // Update the circuit breaker to remove the size of the source aggregations + // and replace the estimation with the serialized size of the newly reduced result. + long newSize = mergeResult.estimatedSize - estimatedTotalSize; + addWithoutBreaking(newSize); + if (logger.isTraceEnabled()) { + logger.trace( + "aggs partial reduction [{}->{}] max [{}]", + estimatedTotalSize, + mergeResult.estimatedSize, + maxAggsCurrentBufferSize + ); + } + } + } + Runnable r = mergeTask.consumeListener(); + synchronized (QueryPhaseResultConsumer.this) { + while (true) { + mergeTask = queue.poll(); + runningTask.set(mergeTask); + if (mergeTask == null) { + break; + } + toConsume = mergeTask.consumeBuffer(); + if (toConsume != null) { + break; + } + } + } + if (r != null) { + r.run(); } - onMergeFailure(t); - return; } - onAfterMerge(task, newMerge, estimatedTotalSize); - tryExecuteNext(); } @Override @@ -483,43 +527,6 @@ public void onFailure(Exception exc) { }); } - public synchronized TopDocsStats consumeTopDocsStats() { - for (QuerySearchResult result : buffer) { - topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly()); - } - return topDocsStats; - } - - public synchronized List consumeTopDocs() { - if (hasTopDocs == false) { - return Collections.emptyList(); - } - List topDocsList = new ArrayList<>(); - if (mergeResult != null) { - topDocsList.add(mergeResult.reducedTopDocs); - } - for (QuerySearchResult result : buffer) { - TopDocsAndMaxScore topDocs = result.consumeTopDocs(); - setShardIndex(topDocs.topDocs, result.getShardIndex()); - topDocsList.add(topDocs.topDocs); - } - return topDocsList; - } - - public synchronized List> getAggs() { - if (hasAggs == false) { - return Collections.emptyList(); - } - List> aggsList = new ArrayList<>(); - if (mergeResult != null) { - aggsList.add(DelayableWriteable.referencing(mergeResult.reducedAggs)); - } - for (QuerySearchResult result : buffer) { - aggsList.add(result.getAggs()); - } - return aggsList; - } - public synchronized void releaseAggs() { if (hasAggs) { for (QuerySearchResult result : buffer) { @@ -529,6 +536,12 @@ public synchronized void releaseAggs() { } } + private static void releaseAggs(QuerySearchResult... toConsume) { + for (QuerySearchResult result : toConsume) { + result.releaseAggs(); + } + } + private record MergeResult( List processedShards, TopDocs reducedTopDocs, @@ -555,21 +568,21 @@ public synchronized QuerySearchResult[] consumeBuffer() { return toRet; } - public void consumeListener() { - if (next != null) { - next.run(); - next = null; - } + public synchronized Runnable consumeListener() { + Runnable n = next; + next = null; + return n; } - public synchronized void cancel() { + public void cancel() { QuerySearchResult[] buffer = consumeBuffer(); if (buffer != null) { - for (QuerySearchResult result : buffer) { - result.releaseAggs(); - } + releaseAggs(buffer); + } + Runnable next = consumeListener(); + if (next != null) { + next.run(); } - consumeListener(); } } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index a6acb3ee2a52e..47dde104a28b8 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -29,6 +29,7 @@ import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.core.Nullable; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.lucene.grouping.TopFieldGroups; import org.elasticsearch.search.DocValueFormat; @@ -188,7 +189,7 @@ public static List mergeKnnResults(SearchRequest request, List topDocs, + final List topDocs, int from, int size, List reducedCompletionSuggestions @@ -231,22 +232,22 @@ static SortedTopDocs sortDocs( return new SortedTopDocs(scoreDocs, isSortedByField, sortFields, groupField, groupValues, numSuggestDocs); } - static TopDocs mergeTopDocs(Collection results, int topN, int from) { + static TopDocs mergeTopDocs(List results, int topN, int from) { if (results.isEmpty()) { return null; } - final TopDocs topDocs = results.stream().findFirst().get(); + final TopDocs topDocs = results.getFirst(); final TopDocs mergedTopDocs; final int numShards = results.size(); if (numShards == 1 && from == 0) { // only one shard and no pagination we can just return the topDocs as we got them. return topDocs; } else if (topDocs instanceof TopFieldGroups firstTopDocs) { final Sort sort = new Sort(firstTopDocs.fields); - final TopFieldGroups[] shardTopDocs = results.toArray(new TopFieldGroups[numShards]); + final TopFieldGroups[] shardTopDocs = results.toArray(new TopFieldGroups[0]); mergedTopDocs = TopFieldGroups.merge(sort, from, topN, shardTopDocs, false); } else if (topDocs instanceof TopFieldDocs firstTopDocs) { final Sort sort = checkSameSortTypes(results, firstTopDocs.fields); - final TopFieldDocs[] shardTopDocs = results.toArray(new TopFieldDocs[numShards]); + final TopFieldDocs[] shardTopDocs = results.toArray(new TopFieldDocs[0]); mergedTopDocs = TopDocs.merge(sort, from, topN, shardTopDocs); } else { final TopDocs[] shardTopDocs = results.toArray(new TopDocs[numShards]); @@ -516,17 +517,7 @@ public AggregationReduceContext forFinalReduction() { topDocs.add(td.topDocs); } } - return reducedQueryPhase( - queryResults, - Collections.emptyList(), - topDocs, - topDocsStats, - 0, - true, - aggReduceContextBuilder, - null, - true - ); + return reducedQueryPhase(queryResults, null, topDocs, topDocsStats, 0, true, aggReduceContextBuilder, null, true); } /** @@ -540,7 +531,7 @@ public AggregationReduceContext forFinalReduction() { */ static ReducedQueryPhase reducedQueryPhase( Collection queryResults, - List> bufferedAggs, + @Nullable List> bufferedAggs, List bufferedTopDocs, TopDocsStats topDocsStats, int numReducePhases, @@ -634,7 +625,12 @@ static ReducedQueryPhase reducedQueryPhase( reducedSuggest = new Suggest(Suggest.reduce(groupedSuggestions)); reducedCompletionSuggestions = reducedSuggest.filter(CompletionSuggestion.class); } - final InternalAggregations aggregations = reduceAggs(aggReduceContextBuilder, performFinalReduce, bufferedAggs); + final InternalAggregations aggregations = bufferedAggs == null + ? null + : InternalAggregations.topLevelReduceDelayable( + bufferedAggs, + performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction() + ); final SearchProfileResultsBuilder profileBuilder = profileShardResults.isEmpty() ? null : new SearchProfileResultsBuilder(profileShardResults); @@ -673,19 +669,6 @@ static ReducedQueryPhase reducedQueryPhase( ); } - private static InternalAggregations reduceAggs( - AggregationReduceContext.Builder aggReduceContextBuilder, - boolean performFinalReduce, - List> toReduce - ) { - return toReduce.isEmpty() - ? null - : InternalAggregations.topLevelReduceDelayable( - toReduce, - performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction() - ); - } - /** * Checks that query results from all shards have consistent unsigned_long format. * Sort queries on a field that has long type in one index, and unsigned_long in another index