Skip to content

Commit

Permalink
Add a listener to track the progress of a search request locally
Browse files Browse the repository at this point in the history
This commit adds a function in NodeClient that allows to track the progress
of a search request locally. Progress is tracked through a SearchProgressListener
that exposes query and fetch responses as well as partial and final reduces.
This new method can be used by modules/plugins inside a node in order to track the
progress of a local search request.

Relates elastic#49091
  • Loading branch information
jimczi committed Nov 21, 2019
1 parent 3c5eb04 commit 931d8d0
Show file tree
Hide file tree
Showing 17 changed files with 657 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
private final SearchResponse.Clusters clusters;

private final GroupShardsIterator<SearchShardIterator> toSkipShardsIts;
private final GroupShardsIterator<SearchShardIterator> shardsIts;
protected final GroupShardsIterator<SearchShardIterator> shardsIts;
private final int expectedTotalOps;
private final AtomicInteger totalOps = new AtomicInteger();
private final int maxConcurrentRequestsPerNode;
Expand Down Expand Up @@ -443,7 +443,7 @@ public final void onShardFailure(final int shardIndex, @Nullable SearchShardTarg
successfulOps.decrementAndGet(); // if this shard was successful before (initial phase) we have to adjust the counter
}
}
results.consumeShardFailure(shardIndex);
results.consumeShardFailure(shardIndex, e);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ boolean hasResult(int shardIndex) {
}

@Override
void consumeShardFailure(int shardIndex) {
void consumeShardFailure(int shardIndex, Exception exc) {
// we have to carry over shard failures in order to account for them in the response.
consumeResult(shardIndex, true, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.dfs.AggregatedDfs;
import org.elasticsearch.search.dfs.DfsSearchResult;
import org.elasticsearch.search.query.QuerySearchRequest;
Expand All @@ -46,13 +47,15 @@ final class DfsQueryPhase extends SearchPhase {
private final Function<ArraySearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory;
private final SearchPhaseContext context;
private final SearchTransportService searchTransportService;
private final SearchProgressListener progressListener;

DfsQueryPhase(AtomicArray<DfsSearchResult> dfsSearchResults,
SearchPhaseController searchPhaseController,
Function<ArraySearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory,
SearchPhaseContext context) {
super("dfs_query");
this.queryResult = searchPhaseController.newSearchPhaseResults(context.getRequest(), context.getNumShards());
this.progressListener = context.getTask().getProgressListener();
this.queryResult = searchPhaseController.newSearchPhaseResults(progressListener, context.getRequest(), context.getNumShards());
this.searchPhaseController = searchPhaseController;
this.dfsSearchResults = dfsSearchResults;
this.nextPhaseFactory = nextPhaseFactory;
Expand All @@ -69,6 +72,8 @@ public void run() throws IOException {
final CountedCollector<SearchPhaseResult> counter = new CountedCollector<>(queryResult::consumeResult,
resultList.size(),
() -> context.executeNextPhase(this, nextPhaseFactory.apply(queryResult)), context);
final SearchSourceBuilder sourceBuilder = context.getRequest().source();
progressListener.onListShards(progressListener.searchShards(resultList), sourceBuilder == null || sourceBuilder.size() != 0);
for (final DfsSearchResult dfsResult : resultList) {
final SearchShardTarget searchShardTarget = dfsResult.getSearchShardTarget();
Transport.Connection connection = context.getConnection(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.elasticsearch.transport.Transport;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.function.BiFunction;

Expand All @@ -49,6 +50,7 @@ final class FetchSearchPhase extends SearchPhase {
private final SearchPhaseContext context;
private final Logger logger;
private final SearchPhaseResults<SearchPhaseResult> resultConsumer;
private final SearchProgressListener progressListener;

FetchSearchPhase(SearchPhaseResults<SearchPhaseResult> resultConsumer,
SearchPhaseController searchPhaseController,
Expand All @@ -72,6 +74,7 @@ final class FetchSearchPhase extends SearchPhase {
this.context = context;
this.logger = context.getLogger();
this.resultConsumer = resultConsumer;
this.progressListener = context.getTask().getProgressListener();
}

@Override
Expand Down Expand Up @@ -136,6 +139,8 @@ private void innerRun() throws IOException {
}
// in any case we count down this result since we don't talk to this shard anymore
counter.countDown();
// empty result
progressListener.onFetchResult(queryResult.fetchResult());
} else {
SearchShardTarget searchShardTarget = queryResult.getSearchShardTarget();
Transport.Connection connection = context.getConnection(searchShardTarget.getClusterAlias(),
Expand Down Expand Up @@ -164,11 +169,16 @@ private void executeFetch(final int shardIndex, final SearchShardTarget shardTar
new SearchActionListener<FetchSearchResult>(shardTarget, shardIndex) {
@Override
public void innerOnResponse(FetchSearchResult result) {
boolean success = false;
try {
counter.onResult(result);
success = true;
} catch (Exception e) {
context.onPhaseFailure(FetchSearchPhase.this, "", e);
}
if (success) {
progressListener.onFetchResult(result);
}
}

@Override
Expand All @@ -182,6 +192,7 @@ public void onFailure(Exception e) {
// request to clear the search context.
releaseIrrelevantSearchContext(querySearchResult);
}
progressListener.onFetchFailure(shardIndex, e);
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.elasticsearch.search.SearchShardTarget;

/**
* An base action listener that ensures shard target and shard index is set on all responses
* A base action listener that ensures shard target and shard index is set on all responses
* received by this listener.
*/
abstract class SearchActionListener<T extends SearchPhaseResult> implements ActionListener<T> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<Sear
private final int bufferSize;
private int index;
private final SearchPhaseController controller;
private final SearchProgressListener progressListener;
private int numReducePhases = 0;
private final TopDocsStats topDocsStats;
private final boolean performFinalReduce;
Expand All @@ -582,8 +583,9 @@ static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<Sear
* @param bufferSize the size of the reduce buffer. if the buffer size is smaller than the number of expected results
* the buffer is used to incrementally reduce aggregation results before all shards responded.
*/
private QueryPhaseResultConsumer(SearchPhaseController controller, int expectedResultSize, int bufferSize,
boolean hasTopDocs, boolean hasAggs, int trackTotalHitsUpTo, boolean performFinalReduce) {
private QueryPhaseResultConsumer(SearchProgressListener progressListener, SearchPhaseController controller,
int expectedResultSize, int bufferSize, boolean hasTopDocs, boolean hasAggs,
int trackTotalHitsUpTo, boolean performFinalReduce) {
super(expectedResultSize);
if (expectedResultSize != 1 && bufferSize < 2) {
throw new IllegalArgumentException("buffer size must be >= 2 if there is more than one expected result");
Expand All @@ -595,6 +597,7 @@ private QueryPhaseResultConsumer(SearchPhaseController controller, int expectedR
throw new IllegalArgumentException("either aggs or top docs must be present");
}
this.controller = controller;
this.progressListener = progressListener;
// no need to buffer anything if we have less expected results. in this case we don't consume any results ahead of time.
this.aggsBuffer = new InternalAggregations[hasAggs ? bufferSize : 0];
this.topDocsBuffer = new TopDocs[hasTopDocs ? bufferSize : 0];
Expand All @@ -605,11 +608,17 @@ private QueryPhaseResultConsumer(SearchPhaseController controller, int expectedR
this.performFinalReduce = performFinalReduce;
}

@Override
void consumeShardFailure(int shardIndex, Exception exc) {
progressListener.onQueryFailure(shardIndex, exc);
}

@Override
public void consumeResult(SearchPhaseResult result) {
super.consumeResult(result);
QuerySearchResult queryResult = result.queryResult();
consumeInternal(queryResult);
progressListener.onQueryResult(queryResult);
}

private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
Expand All @@ -629,6 +638,10 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
}
numReducePhases++;
index = 1;
if (hasAggs) {
progressListener.onPartialReduce(progressListener.searchShards(results.asList()),
topDocsStats.getTotalHits(), aggsBuffer[0], numReducePhases);
}
}
final int i = index++;
if (hasAggs) {
Expand All @@ -652,8 +665,10 @@ private synchronized List<TopDocs> getRemainingTopDocs() {

@Override
public ReducedQueryPhase reduce() {
return controller.reducedQueryPhase(results.asList(), getRemainingAggs(), getRemainingTopDocs(), topDocsStats,
numReducePhases, false, performFinalReduce);
ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(),
getRemainingAggs(), getRemainingTopDocs(), topDocsStats, numReducePhases, false, performFinalReduce);
progressListener.onReduce(progressListener.searchShards(results.asList()), reducePhase.totalHits, reducePhase.aggregations);
return reducePhase;
}

/**
Expand All @@ -678,7 +693,9 @@ private int resolveTrackTotalHits(SearchRequest request) {
/**
* Returns a new ArraySearchPhaseResults instance. This might return an instance that reduces search responses incrementally.
*/
ArraySearchPhaseResults<SearchPhaseResult> newSearchPhaseResults(SearchRequest request, int numShards) {
ArraySearchPhaseResults<SearchPhaseResult> newSearchPhaseResults(SearchProgressListener listener,
SearchRequest request,
int numShards) {
SearchSourceBuilder source = request.source();
boolean isScrollRequest = request.scroll() != null;
final boolean hasAggs = source != null && source.aggregations() != null;
Expand All @@ -688,14 +705,30 @@ ArraySearchPhaseResults<SearchPhaseResult> newSearchPhaseResults(SearchRequest r
// no incremental reduce if scroll is used - we only hit a single shard or sometimes more...
if (request.getBatchedReduceSize() < numShards) {
// only use this if there are aggs and if there are more shards than we should reduce at once
return new QueryPhaseResultConsumer(this, numShards, request.getBatchedReduceSize(), hasTopDocs, hasAggs,
return new QueryPhaseResultConsumer(listener, this, numShards, request.getBatchedReduceSize(), hasTopDocs, hasAggs,
trackTotalHitsUpTo, request.isFinalReduce());
}
}
return new ArraySearchPhaseResults<SearchPhaseResult>(numShards) {
@Override
void consumeResult(SearchPhaseResult result) {
super.consumeResult(result);
listener.onQueryResult(result.queryResult());
}

@Override
void consumeShardFailure(int shardIndex, Exception exc) {
super.consumeShardFailure(shardIndex, exc);
listener.onQueryFailure(shardIndex, exc);
}

@Override
ReducedQueryPhase reduce() {
return reducedQueryPhase(results.asList(), isScrollRequest, trackTotalHitsUpTo, request.isFinalReduce());
List<SearchPhaseResult> resultList = results.asList();
final ReducedQueryPhase reducePhase =
reducedQueryPhase(resultList, isScrollRequest, trackTotalHitsUpTo, request.isFinalReduce());
listener.onReduce(listener.searchShards(resultList), reducePhase.totalHits, reducePhase.aggregations);
return reducePhase;
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ final int getNumShards() {
*/
abstract boolean hasResult(int shardIndex);

void consumeShardFailure(int shardIndex) {}
void consumeShardFailure(int shardIndex, Exception exc) {}

AtomicArray<Result> getAtomicArray() {
throw new UnsupportedOperationException();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.action.search;

import org.elasticsearch.action.ActionListener;

/**
* An {@link ActionListener} for search requests that allows to track progress of the {@link SearchAction}.
* See {@link SearchProgressListener}.
*/
public abstract class SearchProgressActionListener extends SearchProgressListener implements ActionListener<SearchResponse> {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.action.search;

import org.apache.lucene.search.TotalHits;
import org.elasticsearch.cluster.routing.GroupShardsIterator;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.fetch.FetchSearchResult;
import org.elasticsearch.search.query.QuerySearchResult;

import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

/**
* A listener that allows to track progress of the {@link SearchAction}.
*/
abstract class SearchProgressListener {
/**
* Executed when shards are ready to be queried.
*
* @param shards The list of shards to query.
* @param fetchPhase <code>true</code> if the search needs a fetch phase, <code>false</code> otherwise.
**/
public void onListShards(List<SearchShard> shards, boolean fetchPhase) {}

/**
* Executed when a shard returns a query result.
*
* @param result The query result.
*/
public void onQueryResult(QuerySearchResult result) {}

/**
* Executed when a shard reports a query failure.
*
* @param shardIndex The index of the shard in the list provided by onListShards.
* @param exc The cause of the failure.
*/
public void onQueryFailure(int shardIndex, Exception exc) { }

/**
* Executed when a partial reduce is created. The number of partial reduce can be controlled via
* {@link SearchRequest#setBatchedReduceSize(int)}.
*
* @param shards The list of shards that are part of this reduce.
* @param totalHits The total number of hits in this reduce.
* @param aggs The partial result for aggregations.
* @param version The version number for this reduce.
*/
public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int version) {}

/**
* Executed once when the final reduce is created.
*
* @param shards The list of shards that are part of this reduce.
* @param totalHits The total number of hits in this reduce.
* @param aggs The final result for aggregations.
*/
public void onReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs) {}

/**
* Executed when a shard returns a query result.
*
* @param result The fetch result.
*/
public void onFetchResult(FetchSearchResult result) {}

/**
* Executed when a shard reports a fetch failure.
*
* @param shardIndex The index of the shard in the list provided by onListShards.
* @param exc The cause of the failure.
*/
public void onFetchFailure(int shardIndex, Exception exc) {}

final List<SearchShard> searchShards(List<? extends SearchPhaseResult> results) {
return results.stream()
.filter(Objects::nonNull)
.map(SearchPhaseResult::getSearchShardTarget)
.map(e -> new SearchShard(e.getClusterAlias(), e.getShardId()))
.collect(Collectors.toList());
}

final List<SearchShard> searchShards(GroupShardsIterator<SearchShardIterator> its) {
return StreamSupport.stream(its.spliterator(), false)
.map(e -> new SearchShard(e.getClusterAlias(), e.shardId()))
.collect(Collectors.toList());
}

public static final SearchProgressListener NOOP = new SearchProgressListener() {};
}
Loading

0 comments on commit 931d8d0

Please sign in to comment.