Skip to content

Commit 5074f2c

Browse files
authored
[ML] explicitly disallow partial results in datafeed extractors (#55537)
Instead of doing our own checks against REST status, shard counts, and shard failures, this commit changes all our extractor search requests to set `.setAllowPartialSearchResults(false)`. - Scrolls are automatically cleared when a search failure occurs with `.setAllowPartialSearchResults(false)` set. - Code error handling is simplified closes #40793
1 parent ed57adb commit 5074f2c

File tree

10 files changed

+111
-217
lines changed

10 files changed

+111
-217
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/extractor/ExtractorUtils.java

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,20 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.datafeed.extractor;
77

8-
import org.apache.logging.log4j.LogManager;
9-
import org.apache.logging.log4j.Logger;
108
import org.elasticsearch.ElasticsearchException;
11-
import org.elasticsearch.action.search.SearchResponse;
12-
import org.elasticsearch.action.search.ShardSearchFailure;
139
import org.elasticsearch.common.Rounding;
1410
import org.elasticsearch.common.unit.TimeValue;
1511
import org.elasticsearch.index.query.BoolQueryBuilder;
1612
import org.elasticsearch.index.query.QueryBuilder;
1713
import org.elasticsearch.index.query.RangeQueryBuilder;
18-
import org.elasticsearch.rest.RestStatus;
1914
import org.elasticsearch.search.aggregations.AggregationBuilder;
2015
import org.elasticsearch.search.aggregations.AggregatorFactories;
2116
import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramAggregationBuilder;
2217
import org.elasticsearch.search.aggregations.bucket.histogram.HistogramAggregationBuilder;
2318
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
2419
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
2520

26-
import java.io.IOException;
2721
import java.time.ZoneOffset;
28-
import java.util.Arrays;
2922
import java.util.Collection;
3023
import java.util.concurrent.TimeUnit;
3124

@@ -34,7 +27,6 @@
3427
*/
3528
public final class ExtractorUtils {
3629

37-
private static final Logger LOGGER = LogManager.getLogger(ExtractorUtils.class);
3830
private static final String EPOCH_MILLIS = "epoch_millis";
3931

4032
private ExtractorUtils() {}
@@ -47,25 +39,6 @@ public static QueryBuilder wrapInTimeRangeQuery(QueryBuilder userQuery, String t
4739
return new BoolQueryBuilder().filter(userQuery).filter(timeQuery);
4840
}
4941

50-
/**
51-
* Checks that a {@link SearchResponse} has an OK status code and no shard failures
52-
*/
53-
public static void checkSearchWasSuccessful(String jobId, SearchResponse searchResponse) throws IOException {
54-
if (searchResponse.status() != RestStatus.OK) {
55-
throw new IOException("[" + jobId + "] Search request returned status code: " + searchResponse.status()
56-
+ ". Response was:\n" + searchResponse.toString());
57-
}
58-
ShardSearchFailure[] shardFailures = searchResponse.getShardFailures();
59-
if (shardFailures != null && shardFailures.length > 0) {
60-
LOGGER.error("[{}] Search request returned shard failures: {}", jobId, Arrays.toString(shardFailures));
61-
throw new IOException(ExceptionsHelper.shardFailuresToErrorMsg(jobId, shardFailures));
62-
}
63-
int unavailableShards = searchResponse.getTotalShards() - searchResponse.getSuccessfulShards();
64-
if (unavailableShards > 0) {
65-
throw new IOException("[" + jobId + "] Search request encountered [" + unavailableShards + "] unavailable shards");
66-
}
67-
}
68-
6942
/**
7043
* Find the (date) histogram in {@code aggFactory} and extract its interval.
7144
* Throws if there is no (date) histogram or if the histogram has sibling

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/aggregation/AbstractAggregationDataExtractor.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,13 @@ public Optional<InputStream> next() throws IOException {
107107
return Optional.ofNullable(processNextBatch());
108108
}
109109

110-
private Aggregations search() throws IOException {
110+
private Aggregations search() {
111111
LOGGER.debug("[{}] Executing aggregated search", context.jobId);
112-
SearchResponse searchResponse = executeSearchRequest(buildSearchRequest(buildBaseSearchSource()));
112+
T searchRequest = buildSearchRequest(buildBaseSearchSource());
113+
assert searchRequest.request().allowPartialSearchResults() == false;
114+
SearchResponse searchResponse = executeSearchRequest(searchRequest);
113115
LOGGER.debug("[{}] Search response was obtained", context.jobId);
114116
timingStatsReporter.reportSearchDuration(searchResponse.getTook());
115-
ExtractorUtils.checkSearchWasSuccessful(context.jobId, searchResponse);
116117
return validateAggs(searchResponse.getAggregations());
117118
}
118119

@@ -166,10 +167,6 @@ private InputStream processNextBatch() throws IOException {
166167
return new ByteArrayInputStream(outputStream.toByteArray());
167168
}
168169

169-
protected long getHistogramInterval() {
170-
return ExtractorUtils.getHistogramIntervalMillis(context.aggs);
171-
}
172-
173170
public AggregationDataExtractorContext getContext() {
174171
return context;
175172
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/aggregation/AggregationDataExtractor.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ protected SearchRequestBuilder buildSearchRequest(SearchSourceBuilder searchSour
2929
return new SearchRequestBuilder(client, SearchAction.INSTANCE)
3030
.setSource(searchSourceBuilder)
3131
.setIndicesOptions(context.indicesOptions)
32+
.setAllowPartialSearchResults(false)
3233
.setIndices(context.indices);
3334
}
3435
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/aggregation/AggregationToJsonProcessor.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ class AggregationToJsonProcessor {
6363
* @param includeDocCount whether to include the doc_count
6464
* @param startTime buckets with a timestamp before this time are discarded
6565
*/
66-
AggregationToJsonProcessor(String timeField, Set<String> fields, boolean includeDocCount, long startTime)
67-
throws IOException {
66+
AggregationToJsonProcessor(String timeField, Set<String> fields, boolean includeDocCount, long startTime) {
6867
this.timeField = Objects.requireNonNull(timeField);
6968
this.fields = Objects.requireNonNull(fields);
7069
this.includeDocCount = includeDocCount;
@@ -279,7 +278,7 @@ private void processBucket(MultiBucketsAggregation bucketAgg, boolean addField)
279278
* Adds a leaf key-value. It returns {@code true} if the key added or {@code false} when nothing was added.
280279
* Non-finite metric values are not added.
281280
*/
282-
private boolean processLeaf(Aggregation agg) throws IOException {
281+
private boolean processLeaf(Aggregation agg) {
283282
if (agg instanceof NumericMetricsAggregation.SingleValue) {
284283
return processSingleValue((NumericMetricsAggregation.SingleValue) agg);
285284
} else if (agg instanceof Percentiles) {
@@ -291,7 +290,7 @@ private boolean processLeaf(Aggregation agg) throws IOException {
291290
}
292291
}
293292

294-
private boolean processSingleValue(NumericMetricsAggregation.SingleValue singleValue) throws IOException {
293+
private boolean processSingleValue(NumericMetricsAggregation.SingleValue singleValue) {
295294
return addMetricIfFinite(singleValue.getName(), singleValue.value());
296295
}
297296

@@ -311,7 +310,7 @@ private boolean processGeoCentroid(GeoCentroid agg) {
311310
return false;
312311
}
313312

314-
private boolean processPercentiles(Percentiles percentiles) throws IOException {
313+
private boolean processPercentiles(Percentiles percentiles) {
315314
Iterator<Percentile> percentileIterator = percentiles.iterator();
316315
boolean aggregationAdded = addMetricIfFinite(percentiles.getName(), percentileIterator.next().getValue());
317316
if (percentileIterator.hasNext()) {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/aggregation/RollupDataExtractor.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class RollupDataExtractor extends AbstractAggregationDataExtractor<RollupSearchA
2828
protected RollupSearchAction.RequestBuilder buildSearchRequest(SearchSourceBuilder searchSourceBuilder) {
2929
SearchRequest searchRequest = new SearchRequest().indices(context.indices)
3030
.indicesOptions(context.indicesOptions)
31+
.allowPartialSearchResults(false)
3132
.source(searchSourceBuilder);
3233

3334
return new RollupSearchAction.RequestBuilder(client, searchRequest);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/chunked/ChunkedDataExtractor.java

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ public Optional<InputStream> next() throws IOException {
114114
return getNextStream();
115115
}
116116

117-
private void setUpChunkedSearch() throws IOException {
117+
private void setUpChunkedSearch() {
118118
DataSummary dataSummary = dataSummaryFactory.buildDataSummary();
119119
if (dataSummary.hasData()) {
120120
currentStart = context.timeAligner.alignToFloor(dataSummary.earliestTime());
@@ -196,21 +196,18 @@ private class DataSummaryFactory {
196196
* So, if we need to gather an appropriate chunked time for aggregations, we can utilize the AggregatedDataSummary
197197
*
198198
* @return DataSummary object
199-
* @throws IOException when timefield range search fails
200199
*/
201-
private DataSummary buildDataSummary() throws IOException {
200+
private DataSummary buildDataSummary() {
202201
return context.hasAggregations ? newAggregatedDataSummary() : newScrolledDataSummary();
203202
}
204203

205-
private DataSummary newScrolledDataSummary() throws IOException {
204+
private DataSummary newScrolledDataSummary() {
206205
SearchRequestBuilder searchRequestBuilder = rangeSearchRequest();
207206

208207
SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder);
209208
LOGGER.debug("[{}] Scrolling Data summary response was obtained", context.jobId);
210209
timingStatsReporter.reportSearchDuration(searchResponse.getTook());
211210

212-
ExtractorUtils.checkSearchWasSuccessful(context.jobId, searchResponse);
213-
214211
Aggregations aggregations = searchResponse.getAggregations();
215212
long earliestTime = 0;
216213
long latestTime = 0;
@@ -224,16 +221,14 @@ private DataSummary newScrolledDataSummary() throws IOException {
224221
return new ScrolledDataSummary(earliestTime, latestTime, totalHits);
225222
}
226223

227-
private DataSummary newAggregatedDataSummary() throws IOException {
224+
private DataSummary newAggregatedDataSummary() {
228225
// TODO: once RollupSearchAction is changed from indices:admin* to indices:data/read/* this branch is not needed
229226
ActionRequestBuilder<SearchRequest, SearchResponse> searchRequestBuilder =
230227
dataExtractorFactory instanceof RollupDataExtractorFactory ? rollupRangeSearchRequest() : rangeSearchRequest();
231228
SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder);
232229
LOGGER.debug("[{}] Aggregating Data summary response was obtained", context.jobId);
233230
timingStatsReporter.reportSearchDuration(searchResponse.getTook());
234231

235-
ExtractorUtils.checkSearchWasSuccessful(context.jobId, searchResponse);
236-
237232
Aggregations aggregations = searchResponse.getAggregations();
238233
Min min = aggregations.get(EARLIEST_TIME);
239234
Max max = aggregations.get(LATEST_TIME);
@@ -253,12 +248,14 @@ private SearchRequestBuilder rangeSearchRequest() {
253248
.setIndices(context.indices)
254249
.setIndicesOptions(context.indicesOptions)
255250
.setSource(rangeSearchBuilder())
251+
.setAllowPartialSearchResults(false)
256252
.setTrackTotalHits(true);
257253
}
258254

259255
private RollupSearchAction.RequestBuilder rollupRangeSearchRequest() {
260256
SearchRequest searchRequest = new SearchRequest().indices(context.indices)
261257
.indicesOptions(context.indicesOptions)
258+
.allowPartialSearchResults(false)
262259
.source(rangeSearchBuilder());
263260
return new RollupSearchAction.RequestBuilder(client, searchRequest);
264261
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractor.java

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,13 @@ private Optional<InputStream> tryNextStream() throws IOException {
102102
return scrollId == null ?
103103
Optional.ofNullable(initScroll(context.start)) : Optional.ofNullable(continueScroll());
104104
} catch (Exception e) {
105-
// In case of error make sure we clear the scroll context
106-
clearScroll();
107-
throw e;
105+
scrollId = null;
106+
if (searchHasShardFailure) {
107+
throw e;
108+
}
109+
LOGGER.debug("[{}] Resetting scroll search after shard failure", context.jobId);
110+
markScrollAsErrored();
111+
return Optional.ofNullable(initScroll(lastTimestamp == null ? context.start : lastTimestamp));
108112
}
109113
}
110114

@@ -127,6 +131,7 @@ private SearchRequestBuilder buildSearchRequest(long start) {
127131
.setIndices(context.indices)
128132
.setIndicesOptions(context.indicesOptions)
129133
.setSize(context.scrollSize)
134+
.setAllowPartialSearchResults(false)
130135
.setQuery(ExtractorUtils.wrapInTimeRangeQuery(
131136
context.query, context.extractedFields.timeField(), start, context.end));
132137

@@ -147,14 +152,6 @@ private SearchRequestBuilder buildSearchRequest(long start) {
147152
private InputStream processSearchResponse(SearchResponse searchResponse) throws IOException {
148153

149154
scrollId = searchResponse.getScrollId();
150-
151-
if (searchResponse.getFailedShards() > 0 && searchHasShardFailure == false) {
152-
LOGGER.debug("[{}] Resetting scroll search after shard failure", context.jobId);
153-
markScrollAsErrored();
154-
return initScroll(lastTimestamp == null ? context.start : lastTimestamp);
155-
}
156-
157-
ExtractorUtils.checkSearchWasSuccessful(context.jobId, searchResponse);
158155
if (searchResponse.getHits().getHits().length == 0) {
159156
hasNext = false;
160157
clearScroll();
@@ -190,24 +187,23 @@ private InputStream continueScroll() throws IOException {
190187
try {
191188
searchResponse = executeSearchScrollRequest(scrollId);
192189
} catch (SearchPhaseExecutionException searchExecutionException) {
193-
if (searchHasShardFailure == false) {
194-
LOGGER.debug("[{}] Reinitializing scroll due to SearchPhaseExecutionException", context.jobId);
195-
markScrollAsErrored();
196-
searchResponse =
197-
executeSearchRequest(buildSearchRequest(lastTimestamp == null ? context.start : lastTimestamp));
198-
} else {
190+
if (searchHasShardFailure) {
199191
throw searchExecutionException;
200192
}
193+
LOGGER.debug("[{}] search failed due to SearchPhaseExecutionException. Will attempt again with new scroll",
194+
context.jobId);
195+
markScrollAsErrored();
196+
searchResponse = executeSearchRequest(buildSearchRequest(lastTimestamp == null ? context.start : lastTimestamp));
201197
}
202198
LOGGER.debug("[{}] Search response was obtained", context.jobId);
203199
timingStatsReporter.reportSearchDuration(searchResponse.getTook());
204200
return processSearchResponse(searchResponse);
205201
}
206202

207-
private void markScrollAsErrored() {
203+
void markScrollAsErrored() {
208204
// This could be a transient error with the scroll Id.
209205
// Reinitialise the scroll and try again but only once.
210-
clearScroll();
206+
scrollId = null;
211207
if (lastTimestamp != null) {
212208
lastTimestamp++;
213209
}

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/aggregation/AggregationDataExtractorTests.java

Lines changed: 12 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
*/
66
package org.elasticsearch.xpack.ml.datafeed.extractor.aggregation;
77

8+
import org.elasticsearch.action.search.SearchPhaseExecutionException;
89
import org.elasticsearch.action.search.SearchRequest;
910
import org.elasticsearch.action.search.SearchRequestBuilder;
1011
import org.elasticsearch.action.search.SearchResponse;
@@ -64,6 +65,7 @@ public class AggregationDataExtractorTests extends ESTestCase {
6465
private class TestDataExtractor extends AggregationDataExtractor {
6566

6667
private SearchResponse nextResponse;
68+
private SearchPhaseExecutionException ex;
6769

6870
TestDataExtractor(long start, long end) {
6971
super(testClient, createContext(start, end), timingStatsReporter);
@@ -72,12 +74,19 @@ private class TestDataExtractor extends AggregationDataExtractor {
7274
@Override
7375
protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) {
7476
capturedSearchRequests.add(searchRequestBuilder);
77+
if (ex != null) {
78+
throw ex;
79+
}
7580
return nextResponse;
7681
}
7782

7883
void setNextResponse(SearchResponse searchResponse) {
7984
nextResponse = searchResponse;
8085
}
86+
87+
void setNextResponseToError(SearchPhaseExecutionException ex) {
88+
this.ex = ex;
89+
}
8190
}
8291

8392
@Before
@@ -246,29 +255,12 @@ public void testExtractionGivenCancelHalfWay() throws IOException {
246255
assertThat(capturedSearchRequests.size(), equalTo(1));
247256
}
248257

249-
public void testExtractionGivenSearchResponseHasError() throws IOException {
250-
TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L);
251-
extractor.setNextResponse(createErrorResponse());
252-
253-
assertThat(extractor.hasNext(), is(true));
254-
expectThrows(IOException.class, extractor::next);
255-
}
256-
257-
public void testExtractionGivenSearchResponseHasShardFailures() {
258+
public void testExtractionGivenSearchResponseHasError() {
258259
TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L);
259-
extractor.setNextResponse(createResponseWithShardFailures());
260+
extractor.setNextResponseToError(new SearchPhaseExecutionException("phase 1", "boom", ShardSearchFailure.EMPTY_ARRAY));
260261

261262
assertThat(extractor.hasNext(), is(true));
262-
expectThrows(IOException.class, extractor::next);
263-
}
264-
265-
public void testExtractionGivenInitSearchResponseEncounteredUnavailableShards() {
266-
TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L);
267-
extractor.setNextResponse(createResponseWithUnavailableShards(2));
268-
269-
assertThat(extractor.hasNext(), is(true));
270-
IOException e = expectThrows(IOException.class, extractor::next);
271-
assertThat(e.getMessage(), equalTo("[" + jobId + "] Search request encountered [2] unavailable shards"));
263+
expectThrows(SearchPhaseExecutionException.class, extractor::next);
272264
}
273265

274266
private AggregationDataExtractorContext createContext(long start, long end) {
@@ -295,29 +287,6 @@ private SearchResponse createSearchResponse(Aggregations aggregations) {
295287
return searchResponse;
296288
}
297289

298-
private SearchResponse createErrorResponse() {
299-
SearchResponse searchResponse = mock(SearchResponse.class);
300-
when(searchResponse.status()).thenReturn(RestStatus.INTERNAL_SERVER_ERROR);
301-
return searchResponse;
302-
}
303-
304-
private SearchResponse createResponseWithShardFailures() {
305-
SearchResponse searchResponse = mock(SearchResponse.class);
306-
when(searchResponse.status()).thenReturn(RestStatus.OK);
307-
when(searchResponse.getShardFailures()).thenReturn(
308-
new ShardSearchFailure[] { new ShardSearchFailure(new RuntimeException("shard failed"))});
309-
return searchResponse;
310-
}
311-
312-
private SearchResponse createResponseWithUnavailableShards(int unavailableShards) {
313-
SearchResponse searchResponse = mock(SearchResponse.class);
314-
when(searchResponse.status()).thenReturn(RestStatus.OK);
315-
when(searchResponse.getSuccessfulShards()).thenReturn(3);
316-
when(searchResponse.getTotalShards()).thenReturn(3 + unavailableShards);
317-
when(searchResponse.getTook()).thenReturn(TimeValue.timeValueMillis(randomNonNegativeLong()));
318-
return searchResponse;
319-
}
320-
321290
private static String asString(InputStream inputStream) throws IOException {
322291
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
323292
return reader.lines().collect(Collectors.joining("\n"));

0 commit comments

Comments
 (0)