Skip to content

Commit

Permalink
Skip final reduction if SearchRequest holds a cluster alias (#37000)
Browse files Browse the repository at this point in the history
With #36997 we added the ability to provide a cluster alias with a
SearchRequest.

The next step is to disable the final reduction whenever a cluster alias
is provided with the SearchRequest. A cluster alias will be provided
when executing a cross-cluster search request with alternate execution
mode, where each cluster does its own reduction locally. In order for
the CCS node to be able to later perform an additional reduction of the
results, we need to make sure that all the needed info stays available.
This means that terms aggregations can be reduced but not pruned, and
pipeline aggs should not be executed. The final reduction will happen
later in the CCS coordinating node.

Relates to #36997 & #32125
  • Loading branch information
javanna authored Dec 28, 2018
1 parent 34d22f3 commit cb6bac3
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -407,17 +407,18 @@ private SearchHits getHits(ReducedQueryPhase reducedQueryPhase, boolean ignoreFr
* Reduces the given query results and consumes all aggregations and profile results.
* @param queryResults a list of non-null query shard results
*/
public ReducedQueryPhase reducedScrollQueryPhase(Collection<? extends SearchPhaseResult> queryResults) {
return reducedQueryPhase(queryResults, true, true);
ReducedQueryPhase reducedScrollQueryPhase(Collection<? extends SearchPhaseResult> queryResults) {
return reducedQueryPhase(queryResults, true, true, true);
}

/**
* Reduces the given query results and consumes all aggregations and profile results.
* @param queryResults a list of non-null query shard results
*/
public ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
boolean isScrollRequest, boolean trackTotalHits) {
return reducedQueryPhase(queryResults, null, new ArrayList<>(), new TopDocsStats(trackTotalHits), 0, isScrollRequest);
ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
boolean isScrollRequest, boolean trackTotalHits, boolean performFinalReduce) {
return reducedQueryPhase(queryResults, null, new ArrayList<>(), new TopDocsStats(trackTotalHits), 0, isScrollRequest,
performFinalReduce);
}

/**
Expand All @@ -433,7 +434,8 @@ public ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResul
*/
private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
List<InternalAggregations> bufferedAggs, List<TopDocs> bufferedTopDocs,
TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest) {
TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest,
boolean performFinalReduce) {
assert numReducePhases >= 0 : "num reduce phases must be >= 0 but was: " + numReducePhases;
numReducePhases++; // increment for this phase
boolean timedOut = false;
Expand Down Expand Up @@ -499,15 +501,15 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
}
}
final Suggest suggest = groupedSuggestions.isEmpty() ? null : new Suggest(Suggest.reduce(groupedSuggestions));
ReduceContext reduceContext = reduceContextFunction.apply(true);
ReduceContext reduceContext = reduceContextFunction.apply(performFinalReduce);
final InternalAggregations aggregations = aggregationsList.isEmpty() ? null : reduceAggs(aggregationsList,
firstResult.pipelineAggregators(), reduceContext);
final SearchProfileShardResults shardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults);
final SortedTopDocs sortedTopDocs = sortDocs(isScrollRequest, queryResults, bufferedTopDocs, topDocsStats, from, size);
final TotalHits totalHits = topDocsStats.getTotalHits();
return new ReducedQueryPhase(totalHits, topDocsStats.fetchHits, topDocsStats.maxScore,
timedOut, terminatedEarly, suggest, aggregations, shardResults, sortedTopDocs,
firstResult.sortValueFormats(), numReducePhases, size, from, firstResult == null);
firstResult.sortValueFormats(), numReducePhases, size, from, false);
}

/**
Expand Down Expand Up @@ -617,6 +619,7 @@ static final class QueryPhaseResultConsumer extends InitialSearchPhase.ArraySear
private final SearchPhaseController controller;
private int numReducePhases = 0;
private final TopDocsStats topDocsStats = new TopDocsStats();
private final boolean performFinalReduce;

/**
* Creates a new {@link QueryPhaseResultConsumer}
Expand All @@ -626,7 +629,7 @@ static final class QueryPhaseResultConsumer extends InitialSearchPhase.ArraySear
* the buffer is used to incrementally reduce aggregation results before all shards responded.
*/
private QueryPhaseResultConsumer(SearchPhaseController controller, int expectedResultSize, int bufferSize,
boolean hasTopDocs, boolean hasAggs) {
boolean hasTopDocs, boolean hasAggs, boolean performFinalReduce) {
super(expectedResultSize);
if (expectedResultSize != 1 && bufferSize < 2) {
throw new IllegalArgumentException("buffer size must be >= 2 if there is more than one expected result");
Expand All @@ -644,6 +647,7 @@ private QueryPhaseResultConsumer(SearchPhaseController controller, int expectedR
this.hasTopDocs = hasTopDocs;
this.hasAggs = hasAggs;
this.bufferSize = bufferSize;
this.performFinalReduce = performFinalReduce;
}

@Override
Expand Down Expand Up @@ -693,7 +697,7 @@ private synchronized List<TopDocs> getRemainingTopDocs() {
@Override
public ReducedQueryPhase reduce() {
return controller.reducedQueryPhase(results.asList(), getRemainingAggs(), getRemainingTopDocs(), topDocsStats,
numReducePhases, false);
numReducePhases, false, performFinalReduce);
}

/**
Expand All @@ -715,18 +719,19 @@ InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> newSearchPhaseResu
final boolean hasAggs = source != null && source.aggregations() != null;
final boolean hasTopDocs = source == null || source.size() != 0;
final boolean trackTotalHits = source == null || source.trackTotalHits();
final boolean finalReduce = request.getLocalClusterAlias() == null;

if (isScrollRequest == false && (hasAggs || hasTopDocs)) {
// no incremental reduce if scroll is used - we only hit a single shard or sometimes more...
if (request.getBatchedReduceSize() < numShards) {
// only use this if there are aggs and if there are more shards than we should reduce at once
return new QueryPhaseResultConsumer(this, numShards, request.getBatchedReduceSize(), hasTopDocs, hasAggs);
return new QueryPhaseResultConsumer(this, numShards, request.getBatchedReduceSize(), hasTopDocs, hasAggs, finalReduce);
}
}
return new InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult>(numShards) {
@Override
ReducedQueryPhase reduce() {
return reducedQueryPhase(results.asList(), isScrollRequest, trackTotalHits);
return reducedQueryPhase(results.asList(), isScrollRequest, trackTotalHits, finalReduce);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,28 +56,34 @@
import org.junit.Before;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.not;

public class SearchPhaseControllerTests extends ESTestCase {
private SearchPhaseController searchPhaseController;
private List<Boolean> reductions;

@Before
public void setup() {
reductions = new CopyOnWriteArrayList<>();
searchPhaseController = new SearchPhaseController(
(b) -> new InternalAggregation.ReduceContext(BigArrays.NON_RECYCLING_INSTANCE, null, b));
(finalReduce) -> {
reductions.add(finalReduce);
return new InternalAggregation.ReduceContext(BigArrays.NON_RECYCLING_INSTANCE, null, finalReduce);
});
}

public void testSort() {
Expand Down Expand Up @@ -158,7 +164,7 @@ public void testMerge() {
AtomicArray<SearchPhaseResult> queryResults = generateQueryResults(nShards, suggestions, queryResultSize, false);
for (boolean trackTotalHits : new boolean[] {true, false}) {
SearchPhaseController.ReducedQueryPhase reducedQueryPhase =
searchPhaseController.reducedQueryPhase(queryResults.asList(), false, trackTotalHits);
searchPhaseController.reducedQueryPhase(queryResults.asList(), false, trackTotalHits, true);
AtomicArray<SearchPhaseResult> fetchResults = generateFetchResults(nShards,
reducedQueryPhase.sortedTopDocs.scoreDocs, reducedQueryPhase.suggest);
InternalSearchResponse mergedResponse = searchPhaseController.merge(false,
Expand Down Expand Up @@ -308,14 +314,15 @@ private static AtomicArray<SearchPhaseResult> generateFetchResults(int nShards,

public void testConsumer() {
int bufferSize = randomIntBetween(2, 3);
SearchRequest request = new SearchRequest();
SearchRequest request = randomBoolean() ? new SearchRequest() : new SearchRequest("remote");
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")));
request.setBatchedReduceSize(bufferSize);
InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(request, 3);
assertEquals(0, reductions.size());
QuerySearchResult result = new QuerySearchResult(0, new SearchShardTarget("node", new Index("a", "b"), 0, null));
result.topDocs(new TopDocsAndMaxScore(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), Float.NaN),
new DocValueFormat[0]);
InternalAggregations aggs = new InternalAggregations(Arrays.asList(new InternalMax("test", 1.0D, DocValueFormat.RAW,
InternalAggregations aggs = new InternalAggregations(Collections.singletonList(new InternalMax("test", 1.0D, DocValueFormat.RAW,
Collections.emptyList(), Collections.emptyMap())));
result.aggregations(aggs);
result.setShardIndex(0);
Expand All @@ -324,7 +331,7 @@ public void testConsumer() {
result = new QuerySearchResult(1, new SearchShardTarget("node", new Index("a", "b"), 0, null));
result.topDocs(new TopDocsAndMaxScore(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), Float.NaN),
new DocValueFormat[0]);
aggs = new InternalAggregations(Arrays.asList(new InternalMax("test", 3.0D, DocValueFormat.RAW,
aggs = new InternalAggregations(Collections.singletonList(new InternalMax("test", 3.0D, DocValueFormat.RAW,
Collections.emptyList(), Collections.emptyMap())));
result.aggregations(aggs);
result.setShardIndex(2);
Expand All @@ -333,23 +340,29 @@ public void testConsumer() {
result = new QuerySearchResult(1, new SearchShardTarget("node", new Index("a", "b"), 0, null));
result.topDocs(new TopDocsAndMaxScore(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), Float.NaN),
new DocValueFormat[0]);
aggs = new InternalAggregations(Arrays.asList(new InternalMax("test", 2.0D, DocValueFormat.RAW,
aggs = new InternalAggregations(Collections.singletonList(new InternalMax("test", 2.0D, DocValueFormat.RAW,
Collections.emptyList(), Collections.emptyMap())));
result.aggregations(aggs);
result.setShardIndex(1);
consumer.consumeResult(result);
int numTotalReducePhases = 1;
final int numTotalReducePhases;
if (bufferSize == 2) {
assertThat(consumer, instanceOf(SearchPhaseController.QueryPhaseResultConsumer.class));
assertEquals(1, ((SearchPhaseController.QueryPhaseResultConsumer)consumer).getNumReducePhases());
assertEquals(2, ((SearchPhaseController.QueryPhaseResultConsumer)consumer).getNumBuffered());
numTotalReducePhases++;
assertEquals(1, reductions.size());
assertEquals(false, reductions.get(0));
numTotalReducePhases = 2;
} else {
assertThat(consumer, not(instanceOf(SearchPhaseController.QueryPhaseResultConsumer.class)));
assertEquals(0, reductions.size());
numTotalReducePhases = 1;
}

SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertEquals(numTotalReducePhases, reduce.numReducePhases);
assertEquals(numTotalReducePhases, reductions.size());
assertFinalReduction(request);
InternalMax max = (InternalMax) reduce.aggregations.asList().get(0);
assertEquals(3.0D, max.getValue(), 0.0D);
assertFalse(reduce.sortedTopDocs.isSortedByField);
Expand All @@ -362,7 +375,7 @@ public void testConsumerConcurrently() throws InterruptedException {
int expectedNumResults = randomIntBetween(1, 100);
int bufferSize = randomIntBetween(2, 200);

SearchRequest request = new SearchRequest();
SearchRequest request = randomBoolean() ? new SearchRequest() : new SearchRequest("remote");
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")));
request.setBatchedReduceSize(bufferSize);
InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> consumer =
Expand All @@ -378,7 +391,7 @@ public void testConsumerConcurrently() throws InterruptedException {
result.topDocs(new TopDocsAndMaxScore(
new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] {new ScoreDoc(0, number)}), number),
new DocValueFormat[0]);
InternalAggregations aggs = new InternalAggregations(Arrays.asList(new InternalMax("test", (double) number,
InternalAggregations aggs = new InternalAggregations(Collections.singletonList(new InternalMax("test", (double) number,
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap())));
result.aggregations(aggs);
result.setShardIndex(id);
Expand All @@ -392,6 +405,7 @@ public void testConsumerConcurrently() throws InterruptedException {
threads[i].join();
}
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertFinalReduction(request);
InternalMax internalMax = (InternalMax) reduce.aggregations.asList().get(0);
assertEquals(max.get(), internalMax.getValue(), 0.0D);
assertEquals(1, reduce.sortedTopDocs.scoreDocs.length);
Expand All @@ -407,7 +421,7 @@ public void testConsumerConcurrently() throws InterruptedException {
public void testConsumerOnlyAggs() {
int expectedNumResults = randomIntBetween(1, 100);
int bufferSize = randomIntBetween(2, 200);
SearchRequest request = new SearchRequest();
SearchRequest request = randomBoolean() ? new SearchRequest() : new SearchRequest("remote");
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0));
request.setBatchedReduceSize(bufferSize);
InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> consumer =
Expand All @@ -419,14 +433,15 @@ public void testConsumerOnlyAggs() {
QuerySearchResult result = new QuerySearchResult(i, new SearchShardTarget("node", new Index("a", "b"), i, null));
result.topDocs(new TopDocsAndMaxScore(new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), number),
new DocValueFormat[0]);
InternalAggregations aggs = new InternalAggregations(Arrays.asList(new InternalMax("test", (double) number,
InternalAggregations aggs = new InternalAggregations(Collections.singletonList(new InternalMax("test", (double) number,
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap())));
result.aggregations(aggs);
result.setShardIndex(i);
result.size(1);
consumer.consumeResult(result);
}
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertFinalReduction(request);
InternalMax internalMax = (InternalMax) reduce.aggregations.asList().get(0);
assertEquals(max.get(), internalMax.getValue(), 0.0D);
assertEquals(0, reduce.sortedTopDocs.scoreDocs.length);
Expand All @@ -441,7 +456,7 @@ public void testConsumerOnlyAggs() {
public void testConsumerOnlyHits() {
int expectedNumResults = randomIntBetween(1, 100);
int bufferSize = randomIntBetween(2, 200);
SearchRequest request = new SearchRequest();
SearchRequest request = randomBoolean() ? new SearchRequest() : new SearchRequest("remote");
if (randomBoolean()) {
request.source(new SearchSourceBuilder().size(randomIntBetween(1, 10)));
}
Expand All @@ -460,6 +475,7 @@ public void testConsumerOnlyHits() {
consumer.consumeResult(result);
}
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertFinalReduction(request);
assertEquals(1, reduce.sortedTopDocs.scoreDocs.length);
assertEquals(max.get(), reduce.maxScore, 0.0f);
assertEquals(expectedNumResults, reduce.totalHits.value);
Expand All @@ -470,6 +486,12 @@ public void testConsumerOnlyHits() {
assertNull(reduce.sortedTopDocs.collapseValues);
}

private void assertFinalReduction(SearchRequest searchRequest) {
assertThat(reductions.size(), greaterThanOrEqualTo(1));
//the last reduction step was the final one only if no cluster alias was provided with the search request
assertEquals(searchRequest.getLocalClusterAlias() == null, reductions.get(reductions.size() - 1));
}

public void testNewSearchPhaseResults() {
for (int i = 0; i < 10; i++) {
int expectedNumResults = randomIntBetween(1, 10);
Expand Down Expand Up @@ -540,7 +562,7 @@ public void testReduceTopNWithFromOffset() {
public void testConsumerSortByField() {
int expectedNumResults = randomIntBetween(1, 100);
int bufferSize = randomIntBetween(2, 200);
SearchRequest request = new SearchRequest();
SearchRequest request = randomBoolean() ? new SearchRequest() : new SearchRequest("remote");
int size = randomIntBetween(1, 10);
request.setBatchedReduceSize(bufferSize);
InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> consumer =
Expand All @@ -560,6 +582,7 @@ public void testConsumerSortByField() {
consumer.consumeResult(result);
}
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertFinalReduction(request);
assertEquals(Math.min(expectedNumResults, size), reduce.sortedTopDocs.scoreDocs.length);
assertEquals(expectedNumResults, reduce.totalHits.value);
assertEquals(max.get(), ((FieldDoc)reduce.sortedTopDocs.scoreDocs[0]).fields[0]);
Expand All @@ -574,7 +597,7 @@ public void testConsumerSortByField() {
public void testConsumerFieldCollapsing() {
int expectedNumResults = randomIntBetween(30, 100);
int bufferSize = randomIntBetween(2, 200);
SearchRequest request = new SearchRequest();
SearchRequest request = randomBoolean() ? new SearchRequest() : new SearchRequest("remote");
int size = randomIntBetween(5, 10);
request.setBatchedReduceSize(bufferSize);
InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> consumer =
Expand All @@ -596,6 +619,7 @@ public void testConsumerFieldCollapsing() {
consumer.consumeResult(result);
}
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertFinalReduction(request);
assertEquals(3, reduce.sortedTopDocs.scoreDocs.length);
assertEquals(expectedNumResults, reduce.totalHits.value);
assertEquals(a, ((FieldDoc)reduce.sortedTopDocs.scoreDocs[0]).fields[0]);
Expand Down

0 comments on commit cb6bac3

Please sign in to comment.