diff --git a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java index 89411ac302b10..6c654d9235ec2 100644 --- a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java +++ b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java @@ -19,7 +19,6 @@ import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.core.Releasable; -import org.elasticsearch.core.Releasables; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.SearchShardTarget; @@ -121,26 +120,50 @@ public void consumeResult(SearchPhaseResult result, Runnable next) { public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { if (pendingMerges.hasPendingMerges()) { throw new AssertionError("partial reduce in-flight"); - } else if (pendingMerges.hasFailure()) { - throw pendingMerges.getFailure(); + } + Exception failure = pendingMerges.failure.get(); + if (failure != null) { + throw failure; } // ensure consistent ordering pendingMerges.sortBuffer(); - final TopDocsStats topDocsStats = pendingMerges.consumeTopDocsStats(); - final List topDocsList = pendingMerges.consumeTopDocs(); + final TopDocsStats topDocsStats = pendingMerges.topDocsStats; + final int resultSize = pendingMerges.buffer.size() + (pendingMerges.mergeResult == null ? 0 : 1); + final List topDocsList = hasTopDocs ? new ArrayList<>(resultSize) : null; + final List> aggsList = hasAggs ? new ArrayList<>(resultSize) : null; + synchronized (pendingMerges) { + if (pendingMerges.mergeResult != null) { + if (topDocsList != null) { + topDocsList.add(pendingMerges.mergeResult.reducedTopDocs); + } + if (aggsList != null) { + aggsList.add(DelayableWriteable.referencing(pendingMerges.mergeResult.reducedAggs)); + } + } + for (QuerySearchResult result : pendingMerges.buffer) { + topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly()); + if (topDocsList != null) { + TopDocsAndMaxScore topDocs = result.consumeTopDocs(); + setShardIndex(topDocs.topDocs, result.getShardIndex()); + topDocsList.add(topDocs.topDocs); + } + if (aggsList != null) { + aggsList.add(result.getAggs()); + } + } + } SearchPhaseController.ReducedQueryPhase reducePhase; long breakerSize = pendingMerges.circuitBreakerBytes; try { - final List> aggsList = pendingMerges.getAggs(); - if (hasAggs) { + if (aggsList != null) { // Add an estimate of the final reduce size breakerSize = pendingMerges.addEstimateAndMaybeBreak(PendingMerges.estimateRamBytesUsedForReduce(breakerSize)); } reducePhase = SearchPhaseController.reducedQueryPhase( results.asList(), aggsList, - topDocsList, + topDocsList == null ? Collections.emptyList() : topDocsList, topDocsStats, pendingMerges.numReducePhases, false, @@ -183,65 +206,59 @@ private MergeResult partialReduce( // ensure consistent ordering Arrays.sort(toConsume, RESULT_COMPARATOR); - for (QuerySearchResult result : toConsume) { - topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly()); - } - + final List processedShards = new ArrayList<>(emptyResults); final TopDocs newTopDocs; + final InternalAggregations newAggs; + final List> aggsList; + final int resultSetSize = toConsume.length + (lastMerge != null ? 1 : 0); + if (hasAggs) { + aggsList = new ArrayList<>(resultSetSize); + if (lastMerge != null) { + aggsList.add(DelayableWriteable.referencing(lastMerge.reducedAggs)); + } + } else { + aggsList = null; + } + List topDocsList; if (hasTopDocs) { - List topDocsList = new ArrayList<>(); + topDocsList = new ArrayList<>(resultSetSize); if (lastMerge != null) { topDocsList.add(lastMerge.reducedTopDocs); } - for (QuerySearchResult result : toConsume) { - TopDocsAndMaxScore topDocs = result.consumeTopDocs(); - setShardIndex(topDocs.topDocs, result.getShardIndex()); - topDocsList.add(topDocs.topDocs); - } - newTopDocs = mergeTopDocs( - topDocsList, - // we have to merge here in the same way we collect on a shard - topNSize, - 0 - ); } else { - newTopDocs = null; + topDocsList = null; } - - final InternalAggregations newAggs; - if (hasAggs) { - try { - final List> aggsList = new ArrayList<>(); - if (lastMerge != null) { - aggsList.add(DelayableWriteable.referencing(lastMerge.reducedAggs)); - } - for (QuerySearchResult result : toConsume) { + try { + for (QuerySearchResult result : toConsume) { + topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly()); + SearchShardTarget target = result.getSearchShardTarget(); + processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId())); + if (aggsList != null) { aggsList.add(result.getAggs()); } - newAggs = InternalAggregations.topLevelReduceDelayable(aggsList, aggReduceContextBuilder.forPartialReduction()); - } finally { - for (QuerySearchResult result : toConsume) { - result.releaseAggs(); + if (topDocsList != null) { + TopDocsAndMaxScore topDocs = result.consumeTopDocs(); + setShardIndex(topDocs.topDocs, result.getShardIndex()); + topDocsList.add(topDocs.topDocs); } } - } else { - newAggs = null; + // we have to merge here in the same way we collect on a shard + newTopDocs = topDocsList == null ? null : mergeTopDocs(topDocsList, topNSize, 0); + newAggs = aggsList == null + ? null + : InternalAggregations.topLevelReduceDelayable(aggsList, aggReduceContextBuilder.forPartialReduction()); + } finally { + releaseAggs(toConsume); } - List processedShards = new ArrayList<>(emptyResults); if (lastMerge != null) { processedShards.addAll(lastMerge.processedShards); } - for (QuerySearchResult result : toConsume) { - SearchShardTarget target = result.getSearchShardTarget(); - processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId())); - } if (progressListener != SearchProgressListener.NOOP) { progressListener.notifyPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases); } // we leave the results un-serialized because serializing is slow but we compute the serialized // size as an estimate of the memory used by the newly reduced aggregations. - long serializedSize = hasAggs ? DelayableWriteable.getSerializedSize(newAggs) : 0; - return new MergeResult(processedShards, newTopDocs, newAggs, hasAggs ? serializedSize : 0); + return new MergeResult(processedShards, newTopDocs, newAggs, newAggs != null ? DelayableWriteable.getSerializedSize(newAggs) : 0); } public int getNumReducePhases() { @@ -274,11 +291,7 @@ private class PendingMerges implements Releasable { @Override public synchronized void close() { - if (hasFailure()) { - assert circuitBreakerBytes == 0; - } else { - assert circuitBreakerBytes >= 0; - } + assert assertFailureAndBreakerConsistent(); releaseBuffer(); circuitBreaker.addWithoutBreaking(-circuitBreakerBytes); @@ -290,8 +303,14 @@ public synchronized void close() { } } - synchronized Exception getFailure() { - return failure.get(); + private boolean assertFailureAndBreakerConsistent() { + boolean hasFailure = failure.get() != null; + if (hasFailure) { + assert circuitBreakerBytes == 0; + } else { + assert circuitBreakerBytes >= 0; + } + return true; } boolean hasFailure() { @@ -342,56 +361,71 @@ static long estimateRamBytesUsedForReduce(long size) { } public void consume(QuerySearchResult result, Runnable next) { - boolean executeNextImmediately = true; - synchronized (this) { - if (hasFailure() || result.isNull()) { - result.consumeAll(); - if (result.isNull()) { - SearchShardTarget target = result.getSearchShardTarget(); - emptyResults.add(new SearchShard(target.getClusterAlias(), target.getShardId())); - } - } else { - if (hasAggs) { - long aggsSize = ramBytesUsedQueryResult(result); - try { - addEstimateAndMaybeBreak(aggsSize); - } catch (Exception exc) { - result.releaseAggs(); - releaseBuffer(); - onMergeFailure(exc); - next.run(); - return; + if (hasFailure()) { + result.consumeAll(); + next.run(); + } else if (result.isNull()) { + result.consumeAll(); + SearchShardTarget target = result.getSearchShardTarget(); + SearchShard searchShard = new SearchShard(target.getClusterAlias(), target.getShardId()); + synchronized (this) { + emptyResults.add(searchShard); + } + next.run(); + } else { + final long aggsSize = ramBytesUsedQueryResult(result); + boolean executeNextImmediately = true; + boolean hasFailure = false; + synchronized (this) { + if (hasFailure()) { + hasFailure = true; + } else { + if (hasAggs) { + try { + addEstimateAndMaybeBreak(aggsSize); + } catch (Exception exc) { + releaseBuffer(); + onMergeFailure(exc); + hasFailure = true; + } + } + if (hasFailure == false) { + aggsCurrentBufferSize += aggsSize; + // add one if a partial merge is pending + int size = buffer.size() + (hasPartialReduce ? 1 : 0); + if (size >= batchReduceSize) { + hasPartialReduce = true; + executeNextImmediately = false; + QuerySearchResult[] clone = buffer.toArray(QuerySearchResult[]::new); + MergeTask task = new MergeTask(clone, aggsCurrentBufferSize, new ArrayList<>(emptyResults), next); + aggsCurrentBufferSize = 0; + buffer.clear(); + emptyResults.clear(); + queue.add(task); + tryExecuteNext(); + } + buffer.add(result); } - aggsCurrentBufferSize += aggsSize; - } - // add one if a partial merge is pending - int size = buffer.size() + (hasPartialReduce ? 1 : 0); - if (size >= batchReduceSize) { - hasPartialReduce = true; - executeNextImmediately = false; - QuerySearchResult[] clone = buffer.toArray(QuerySearchResult[]::new); - MergeTask task = new MergeTask(clone, aggsCurrentBufferSize, new ArrayList<>(emptyResults), next); - aggsCurrentBufferSize = 0; - buffer.clear(); - emptyResults.clear(); - queue.add(task); - tryExecuteNext(); } - buffer.add(result); } - } - if (executeNextImmediately) { - next.run(); + if (hasFailure) { + result.consumeAll(); + } + if (executeNextImmediately) { + next.run(); + } } } private void releaseBuffer() { - buffer.forEach(QuerySearchResult::releaseAggs); + for (QuerySearchResult querySearchResult : buffer) { + querySearchResult.releaseAggs(); + } buffer.clear(); } private synchronized void onMergeFailure(Exception exc) { - if (hasFailure()) { + if (failure.compareAndSet(null, exc) == false) { assert circuitBreakerBytes == 0; return; } @@ -401,79 +435,89 @@ private synchronized void onMergeFailure(Exception exc) { circuitBreaker.addWithoutBreaking(-circuitBreakerBytes); circuitBreakerBytes = 0; } - failure.compareAndSet(null, exc); - final List toCancels = new ArrayList<>(); - toCancels.add(() -> onPartialMergeFailure.accept(exc)); + onPartialMergeFailure.accept(exc); final MergeTask task = runningTask.getAndSet(null); if (task != null) { - toCancels.add(task::cancel); + task.cancel(); } MergeTask mergeTask; while ((mergeTask = queue.pollFirst()) != null) { - toCancels.add(mergeTask::cancel); + mergeTask.cancel(); } mergeResult = null; - Releasables.close(toCancels); - } - - private void onAfterMerge(MergeTask task, MergeResult newResult, long estimatedSize) { - synchronized (this) { - if (hasFailure()) { - return; - } - runningTask.compareAndSet(task, null); - mergeResult = newResult; - if (hasAggs) { - // Update the circuit breaker to remove the size of the source aggregations - // and replace the estimation with the serialized size of the newly reduced result. - long newSize = mergeResult.estimatedSize - estimatedSize; - addWithoutBreaking(newSize); - logger.trace( - "aggs partial reduction [{}->{}] max [{}]", - estimatedSize, - mergeResult.estimatedSize, - maxAggsCurrentBufferSize - ); - } - task.consumeListener(); - } } private void tryExecuteNext() { final MergeTask task; synchronized (this) { - if (queue.isEmpty() || hasFailure() || runningTask.get() != null) { + if (hasFailure() || runningTask.get() != null) { return; } task = queue.poll(); - runningTask.compareAndSet(null, task); + runningTask.set(task); + } + if (task == null) { + return; } executor.execute(new AbstractRunnable() { @Override protected void doRun() { - final MergeResult thisMergeResult = mergeResult; - long estimatedTotalSize = (thisMergeResult != null ? thisMergeResult.estimatedSize : 0) + task.aggsBufferSize; - final MergeResult newMerge; - final QuerySearchResult[] toConsume = task.consumeBuffer(); - if (toConsume == null) { - return; - } - try { - long estimatedMergeSize = estimateRamBytesUsedForReduce(estimatedTotalSize); - addEstimateAndMaybeBreak(estimatedMergeSize); - estimatedTotalSize += estimatedMergeSize; - ++numReducePhases; - newMerge = partialReduce(toConsume, task.emptyResults, topDocsStats, thisMergeResult, numReducePhases); - } catch (Exception t) { - for (QuerySearchResult result : toConsume) { - result.releaseAggs(); + MergeTask mergeTask = task; + QuerySearchResult[] toConsume = mergeTask.consumeBuffer(); + while (mergeTask != null) { + final MergeResult thisMergeResult = mergeResult; + long estimatedTotalSize = (thisMergeResult != null ? thisMergeResult.estimatedSize : 0) + mergeTask.aggsBufferSize; + final MergeResult newMerge; + try { + long estimatedMergeSize = estimateRamBytesUsedForReduce(estimatedTotalSize); + addEstimateAndMaybeBreak(estimatedMergeSize); + estimatedTotalSize += estimatedMergeSize; + ++numReducePhases; + newMerge = partialReduce(toConsume, mergeTask.emptyResults, topDocsStats, thisMergeResult, numReducePhases); + } catch (Exception t) { + QueryPhaseResultConsumer.releaseAggs(toConsume); + onMergeFailure(t); + return; + } + synchronized (QueryPhaseResultConsumer.this) { + if (hasFailure()) { + return; + } + mergeResult = newMerge; + if (hasAggs) { + // Update the circuit breaker to remove the size of the source aggregations + // and replace the estimation with the serialized size of the newly reduced result. + long newSize = mergeResult.estimatedSize - estimatedTotalSize; + addWithoutBreaking(newSize); + if (logger.isTraceEnabled()) { + logger.trace( + "aggs partial reduction [{}->{}] max [{}]", + estimatedTotalSize, + mergeResult.estimatedSize, + maxAggsCurrentBufferSize + ); + } + } + } + Runnable r = mergeTask.consumeListener(); + synchronized (QueryPhaseResultConsumer.this) { + while (true) { + mergeTask = queue.poll(); + runningTask.set(mergeTask); + if (mergeTask == null) { + break; + } + toConsume = mergeTask.consumeBuffer(); + if (toConsume != null) { + break; + } + } + } + if (r != null) { + r.run(); } - onMergeFailure(t); - return; } - onAfterMerge(task, newMerge, estimatedTotalSize); - tryExecuteNext(); } @Override @@ -483,43 +527,6 @@ public void onFailure(Exception exc) { }); } - public synchronized TopDocsStats consumeTopDocsStats() { - for (QuerySearchResult result : buffer) { - topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly()); - } - return topDocsStats; - } - - public synchronized List consumeTopDocs() { - if (hasTopDocs == false) { - return Collections.emptyList(); - } - List topDocsList = new ArrayList<>(); - if (mergeResult != null) { - topDocsList.add(mergeResult.reducedTopDocs); - } - for (QuerySearchResult result : buffer) { - TopDocsAndMaxScore topDocs = result.consumeTopDocs(); - setShardIndex(topDocs.topDocs, result.getShardIndex()); - topDocsList.add(topDocs.topDocs); - } - return topDocsList; - } - - public synchronized List> getAggs() { - if (hasAggs == false) { - return Collections.emptyList(); - } - List> aggsList = new ArrayList<>(); - if (mergeResult != null) { - aggsList.add(DelayableWriteable.referencing(mergeResult.reducedAggs)); - } - for (QuerySearchResult result : buffer) { - aggsList.add(result.getAggs()); - } - return aggsList; - } - public synchronized void releaseAggs() { if (hasAggs) { for (QuerySearchResult result : buffer) { @@ -529,6 +536,12 @@ public synchronized void releaseAggs() { } } + private static void releaseAggs(QuerySearchResult... toConsume) { + for (QuerySearchResult result : toConsume) { + result.releaseAggs(); + } + } + private record MergeResult( List processedShards, TopDocs reducedTopDocs, @@ -555,21 +568,21 @@ public synchronized QuerySearchResult[] consumeBuffer() { return toRet; } - public void consumeListener() { - if (next != null) { - next.run(); - next = null; - } + public synchronized Runnable consumeListener() { + Runnable n = next; + next = null; + return n; } - public synchronized void cancel() { + public void cancel() { QuerySearchResult[] buffer = consumeBuffer(); if (buffer != null) { - for (QuerySearchResult result : buffer) { - result.releaseAggs(); - } + releaseAggs(buffer); + } + Runnable next = consumeListener(); + if (next != null) { + next.run(); } - consumeListener(); } } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index a6acb3ee2a52e..47dde104a28b8 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -29,6 +29,7 @@ import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.core.Nullable; import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.lucene.grouping.TopFieldGroups; import org.elasticsearch.search.DocValueFormat; @@ -188,7 +189,7 @@ public static List mergeKnnResults(SearchRequest request, List topDocs, + final List topDocs, int from, int size, List reducedCompletionSuggestions @@ -231,22 +232,22 @@ static SortedTopDocs sortDocs( return new SortedTopDocs(scoreDocs, isSortedByField, sortFields, groupField, groupValues, numSuggestDocs); } - static TopDocs mergeTopDocs(Collection results, int topN, int from) { + static TopDocs mergeTopDocs(List results, int topN, int from) { if (results.isEmpty()) { return null; } - final TopDocs topDocs = results.stream().findFirst().get(); + final TopDocs topDocs = results.getFirst(); final TopDocs mergedTopDocs; final int numShards = results.size(); if (numShards == 1 && from == 0) { // only one shard and no pagination we can just return the topDocs as we got them. return topDocs; } else if (topDocs instanceof TopFieldGroups firstTopDocs) { final Sort sort = new Sort(firstTopDocs.fields); - final TopFieldGroups[] shardTopDocs = results.toArray(new TopFieldGroups[numShards]); + final TopFieldGroups[] shardTopDocs = results.toArray(new TopFieldGroups[0]); mergedTopDocs = TopFieldGroups.merge(sort, from, topN, shardTopDocs, false); } else if (topDocs instanceof TopFieldDocs firstTopDocs) { final Sort sort = checkSameSortTypes(results, firstTopDocs.fields); - final TopFieldDocs[] shardTopDocs = results.toArray(new TopFieldDocs[numShards]); + final TopFieldDocs[] shardTopDocs = results.toArray(new TopFieldDocs[0]); mergedTopDocs = TopDocs.merge(sort, from, topN, shardTopDocs); } else { final TopDocs[] shardTopDocs = results.toArray(new TopDocs[numShards]); @@ -516,17 +517,7 @@ public AggregationReduceContext forFinalReduction() { topDocs.add(td.topDocs); } } - return reducedQueryPhase( - queryResults, - Collections.emptyList(), - topDocs, - topDocsStats, - 0, - true, - aggReduceContextBuilder, - null, - true - ); + return reducedQueryPhase(queryResults, null, topDocs, topDocsStats, 0, true, aggReduceContextBuilder, null, true); } /** @@ -540,7 +531,7 @@ public AggregationReduceContext forFinalReduction() { */ static ReducedQueryPhase reducedQueryPhase( Collection queryResults, - List> bufferedAggs, + @Nullable List> bufferedAggs, List bufferedTopDocs, TopDocsStats topDocsStats, int numReducePhases, @@ -634,7 +625,12 @@ static ReducedQueryPhase reducedQueryPhase( reducedSuggest = new Suggest(Suggest.reduce(groupedSuggestions)); reducedCompletionSuggestions = reducedSuggest.filter(CompletionSuggestion.class); } - final InternalAggregations aggregations = reduceAggs(aggReduceContextBuilder, performFinalReduce, bufferedAggs); + final InternalAggregations aggregations = bufferedAggs == null + ? null + : InternalAggregations.topLevelReduceDelayable( + bufferedAggs, + performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction() + ); final SearchProfileResultsBuilder profileBuilder = profileShardResults.isEmpty() ? null : new SearchProfileResultsBuilder(profileShardResults); @@ -673,19 +669,6 @@ static ReducedQueryPhase reducedQueryPhase( ); } - private static InternalAggregations reduceAggs( - AggregationReduceContext.Builder aggReduceContextBuilder, - boolean performFinalReduce, - List> toReduce - ) { - return toReduce.isEmpty() - ? null - : InternalAggregations.topLevelReduceDelayable( - toReduce, - performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction() - ); - } - /** * Checks that query results from all shards have consistent unsigned_long format. * Sort queries on a field that has long type in one index, and unsigned_long in another index