Skip to content

Commit 23d2ff1

Browse files
committed
Cancel search on shard failure when partial results disallowed (#63520)
If the partial results parameter is false (which defaults to true), then we should cancel the search request when it hits a shard failure because the caller won't consume the partial results. Closes #60278
1 parent 7c7e20d commit 23d2ff1

File tree

13 files changed

+146
-51
lines changed

13 files changed

+146
-51
lines changed

server/src/internalClusterTest/java/org/elasticsearch/search/SearchCancellationIT.java

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,11 @@
3131
import org.elasticsearch.action.search.SearchPhaseExecutionException;
3232
import org.elasticsearch.action.search.SearchResponse;
3333
import org.elasticsearch.action.search.SearchScrollAction;
34+
import org.elasticsearch.action.search.SearchTask;
35+
import org.elasticsearch.action.search.SearchType;
3436
import org.elasticsearch.action.search.ShardSearchFailure;
3537
import org.elasticsearch.action.support.WriteRequest;
38+
import org.elasticsearch.cluster.metadata.IndexMetadata;
3639
import org.elasticsearch.common.Strings;
3740
import org.elasticsearch.common.settings.Settings;
3841
import org.elasticsearch.common.unit.TimeValue;
@@ -42,17 +45,22 @@
4245
import org.elasticsearch.script.Script;
4346
import org.elasticsearch.script.ScriptType;
4447
import org.elasticsearch.search.lookup.LeafFieldsLookup;
48+
import org.elasticsearch.tasks.Task;
4549
import org.elasticsearch.tasks.TaskCancelledException;
4650
import org.elasticsearch.tasks.TaskInfo;
4751
import org.elasticsearch.test.ESIntegTestCase;
52+
import org.elasticsearch.transport.TransportService;
4853

4954
import java.util.ArrayList;
5055
import java.util.Collection;
5156
import java.util.Collections;
5257
import java.util.List;
5358
import java.util.Map;
59+
import java.util.concurrent.CountDownLatch;
60+
import java.util.concurrent.TimeUnit;
5461
import java.util.concurrent.atomic.AtomicBoolean;
5562
import java.util.concurrent.atomic.AtomicInteger;
63+
import java.util.concurrent.atomic.AtomicReference;
5664
import java.util.function.Function;
5765

5866
import static org.elasticsearch.index.query.QueryBuilders.scriptQuery;
@@ -273,13 +281,83 @@ public void testCancelMultiSearch() throws Exception {
273281
}
274282
}
275283

284+
public void testCancelFailedSearchWhenPartialResultDisallowed() throws Exception {
285+
final List<ScriptedBlockPlugin> plugins = initBlockFactory();
286+
int numberOfShards = between(2, 5);
287+
AtomicBoolean failed = new AtomicBoolean();
288+
CountDownLatch queryLatch = new CountDownLatch(1);
289+
CountDownLatch cancelledLatch = new CountDownLatch(1);
290+
for (ScriptedBlockPlugin plugin : plugins) {
291+
plugin.disableBlock();
292+
plugin.setBeforeExecution(() -> {
293+
try {
294+
queryLatch.await(); // block the query until we get a search task
295+
} catch (InterruptedException e) {
296+
throw new AssertionError(e);
297+
}
298+
if (failed.compareAndSet(false, true)) {
299+
throw new IllegalStateException("simulated");
300+
}
301+
try {
302+
cancelledLatch.await(); // block the query until the search is cancelled
303+
} catch (InterruptedException e) {
304+
throw new AssertionError(e);
305+
}
306+
});
307+
}
308+
createIndex("test", Settings.builder()
309+
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numberOfShards)
310+
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
311+
.build());
312+
indexTestData();
313+
Thread searchThread = new Thread(() -> {
314+
expectThrows(Exception.class, () -> {
315+
client().prepareSearch("test")
316+
.setSearchType(SearchType.QUERY_THEN_FETCH)
317+
.setQuery(scriptQuery(new Script(ScriptType.INLINE, "mockscript", SCRIPT_NAME, Collections.emptyMap())))
318+
.setAllowPartialSearchResults(false).setSize(1000).get();
319+
});
320+
});
321+
searchThread.start();
322+
try {
323+
assertBusy(() -> assertThat(getSearchTasks(), hasSize(1)));
324+
queryLatch.countDown();
325+
assertBusy(() -> {
326+
final List<SearchTask> searchTasks = getSearchTasks();
327+
assertThat(searchTasks, hasSize(1));
328+
assertTrue(searchTasks.get(0).isCancelled());
329+
}, 30, TimeUnit.SECONDS);
330+
} finally {
331+
for (ScriptedBlockPlugin plugin : plugins) {
332+
plugin.setBeforeExecution(() -> {});
333+
}
334+
cancelledLatch.countDown();
335+
searchThread.join();
336+
}
337+
}
338+
339+
List<SearchTask> getSearchTasks() {
340+
List<SearchTask> tasks = new ArrayList<>();
341+
for (String nodeName : internalCluster().getNodeNames()) {
342+
TransportService transportService = internalCluster().getInstance(TransportService.class, nodeName);
343+
for (Task task : transportService.getTaskManager().getCancellableTasks().values()) {
344+
if (task.getAction().equals(SearchAction.NAME)) {
345+
tasks.add((SearchTask) task);
346+
}
347+
}
348+
}
349+
return tasks;
350+
}
351+
276352
public static class ScriptedBlockPlugin extends MockScriptPlugin {
277353
static final String SCRIPT_NAME = "search_block";
278354

279355
private final AtomicInteger hits = new AtomicInteger();
280356

281357
private final AtomicBoolean shouldBlock = new AtomicBoolean(true);
282358

359+
private final AtomicReference<Runnable> beforeExecution = new AtomicReference<>();
360+
283361
public void reset() {
284362
hits.set(0);
285363
}
@@ -292,9 +370,17 @@ public void enableBlock() {
292370
shouldBlock.set(true);
293371
}
294372

373+
public void setBeforeExecution(Runnable runnable) {
374+
beforeExecution.set(runnable);
375+
}
376+
295377
@Override
296378
public Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
297379
return Collections.singletonMap(SCRIPT_NAME, params -> {
380+
final Runnable runnable = beforeExecution.get();
381+
if (runnable != null) {
382+
runnable.run();
383+
}
298384
LeafFieldsLookup fieldsLookup = (LeafFieldsLookup) params.get("_fields");
299385
LogManager.getLogger(SearchCancellationIT.class).info("Blocking on the document {}", fieldsLookup.get("_id"));
300386
hits.incrementAndGet();

server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
9999
private final int maxConcurrentRequestsPerNode;
100100
private final Map<String, PendingExecutions> pendingExecutionsPerNode = new ConcurrentHashMap<>();
101101
private final boolean throttleConcurrentRequests;
102+
private final AtomicBoolean requestCancelled = new AtomicBoolean();
102103

103104
private final List<Releasable> releasables = new ArrayList<>();
104105

@@ -393,6 +394,15 @@ private void onShardFailure(final int shardIndex, @Nullable SearchShardTarget sh
393394
logger.debug(() -> new ParameterizedMessage("{}: Failed to execute [{}] lastShard [{}]",
394395
shard != null ? shard : shardIt.shardId(), request, lastShard), e);
395396
if (lastShard) {
397+
if (request.allowPartialSearchResults() == false) {
398+
if (requestCancelled.compareAndSet(false, true)) {
399+
try {
400+
searchTransportService.cancelSearchTask(task, "partial results are not allowed and at least one shard has failed");
401+
} catch (Exception cancelFailure) {
402+
logger.debug("Failed to cancel search request", cancelFailure);
403+
}
404+
}
405+
}
396406
onShardGroupFailure(shardIndex, shard, e);
397407
}
398408
final int totalOps = this.totalOps.incrementAndGet();

server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,12 @@
2424
import org.elasticsearch.action.ActionListenerResponseHandler;
2525
import org.elasticsearch.action.IndicesRequest;
2626
import org.elasticsearch.action.OriginalIndices;
27+
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
28+
import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskAction;
2729
import org.elasticsearch.action.support.ChannelActionListener;
2830
import org.elasticsearch.action.support.IndicesOptions;
31+
import org.elasticsearch.client.OriginSettingClient;
32+
import org.elasticsearch.client.node.NodeClient;
2933
import org.elasticsearch.cluster.node.DiscoveryNode;
3034
import org.elasticsearch.common.Nullable;
3135
import org.elasticsearch.common.io.stream.StreamInput;
@@ -46,6 +50,7 @@
4650
import org.elasticsearch.search.query.QuerySearchRequest;
4751
import org.elasticsearch.search.query.QuerySearchResult;
4852
import org.elasticsearch.search.query.ScrollQuerySearchResult;
53+
import org.elasticsearch.tasks.TaskId;
4954
import org.elasticsearch.threadpool.ThreadPool;
5055
import org.elasticsearch.transport.RemoteClusterService;
5156
import org.elasticsearch.transport.Transport;
@@ -81,12 +86,14 @@ public class SearchTransportService {
8186
public static final String QUERY_CAN_MATCH_NAME = "indices:data/read/search[can_match]";
8287

8388
private final TransportService transportService;
89+
private final NodeClient client;
8490
private final BiFunction<Transport.Connection, SearchActionListener, ActionListener> responseWrapper;
8591
private final Map<String, Long> clientConnections = ConcurrentCollections.newConcurrentMapWithAggressiveConcurrency();
8692

87-
public SearchTransportService(TransportService transportService,
93+
public SearchTransportService(TransportService transportService, NodeClient client,
8894
BiFunction<Transport.Connection, SearchActionListener, ActionListener> responseWrapper) {
8995
this.transportService = transportService;
96+
this.client = client;
9097
this.responseWrapper = responseWrapper;
9198
}
9299

@@ -423,4 +430,12 @@ private boolean assertNodePresent() {
423430
return true;
424431
}
425432
}
433+
434+
public void cancelSearchTask(SearchTask task, String reason) {
435+
CancelTasksRequest req = new CancelTasksRequest()
436+
.setTaskId(new TaskId(client.getLocalNodeId(), task.getId()))
437+
.setReason("Fatal failure during search: " + reason);
438+
// force the origin to execute the cancellation as a system user
439+
new OriginSettingClient(client, GetTaskAction.TASKS_ORIGIN).admin().cluster().cancelTasks(req, ActionListener.wrap(() -> {}));
440+
}
426441
}

server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,13 @@
2121

2222
import org.elasticsearch.action.ActionListener;
2323
import org.elasticsearch.action.OriginalIndices;
24-
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
2524
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup;
2625
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsRequest;
2726
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
2827
import org.elasticsearch.action.support.ActionFilters;
2928
import org.elasticsearch.action.support.HandledTransportAction;
3029
import org.elasticsearch.action.support.IndicesOptions;
3130
import org.elasticsearch.client.Client;
32-
import org.elasticsearch.client.OriginSettingClient;
33-
import org.elasticsearch.client.node.NodeClient;
3431
import org.elasticsearch.cluster.ClusterState;
3532
import org.elasticsearch.cluster.block.ClusterBlockException;
3633
import org.elasticsearch.cluster.block.ClusterBlockLevel;
@@ -69,7 +66,6 @@
6966
import org.elasticsearch.search.profile.ProfileShardResult;
7067
import org.elasticsearch.search.profile.SearchProfileShardResults;
7168
import org.elasticsearch.tasks.Task;
72-
import org.elasticsearch.tasks.TaskId;
7369
import org.elasticsearch.threadpool.ThreadPool;
7470
import org.elasticsearch.transport.RemoteClusterAware;
7571
import org.elasticsearch.transport.RemoteClusterService;
@@ -97,7 +93,6 @@
9793
import java.util.stream.Collectors;
9894
import java.util.stream.StreamSupport;
9995

100-
import static org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN;
10196
import static org.elasticsearch.action.search.SearchType.DFS_QUERY_THEN_FETCH;
10297
import static org.elasticsearch.action.search.SearchType.QUERY_THEN_FETCH;
10398
import static org.elasticsearch.search.sort.FieldSortBuilder.hasPrimaryFieldSort;
@@ -108,7 +103,6 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
108103
public static final Setting<Long> SHARD_COUNT_LIMIT_SETTING = Setting.longSetting(
109104
"action.search.shard_count.limit", Long.MAX_VALUE, 1L, Property.Dynamic, Property.NodeScope);
110105

111-
private final NodeClient client;
112106
private final ThreadPool threadPool;
113107
private final ClusterService clusterService;
114108
private final SearchTransportService searchTransportService;
@@ -120,8 +114,7 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
120114
private final CircuitBreaker circuitBreaker;
121115

122116
@Inject
123-
public TransportSearchAction(NodeClient client,
124-
ThreadPool threadPool,
117+
public TransportSearchAction(ThreadPool threadPool,
125118
CircuitBreakerService circuitBreakerService,
126119
TransportService transportService,
127120
SearchService searchService,
@@ -132,7 +125,6 @@ public TransportSearchAction(NodeClient client,
132125
IndexNameExpressionResolver indexNameExpressionResolver,
133126
NamedWriteableRegistry namedWriteableRegistry) {
134127
super(SearchAction.NAME, transportService, actionFilters, (Writeable.Reader<SearchRequest>) SearchRequest::new);
135-
this.client = client;
136128
this.threadPool = threadPool;
137129
this.circuitBreaker = circuitBreakerService.getBreaker(CircuitBreaker.REQUEST);
138130
this.searchPhaseController = searchPhaseController;
@@ -802,7 +794,8 @@ public void run() {
802794
}, clusters);
803795
} else {
804796
final QueryPhaseResultConsumer queryResultConsumer = searchPhaseController.newSearchPhaseResults(executor,
805-
circuitBreaker, task.getProgressListener(), searchRequest, shardIterators.size(), exc -> cancelTask(task, exc));
797+
circuitBreaker, task.getProgressListener(), searchRequest, shardIterators.size(),
798+
exc -> searchTransportService.cancelSearchTask(task, "failed to merge result [" + exc.getMessage() + "]"));
806799
AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction;
807800
switch (searchRequest.searchType()) {
808801
case DFS_QUERY_THEN_FETCH:
@@ -822,15 +815,6 @@ public void run() {
822815
}
823816
}
824817

825-
private void cancelTask(SearchTask task, Exception exc) {
826-
String errorMsg = exc.getMessage() != null ? exc.getMessage() : "";
827-
CancelTasksRequest req = new CancelTasksRequest()
828-
.setTaskId(new TaskId(client.getLocalNodeId(), task.getId()))
829-
.setReason("Fatal failure during search: " + errorMsg);
830-
// force the origin to execute the cancellation as a system user
831-
new OriginSettingClient(client, TASKS_ORIGIN).admin().cluster().cancelTasks(req, ActionListener.wrap(() -> {}));
832-
}
833-
834818
private static void failIfOverShardCountLimit(ClusterService clusterService, int shardCount) {
835819
final long shardCountLimit = clusterService.getClusterSettings().get(SHARD_COUNT_LIMIT_SETTING);
836820
if (shardCount > shardCountLimit) {

server/src/main/java/org/elasticsearch/node/Node.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ protected Node(final Environment initialEnvironment,
585585
networkModule.getTransportInterceptor(), localNodeFactory, settingsModule.getClusterSettings(), taskHeaders);
586586
final GatewayMetaState gatewayMetaState = new GatewayMetaState();
587587
final ResponseCollectorService responseCollectorService = new ResponseCollectorService(clusterService);
588-
final SearchTransportService searchTransportService = new SearchTransportService(transportService,
588+
final SearchTransportService searchTransportService = new SearchTransportService(transportService, client,
589589
SearchExecutionStatsCollector.makeWrapper(responseCollectorService));
590590
final HttpServerTransport httpServerTransport = newHttpTransport(networkModule);
591591
final IndexingPressure indexingLimits = new IndexingPressure(settings);

server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public void testFilterShards() throws InterruptedException {
7272
final boolean shard1 = randomBoolean();
7373
final boolean shard2 = randomBoolean();
7474

75-
SearchTransportService searchTransportService = new SearchTransportService(null, null) {
75+
SearchTransportService searchTransportService = new SearchTransportService(null, null, null) {
7676
@Override
7777
public void sendCanMatch(Transport.Connection connection, ShardSearchRequest request, SearchTask task,
7878
ActionListener<SearchService.CanMatchResponse> listener) {
@@ -129,7 +129,7 @@ public void testFilterWithFailure() throws InterruptedException {
129129
lookup.put("node1", new SearchAsyncActionTests.MockConnection(primaryNode));
130130
lookup.put("node2", new SearchAsyncActionTests.MockConnection(replicaNode));
131131
final boolean shard1 = randomBoolean();
132-
SearchTransportService searchTransportService = new SearchTransportService(null, null) {
132+
SearchTransportService searchTransportService = new SearchTransportService(null, null, null) {
133133
@Override
134134
public void sendCanMatch(Transport.Connection connection, ShardSearchRequest request, SearchTask task,
135135
ActionListener<SearchService.CanMatchResponse> listener) {
@@ -195,7 +195,7 @@ public void testLotsOfShards() throws InterruptedException {
195195

196196

197197
final SearchTransportService searchTransportService =
198-
new SearchTransportService(null, null) {
198+
new SearchTransportService(null, null, null) {
199199
@Override
200200
public void sendCanMatch(
201201
Transport.Connection connection,
@@ -213,7 +213,7 @@ public void sendCanMatch(
213213
final ExecutorService executor = Executors.newFixedThreadPool(randomIntBetween(1, Runtime.getRuntime().availableProcessors()));
214214
final SearchRequest searchRequest = new SearchRequest();
215215
searchRequest.allowPartialSearchResults(true);
216-
SearchTransportService transportService = new SearchTransportService(null, null);
216+
SearchTransportService transportService = new SearchTransportService(null, null, null);
217217
ActionListener<SearchResponse> responseListener = ActionListener.wrap(response -> {},
218218
(e) -> { throw new AssertionError("unexpected", e);});
219219
Map<String, AliasFilter> aliasFilters = Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY));
@@ -296,7 +296,7 @@ public void testSortShards() throws InterruptedException {
296296
List<MinAndMax<?>> minAndMaxes = new ArrayList<>();
297297
Set<ShardId> shardToSkip = new HashSet<>();
298298

299-
SearchTransportService searchTransportService = new SearchTransportService(null, null) {
299+
SearchTransportService searchTransportService = new SearchTransportService(null, null, null) {
300300
@Override
301301
public void sendCanMatch(Transport.Connection connection, ShardSearchRequest request, SearchTask task,
302302
ActionListener<SearchService.CanMatchResponse> listener) {
@@ -369,7 +369,7 @@ public void testInvalidSortShards() throws InterruptedException {
369369
List<ShardId> shardIds = new ArrayList<>();
370370
Set<ShardId> shardToSkip = new HashSet<>();
371371

372-
SearchTransportService searchTransportService = new SearchTransportService(null, null) {
372+
SearchTransportService searchTransportService = new SearchTransportService(null, null, null) {
373373
@Override
374374
public void sendCanMatch(Transport.Connection connection, ShardSearchRequest request, SearchTask task,
375375
ActionListener<SearchService.CanMatchResponse> listener) {

0 commit comments

Comments
 (0)