From fc203bfb4ee7987a649a04fa0de7eddf71ca79f0 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Thu, 3 Dec 2020 14:34:30 -0500 Subject: [PATCH 1/4] Cancel proxy requests when the proxy channel closes --- .../search/ccs/CrossClusterSearchIT.java | 14 ++++++- .../action/search/SearchTransportService.java | 22 +++++----- .../transport/TransportActionProxy.java | 40 +++++++++++++++++-- .../transport/TransportActionProxyTests.java | 20 +++++++--- .../ClearCcrRestoreSessionAction.java | 2 +- .../GetCcrRestoreFileChunkAction.java | 2 +- .../TransportOpenPointInTimeAction.java | 1 + 7 files changed, 77 insertions(+), 24 deletions(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java index 0def1c8f6676f..b60d45a5be5be 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java @@ -35,10 +35,12 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.test.AbstractMultiClustersTestCase; import org.elasticsearch.test.InternalTestCluster; import org.elasticsearch.test.NodeRoles; import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; +import org.elasticsearch.transport.TransportService; import org.junit.Before; import java.util.Collection; @@ -122,11 +124,21 @@ public void testProxyConnectionDisconnect() throws Exception { PlainActionFuture future = new PlainActionFuture<>(); SearchRequest searchRequest = new SearchRequest("demo", "cluster_a:prod"); searchRequest.allowPartialSearchResults(false); - searchRequest.setCcsMinimizeRoundtrips(false); + searchRequest.setCcsMinimizeRoundtrips(false);CrossClusterSearchIT.java:24:8 searchRequest.source(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(1000)); client(LOCAL_CLUSTER).search(searchRequest, future); SearchListenerPlugin.waitSearchStarted(); disconnectFromRemoteClusters(); + // Cancellable tasks on the remote cluster should be cancelled + assertBusy(() -> { + final Iterable transportServices = cluster("cluster_a").getInstances(TransportService.class); + for (TransportService transportService : transportServices) { + Collection cancellableTasks = transportService.getTaskManager().getCancellableTasks().values(); + for (CancellableTask cancellableTask : cancellableTasks) { + assertTrue(cancellableTask.getDescription(), cancellableTask.isCancelled()); + } + } + }); assertBusy(() -> assertTrue(future.isDone())); configureAndConnectsToRemoteClusters(); } finally { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java index fb5980dda03b7..b069ad5d54664 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java @@ -297,20 +297,20 @@ public static void registerRequestHandler(TransportService transportService, Sea boolean freed = searchService.freeReaderContext(request.id()); channel.sendResponse(new SearchFreeContextResponse(freed)); }); - TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_SCROLL_ACTION_NAME, SearchFreeContextResponse::new); + TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_SCROLL_ACTION_NAME, false, SearchFreeContextResponse::new); transportService.registerRequestHandler(FREE_CONTEXT_ACTION_NAME, ThreadPool.Names.SAME, SearchFreeContextRequest::new, (request, channel, task) -> { boolean freed = searchService.freeReaderContext(request.id()); channel.sendResponse(new SearchFreeContextResponse(freed)); }); - TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_ACTION_NAME, SearchFreeContextResponse::new); + TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_ACTION_NAME, false, SearchFreeContextResponse::new); transportService.registerRequestHandler(CLEAR_SCROLL_CONTEXTS_ACTION_NAME, ThreadPool.Names.SAME, TransportRequest.Empty::new, (request, channel, task) -> { searchService.freeAllScrollContexts(); channel.sendResponse(TransportResponse.Empty.INSTANCE); }); - TransportActionProxy.registerProxyAction(transportService, CLEAR_SCROLL_CONTEXTS_ACTION_NAME, + TransportActionProxy.registerProxyAction(transportService, CLEAR_SCROLL_CONTEXTS_ACTION_NAME, false, (in) -> TransportResponse.Empty.INSTANCE); transportService.registerRequestHandler(DFS_ACTION_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new, @@ -319,14 +319,14 @@ public static void registerRequestHandler(TransportService transportService, Sea new ChannelActionListener<>(channel, DFS_ACTION_NAME, request)) ); - TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, DfsSearchResult::new); + TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, true, DfsSearchResult::new); transportService.registerRequestHandler(QUERY_ACTION_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new, (request, channel, task) -> { searchService.executeQueryPhase(request, keepStatesInContext(channel.getVersion()), (SearchShardTask) task, new ChannelActionListener<>(channel, QUERY_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyActionWithDynamicResponseType(transportService, QUERY_ACTION_NAME, + TransportActionProxy.registerProxyActionWithDynamicResponseType(transportService, QUERY_ACTION_NAME, true, (request) -> ((ShardSearchRequest)request).numberOfShards() == 1 ? QueryFetchSearchResult::new : QuerySearchResult::new); transportService.registerRequestHandler(QUERY_ID_ACTION_NAME, ThreadPool.Names.SAME, QuerySearchRequest::new, @@ -334,42 +334,42 @@ public static void registerRequestHandler(TransportService transportService, Sea searchService.executeQueryPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel, QUERY_ID_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, QUERY_ID_ACTION_NAME, QuerySearchResult::new); + TransportActionProxy.registerProxyAction(transportService, QUERY_ID_ACTION_NAME, true, QuerySearchResult::new); transportService.registerRequestHandler(QUERY_SCROLL_ACTION_NAME, ThreadPool.Names.SAME, InternalScrollSearchRequest::new, (request, channel, task) -> { searchService.executeQueryPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel, QUERY_SCROLL_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, QUERY_SCROLL_ACTION_NAME, ScrollQuerySearchResult::new); + TransportActionProxy.registerProxyAction(transportService, QUERY_SCROLL_ACTION_NAME, true, ScrollQuerySearchResult::new); transportService.registerRequestHandler(QUERY_FETCH_SCROLL_ACTION_NAME, ThreadPool.Names.SAME, InternalScrollSearchRequest::new, (request, channel, task) -> { searchService.executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel, QUERY_FETCH_SCROLL_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, QUERY_FETCH_SCROLL_ACTION_NAME, ScrollQueryFetchSearchResult::new); + TransportActionProxy.registerProxyAction(transportService, QUERY_FETCH_SCROLL_ACTION_NAME, true, ScrollQueryFetchSearchResult::new); transportService.registerRequestHandler(FETCH_ID_SCROLL_ACTION_NAME, ThreadPool.Names.SAME, ShardFetchRequest::new, (request, channel, task) -> { searchService.executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel, FETCH_ID_SCROLL_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, FETCH_ID_SCROLL_ACTION_NAME, FetchSearchResult::new); + TransportActionProxy.registerProxyAction(transportService, FETCH_ID_SCROLL_ACTION_NAME, true, FetchSearchResult::new); transportService.registerRequestHandler(FETCH_ID_ACTION_NAME, ThreadPool.Names.SAME, true, true, ShardFetchSearchRequest::new, (request, channel, task) -> { searchService.executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel, FETCH_ID_ACTION_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, FETCH_ID_ACTION_NAME, FetchSearchResult::new); + TransportActionProxy.registerProxyAction(transportService, FETCH_ID_ACTION_NAME, true, FetchSearchResult::new); // this is cheap, it does not fetch during the rewrite phase, so we can let it quickly execute on a networking thread transportService.registerRequestHandler(QUERY_CAN_MATCH_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new, (request, channel, task) -> { searchService.canMatch(request, new ChannelActionListener<>(channel, QUERY_CAN_MATCH_NAME, request)); }); - TransportActionProxy.registerProxyAction(transportService, QUERY_CAN_MATCH_NAME, SearchService.CanMatchResponse::new); + TransportActionProxy.registerProxyAction(transportService, QUERY_CAN_MATCH_NAME, true, SearchService.CanMatchResponse::new); } diff --git a/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java b/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java index beb5cfb8084d1..f5f6376bdb7cc 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java @@ -22,11 +22,14 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.Map; import java.util.function.Function; /** @@ -55,6 +58,8 @@ private static class ProxyRequestHandler implements Tran public void messageReceived(T request, TransportChannel channel, Task task) throws Exception { DiscoveryNode targetNode = request.targetNode; TransportRequest wrappedRequest = request.wrapped; + TaskId taskId = task.taskInfo(service.localNode.getId(), false).getTaskId(); + wrappedRequest.setParentTask(taskId); service.sendRequest(targetNode, action, wrappedRequest, new ProxyResponseHandler<>(channel, responseFunction.apply(wrappedRequest))); } @@ -117,27 +122,54 @@ public void writeTo(StreamOutput out) throws IOException { } } + private static class CancellableProxyRequest extends ProxyRequest { + CancellableProxyRequest(StreamInput in, Writeable.Reader reader) throws IOException { + super(in, reader); + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new CancellableTask(id, type, action, "", parentTaskId, headers) { + @Override + public boolean shouldCancelChildrenOnCancellation() { + return true; + } + + @Override + public String getDescription() { + return "proxy task [" + wrapped.getDescription() + "]"; + } + }; + } + } + /** * Registers a proxy request handler that allows to forward requests for the given action to another node. To be used when the * response type changes based on the upcoming request (quite rare) */ - public static void registerProxyActionWithDynamicResponseType(TransportService service, String action, + public static void registerProxyActionWithDynamicResponseType(TransportService service, String action, boolean cancellable, Function> responseFunction) { RequestHandlerRegistry requestHandler = service.getRequestHandler(action); service.registerRequestHandler(getProxyAction(action), ThreadPool.Names.SAME, true, false, - in -> new ProxyRequest<>(in, requestHandler::newRequest), new ProxyRequestHandler<>(service, action, responseFunction)); + in -> cancellable ? + new CancellableProxyRequest<>(in, requestHandler::newRequest) : + new ProxyRequest<>(in, requestHandler::newRequest), + new ProxyRequestHandler<>(service, action, responseFunction)); } /** * Registers a proxy request handler that allows to forward requests for the given action to another node. To be used when the * response type is always the same (most of the cases). */ - public static void registerProxyAction(TransportService service, String action, + public static void registerProxyAction(TransportService service, String action, boolean cancellable, Writeable.Reader reader) { RequestHandlerRegistry requestHandler = service.getRequestHandler(action); service.registerRequestHandler(getProxyAction(action), ThreadPool.Names.SAME, true, false, - in -> new ProxyRequest<>(in, requestHandler::newRequest), new ProxyRequestHandler<>(service, action, request -> reader)); + in -> cancellable ? + new CancellableProxyRequest<>(in, requestHandler::newRequest) : + new ProxyRequest<>(in, requestHandler::newRequest), + new ProxyRequestHandler<>(service, action, request -> reader)); } private static final String PROXY_ACTION_PREFIX = "internal:transport/proxy/"; diff --git a/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java b/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java index 0e71b8427b712..6fe797d931b24 100644 --- a/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.internal.io.IOUtils; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.TestThreadPool; @@ -35,6 +36,8 @@ import java.io.IOException; import java.util.concurrent.CountDownLatch; +import static org.hamcrest.Matchers.equalTo; + public class TransportActionProxyTests extends ESTestCase { protected ThreadPool threadPool; // we use always a non-alpha or beta version here otherwise minimumCompatibilityVersion will be different for the two used versions @@ -89,24 +92,29 @@ public void testSendMessage() throws InterruptedException { SimpleTestResponse response = new SimpleTestResponse("TS_A"); channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceA, "internal:test", SimpleTestResponse::new); + final boolean cancellable = randomBoolean(); + TransportActionProxy.registerProxyAction(serviceA, "internal:test", cancellable, SimpleTestResponse::new); AbstractSimpleTransportTestCase.connectToNode(serviceA, nodeB); serviceB.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> { + assertThat(task instanceof CancellableTask, equalTo(cancellable)); assertEquals(request.sourceNode, "TS_A"); SimpleTestResponse response = new SimpleTestResponse("TS_B"); channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceB, "internal:test", SimpleTestResponse::new); + final boolean cancellableB = randomBoolean(); + TransportActionProxy.registerProxyAction(serviceB, "internal:test", cancellableB, SimpleTestResponse::new); AbstractSimpleTransportTestCase.connectToNode(serviceB, nodeC); serviceC.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> { + assertThat(task instanceof CancellableTask, equalTo(cancellable)); assertEquals(request.sourceNode, "TS_A"); SimpleTestResponse response = new SimpleTestResponse("TS_C"); channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceC, "internal:test", SimpleTestResponse::new); + + TransportActionProxy.registerProxyAction(serviceC, "internal:test", randomBoolean(), SimpleTestResponse::new); CountDownLatch latch = new CountDownLatch(1); serviceA.sendRequest(nodeB, TransportActionProxy.getProxyAction("internal:test"), TransportActionProxy.wrapRequest(nodeC, @@ -144,7 +152,7 @@ public void testException() throws InterruptedException { SimpleTestResponse response = new SimpleTestResponse("TS_A"); channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceA, "internal:test", SimpleTestResponse::new); + TransportActionProxy.registerProxyAction(serviceA, "internal:test", randomBoolean(), SimpleTestResponse::new); AbstractSimpleTransportTestCase.connectToNode(serviceA, nodeB); serviceB.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, @@ -153,13 +161,13 @@ public void testException() throws InterruptedException { SimpleTestResponse response = new SimpleTestResponse("TS_B"); channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceB, "internal:test", SimpleTestResponse::new); + TransportActionProxy.registerProxyAction(serviceB, "internal:test", randomBoolean(), SimpleTestResponse::new); AbstractSimpleTransportTestCase.connectToNode(serviceB, nodeC); serviceC.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> { throw new ElasticsearchException("greetings from TS_C"); }); - TransportActionProxy.registerProxyAction(serviceC, "internal:test", SimpleTestResponse::new); + TransportActionProxy.registerProxyAction(serviceC, "internal:test", randomBoolean(), SimpleTestResponse::new); CountDownLatch latch = new CountDownLatch(1); serviceA.sendRequest(nodeB, TransportActionProxy.getProxyAction("internal:test"), TransportActionProxy.wrapRequest(nodeC, diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/ClearCcrRestoreSessionAction.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/ClearCcrRestoreSessionAction.java index e0ed85883df93..7484fc92e81a6 100644 --- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/ClearCcrRestoreSessionAction.java +++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/ClearCcrRestoreSessionAction.java @@ -36,7 +36,7 @@ public static class TransportDeleteCcrRestoreSessionAction public TransportDeleteCcrRestoreSessionAction(ActionFilters actionFilters, TransportService transportService, CcrRestoreSourceService ccrRestoreService) { super(NAME, transportService, actionFilters, ClearCcrRestoreSessionRequest::new, ThreadPool.Names.GENERIC); - TransportActionProxy.registerProxyAction(transportService, NAME, in -> ActionResponse.Empty.INSTANCE); + TransportActionProxy.registerProxyAction(transportService, NAME, false, in -> ActionResponse.Empty.INSTANCE); this.ccrRestoreService = ccrRestoreService; } diff --git a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/GetCcrRestoreFileChunkAction.java b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/GetCcrRestoreFileChunkAction.java index b858531d7614f..96f795cfc5f95 100644 --- a/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/GetCcrRestoreFileChunkAction.java +++ b/x-pack/plugin/ccr/src/main/java/org/elasticsearch/xpack/ccr/action/repositories/GetCcrRestoreFileChunkAction.java @@ -45,7 +45,7 @@ public static class TransportGetCcrRestoreFileChunkAction public TransportGetCcrRestoreFileChunkAction(BigArrays bigArrays, TransportService transportService, ActionFilters actionFilters, CcrRestoreSourceService restoreSourceService) { super(NAME, transportService, actionFilters, GetCcrRestoreFileChunkRequest::new, ThreadPool.Names.GENERIC); - TransportActionProxy.registerProxyAction(transportService, NAME, GetCcrRestoreFileChunkResponse::new); + TransportActionProxy.registerProxyAction(transportService, NAME, false, GetCcrRestoreFileChunkResponse::new); this.restoreSourceService = restoreSourceService; this.bigArrays = bigArrays; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/action/TransportOpenPointInTimeAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/action/TransportOpenPointInTimeAction.java index 75ae1c01a5b5b..d65c1f3b8d256 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/action/TransportOpenPointInTimeAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/search/action/TransportOpenPointInTimeAction.java @@ -61,6 +61,7 @@ public TransportOpenPointInTimeAction( TransportActionProxy.registerProxyAction( transportService, OPEN_SHARD_READER_CONTEXT_NAME, + false, TransportOpenPointInTimeAction.ShardOpenReaderResponse::new ); } From 5535bc301fc088a256d7f22bf0bc1ff9ea19df18 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Thu, 3 Dec 2020 16:06:10 -0500 Subject: [PATCH 2/4] oops --- .../java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java index b60d45a5be5be..e381bb936b221 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java @@ -124,7 +124,7 @@ public void testProxyConnectionDisconnect() throws Exception { PlainActionFuture future = new PlainActionFuture<>(); SearchRequest searchRequest = new SearchRequest("demo", "cluster_a:prod"); searchRequest.allowPartialSearchResults(false); - searchRequest.setCcsMinimizeRoundtrips(false);CrossClusterSearchIT.java:24:8 + searchRequest.setCcsMinimizeRoundtrips(false); searchRequest.source(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(1000)); client(LOCAL_CLUSTER).search(searchRequest, future); SearchListenerPlugin.waitSearchStarted(); From 0a52a5b9f695bdee2648cb197b3aa439586a6981 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Sat, 5 Dec 2020 15:32:33 -0500 Subject: [PATCH 3/4] add assertion --- .../elasticsearch/transport/TransportActionProxy.java | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java b/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java index f5f6376bdb7cc..20b291b038d1a 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java @@ -58,11 +58,20 @@ private static class ProxyRequestHandler implements Tran public void messageReceived(T request, TransportChannel channel, Task task) throws Exception { DiscoveryNode targetNode = request.targetNode; TransportRequest wrappedRequest = request.wrapped; + assert assertConsistentTaskType(task, wrappedRequest); TaskId taskId = task.taskInfo(service.localNode.getId(), false).getTaskId(); wrappedRequest.setParentTask(taskId); service.sendRequest(targetNode, action, wrappedRequest, new ProxyResponseHandler<>(channel, responseFunction.apply(wrappedRequest))); } + + private boolean assertConsistentTaskType(Task proxyTask, TransportRequest wrapped) { + final Task targetTask = wrapped.createTask(0, proxyTask.getType(), proxyTask.getAction(), TaskId.EMPTY_TASK_ID, Map.of()); + assert targetTask instanceof CancellableTask == proxyTask instanceof CancellableTask : + "Cancellable property of proxy action [" + proxyTask.getAction() + "] is configured inconsistently: " + + "expected [" + (targetTask instanceof CancellableTask) + "] actual [" + (proxyTask instanceof CancellableTask) + "]"; + return true; + } } private static class ProxyResponseHandler implements TransportResponseHandler { From 94490b4dc31eb8dcd518d3f862f2390c82602e4c Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Sat, 5 Dec 2020 17:40:33 -0500 Subject: [PATCH 4/4] fix test --- .../transport/TransportActionProxyTests.java | 42 ++++++++++++++----- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java b/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java index 6fe797d931b24..99089d60fa0ba 100644 --- a/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TransportActionProxyTests.java @@ -27,6 +27,8 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.TestThreadPool; @@ -34,6 +36,7 @@ import org.junit.Before; import java.io.IOException; +import java.util.Map; import java.util.concurrent.CountDownLatch; import static org.hamcrest.Matchers.equalTo; @@ -103,8 +106,7 @@ public void testSendMessage() throws InterruptedException { SimpleTestResponse response = new SimpleTestResponse("TS_B"); channel.sendResponse(response); }); - final boolean cancellableB = randomBoolean(); - TransportActionProxy.registerProxyAction(serviceB, "internal:test", cancellableB, SimpleTestResponse::new); + TransportActionProxy.registerProxyAction(serviceB, "internal:test", cancellable, SimpleTestResponse::new); AbstractSimpleTransportTestCase.connectToNode(serviceB, nodeC); serviceC.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> { @@ -114,11 +116,11 @@ public void testSendMessage() throws InterruptedException { channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceC, "internal:test", randomBoolean(), SimpleTestResponse::new); + TransportActionProxy.registerProxyAction(serviceC, "internal:test", cancellable, SimpleTestResponse::new); CountDownLatch latch = new CountDownLatch(1); serviceA.sendRequest(nodeB, TransportActionProxy.getProxyAction("internal:test"), TransportActionProxy.wrapRequest(nodeC, - new SimpleTestRequest("TS_A")), new TransportResponseHandler() { + new SimpleTestRequest("TS_A", cancellable)), new TransportResponseHandler() { @Override public SimpleTestResponse read(StreamInput in) throws IOException { return new SimpleTestResponse(in); @@ -146,13 +148,14 @@ public void handleException(TransportException exp) { } public void testException() throws InterruptedException { + boolean cancellable = randomBoolean(); serviceA.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> { assertEquals(request.sourceNode, "TS_A"); SimpleTestResponse response = new SimpleTestResponse("TS_A"); channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceA, "internal:test", randomBoolean(), SimpleTestResponse::new); + TransportActionProxy.registerProxyAction(serviceA, "internal:test", cancellable, SimpleTestResponse::new); AbstractSimpleTransportTestCase.connectToNode(serviceA, nodeB); serviceB.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, @@ -161,17 +164,17 @@ public void testException() throws InterruptedException { SimpleTestResponse response = new SimpleTestResponse("TS_B"); channel.sendResponse(response); }); - TransportActionProxy.registerProxyAction(serviceB, "internal:test", randomBoolean(), SimpleTestResponse::new); + TransportActionProxy.registerProxyAction(serviceB, "internal:test", cancellable, SimpleTestResponse::new); AbstractSimpleTransportTestCase.connectToNode(serviceB, nodeC); serviceC.registerRequestHandler("internal:test", ThreadPool.Names.SAME, SimpleTestRequest::new, (request, channel, task) -> { throw new ElasticsearchException("greetings from TS_C"); }); - TransportActionProxy.registerProxyAction(serviceC, "internal:test", randomBoolean(), SimpleTestResponse::new); + TransportActionProxy.registerProxyAction(serviceC, "internal:test", cancellable, SimpleTestResponse::new); CountDownLatch latch = new CountDownLatch(1); serviceA.sendRequest(nodeB, TransportActionProxy.getProxyAction("internal:test"), TransportActionProxy.wrapRequest(nodeC, - new SimpleTestRequest("TS_A")), new TransportResponseHandler() { + new SimpleTestRequest("TS_A", cancellable)), new TransportResponseHandler() { @Override public SimpleTestResponse read(StreamInput in) throws IOException { return new SimpleTestResponse(in); @@ -200,22 +203,39 @@ public void handleException(TransportException exp) { } public static class SimpleTestRequest extends TransportRequest { - String sourceNode; + final boolean cancellable; + final String sourceNode; - public SimpleTestRequest(String sourceNode) { + public SimpleTestRequest(String sourceNode, boolean cancellable) { this.sourceNode = sourceNode; + this.cancellable = cancellable; } - public SimpleTestRequest() {} public SimpleTestRequest(StreamInput in) throws IOException { super(in); sourceNode = in.readString(); + cancellable = in.readBoolean(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(sourceNode); + out.writeBoolean(cancellable); + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + if (cancellable) { + return new CancellableTask(id, type, action, "", parentTaskId, headers) { + @Override + public boolean shouldCancelChildrenOnCancellation() { + return randomBoolean(); + } + }; + } else { + return super.createTask(id, type, action, parentTaskId, headers); + } } }