Skip to content

Commit

Permalink
Expand the lifecycle of the AggregationContext (#94023)
Browse files Browse the repository at this point in the history
Relates to #89437

This PR enables ref counting on QuerySearchResult, and moves responsibility for releasing the BigArrays used in aggregations collection to QuerySearchResult.  This means the collection time circuit breakers will not be cleaned up until we have serialized the aggregations.

The AggregationContext currently manages all of aggregations memory. Rather than change that, this PR extends the life cycle of the AggregationContext so it isn't closed until QuerySearchResult is closed, at which point we have serialized the aggregation information back to the coordinating node.
---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
  • Loading branch information
not-napoleon and elasticmachine authored Apr 20, 2023
1 parent dd968f5 commit cb04885
Show file tree
Hide file tree
Showing 9 changed files with 253 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ private static Map<String, Object> prepareMap(SearchContext context, long tookIn
messageFields.put("elasticsearch.slowlog.message", context.indexShard().shardId());
messageFields.put("elasticsearch.slowlog.took", TimeValue.timeValueNanos(tookInNanos).toString());
messageFields.put("elasticsearch.slowlog.took_millis", TimeUnit.NANOSECONDS.toMillis(tookInNanos));
if (context.queryResult().getTotalHits() != null) {
messageFields.put("elasticsearch.slowlog.total_hits", context.queryResult().getTotalHits());
if (context.getTotalHits() != null) {
messageFields.put("elasticsearch.slowlog.total_hits", context.getTotalHits());
} else {
messageFields.put("elasticsearch.slowlog.total_hits", "-1");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.IndexService;
import org.elasticsearch.index.IndexSettings;
Expand Down Expand Up @@ -178,6 +179,7 @@ public void addFetchResult() {
@Override
public void addQueryResult() {
this.queryResult = new QuerySearchResult(this.readerContext.id(), this.shardTarget, this.request);
addReleasable(queryResult::decRef);
}

@Override
Expand Down Expand Up @@ -713,6 +715,10 @@ public QuerySearchResult queryResult() {
return queryResult;
}

public void addQuerySearchResultReleasable(Releasable releasable) {
queryResult.addReleasable(releasable);
}

@Override
public TotalHits getTotalHits() {
if (queryResult != null) {
Expand Down
72 changes: 59 additions & 13 deletions server/src/main/java/org/elasticsearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,8 @@ private DfsSearchResult executeDfsPhase(ShardSearchRequest request, SearchShardT
try (
Releasable scope = tracer.withScope(task);
Releasable ignored = readerContext.markAsUsed(getKeepAlive(request));
SearchContext context = createContext(readerContext, request, task, true)
SearchContext context = createContext(readerContext, request, task, ResultsType.DFS, false)
) {
context.addDfsResult();
dfsPhase.execute(context);
return context.dfsResult();
} catch (Exception e) {
Expand Down Expand Up @@ -623,15 +622,19 @@ private static <T> void runAsync(Executor executor, CheckedSupplier<T, Exception
executor.execute(ActionRunnable.supply(listener, executable::get));
}

/**
* The returned {@link SearchPhaseResult} will have had its ref count incremented by this method.
* It is the responsibility of the caller to ensure that the ref count is correctly decremented
* when the object is no longer needed.
*/
private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchShardTask task) throws Exception {
final ReaderContext readerContext = createOrGetReaderContext(request);
try (
Releasable scope = tracer.withScope(task);
Releasable ignored = readerContext.markAsUsed(getKeepAlive(request));
SearchContext context = createContext(readerContext, request, task, true)
SearchContext context = createContext(readerContext, request, task, ResultsType.QUERY, true)
) {
tracer.startTrace("executeQueryPhase", Map.of());
context.addQueryResult();
final long afterQueryTime;
try (SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(context)) {
loadOrExecuteQueryPhase(request, context);
Expand All @@ -643,6 +646,7 @@ private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchSh
tracer.stopTrace();
}
if (request.numberOfShards() == 1) {
// we already have query results, but we can run fetch at the same time
context.addFetchResult();
return executeFetchPhase(readerContext, context, afterQueryTime);
} else {
Expand All @@ -651,6 +655,7 @@ private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchSh
final RescoreDocIds rescoreDocIds = context.rescoreDocIds();
context.queryResult().setRescoreDocIds(rescoreDocIds);
readerContext.setRescoreDocIds(rescoreDocIds);
context.queryResult().incRef();
return context.queryResult();
}
} catch (Exception e) {
Expand Down Expand Up @@ -678,6 +683,7 @@ private QueryFetchSearchResult executeFetchPhase(ReaderContext reader, SearchCon
}
executor.success();
}
// This will incRef the QuerySearchResult when it gets created
return new QueryFetchSearchResult(context.queryResult(), context.fetchResult());
}

Expand All @@ -698,15 +704,15 @@ public void executeQueryPhase(
runAsync(getExecutor(readerContext.indexShard()), () -> {
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(null);
try (
SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false);
SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, ResultsType.QUERY, false);
SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(searchContext)
) {
searchContext.addQueryResult();
searchContext.searcher().setAggregatedDfs(readerContext.getAggregatedDfs(null));
processScroll(request, readerContext, searchContext);
QueryPhase.execute(searchContext);
executor.success();
readerContext.setRescoreDocIds(searchContext.rescoreDocIds());
// ScrollQuerySearchResult will incRef the QuerySearchResult when it gets constructed.
return new ScrollQuerySearchResult(searchContext.queryResult(), searchContext.shardTarget());
} catch (Exception e) {
logger.trace("Query phase failed", e);
Expand All @@ -716,17 +722,21 @@ public void executeQueryPhase(
}, wrapFailureListener(listener, readerContext, markAsUsed));
}

/**
* The returned {@link SearchPhaseResult} will have had its ref count incremented by this method.
* It is the responsibility of the caller to ensure that the ref count is correctly decremented
* when the object is no longer needed.
*/
public void executeQueryPhase(QuerySearchRequest request, SearchShardTask task, ActionListener<QuerySearchResult> listener) {
final ReaderContext readerContext = findReaderContext(request.contextId(), request.shardSearchRequest());
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.shardSearchRequest());
final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest));
runAsync(getExecutor(readerContext.indexShard()), () -> {
readerContext.setAggregatedDfs(request.dfs());
try (
SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, true);
SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, ResultsType.QUERY, true);
SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(searchContext)
) {
searchContext.addQueryResult();
searchContext.searcher().setAggregatedDfs(request.dfs());
QueryPhase.execute(searchContext);
if (searchContext.queryResult().hasSearchContext() == false && readerContext.singleSession()) {
Expand All @@ -739,6 +749,7 @@ public void executeQueryPhase(QuerySearchRequest request, SearchShardTask task,
final RescoreDocIds rescoreDocIds = searchContext.rescoreDocIds();
searchContext.queryResult().setRescoreDocIds(rescoreDocIds);
readerContext.setRescoreDocIds(rescoreDocIds);
searchContext.queryResult().incRef();
return searchContext.queryResult();
} catch (Exception e) {
assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e);
Expand Down Expand Up @@ -779,10 +790,9 @@ public void executeFetchPhase(
runAsync(getExecutor(readerContext.indexShard()), () -> {
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(null);
try (
SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false);
SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, ResultsType.FETCH, false);
SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(searchContext)
) {
searchContext.addFetchResult();
searchContext.assignRescoreDocIds(readerContext.getRescoreDocIds(null));
searchContext.searcher().setAggregatedDfs(readerContext.getAggregatedDfs(null));
processScroll(request, readerContext, searchContext);
Expand All @@ -805,8 +815,7 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest());
final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest));
runAsync(getExecutor(readerContext.indexShard()), () -> {
try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, false)) {
searchContext.addFetchResult();
try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, ResultsType.FETCH, false)) {
if (request.lastEmittedDoc() != null) {
searchContext.scrollContext().lastEmittedDoc = request.lastEmittedDoc();
}
Expand Down Expand Up @@ -983,10 +992,12 @@ protected SearchContext createContext(
ReaderContext readerContext,
ShardSearchRequest request,
SearchShardTask task,
ResultsType resultsType,
boolean includeAggregations
) throws IOException {
checkCancelled(task);
final DefaultSearchContext context = createSearchContext(readerContext, request, defaultSearchTimeout);
resultsType.addResultsObject(context);
try {
if (request.scroll() != null) {
context.scrollContext().scroll = request.scroll();
Expand Down Expand Up @@ -1246,7 +1257,7 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
enableRewriteAggsToFilterByFilter,
source.aggregations().isInSortOrderExecutionRequired()
);
context.addReleasable(aggContext);
context.addQuerySearchResultReleasable(aggContext);
try {
AggregatorFactories factories = source.aggregations().build(aggContext, null);
context.aggregations(new SearchContextAggregations(factories));
Expand Down Expand Up @@ -1447,6 +1458,41 @@ public ResponseCollectorService getResponseCollectorService() {
return this.responseCollectorService;
}

/**
* Used to indicate which result object should be instantiated when creating a search context
*/
enum ResultsType {
DFS {
@Override
void addResultsObject(SearchContext context) {
context.addDfsResult();
}
},
QUERY {
@Override
void addResultsObject(SearchContext context) {
context.addQueryResult();
}
},
FETCH {
@Override
void addResultsObject(SearchContext context) {
context.addFetchResult();
}
},
/**
* None is intended for use in testing, when we might not progress all the way to generating results
*/
NONE {
@Override
void addResultsObject(SearchContext context) {
// this space intentionally left blank
}
};

abstract void addResultsObject(SearchContext context);
}

class Reaper implements Runnable {
@Override
public void run() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.internal.ShardSearchContextId;
Expand All @@ -21,17 +23,29 @@ public final class QueryFetchSearchResult extends SearchPhaseResult {

private final QuerySearchResult queryResult;
private final FetchSearchResult fetchResult;
private final RefCounted refCounted;

public QueryFetchSearchResult(StreamInput in) throws IOException {
super(in);
// TODO: Delegate refcounting to QuerySearchResult (see https://github.com/elastic/elasticsearch/pull/94023)
// These get a ref count of 1 when we create them, so we don't need to incRef here
queryResult = new QuerySearchResult(in);
fetchResult = new FetchSearchResult(in);
refCounted = AbstractRefCounted.of(() -> {
queryResult.decRef();
fetchResult.decRef();
});
}

public QueryFetchSearchResult(QuerySearchResult queryResult, FetchSearchResult fetchResult) {
this.queryResult = queryResult;
this.fetchResult = fetchResult;
// We're acquiring a copy, we should incRef it
this.queryResult.incRef();
this.fetchResult.incRef();
refCounted = AbstractRefCounted.of(() -> {
queryResult.decRef();
fetchResult.decRef();
});
}

@Override
Expand Down Expand Up @@ -73,4 +87,24 @@ public void writeTo(StreamOutput out) throws IOException {
queryResult.writeTo(out);
fetchResult.writeTo(out);
}

@Override
public void incRef() {
refCounted.incRef();
}

@Override
public boolean tryIncRef() {
return refCounted.tryIncRef();
}

@Override
public boolean decRef() {
return refCounted.decRef();
}

@Override
public boolean hasReferences() {
return refCounted.hasReferences();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,24 @@ public void writeTo(StreamOutput out) throws IOException {
getSearchShardTarget().writeTo(out);
result.writeTo(out);
}

@Override
public void incRef() {
result.incRef();
}

@Override
public boolean tryIncRef() {
return result.tryIncRef();
}

@Override
public boolean decRef() {
return result.decRef();
}

@Override
public boolean hasReferences() {
return result.hasReferences();
}
}
Loading

0 comments on commit cb04885

Please sign in to comment.