Skip to content

Commit

Permalink
Cancel proxy requests when the proxy channel closes (#65850)
Browse files Browse the repository at this point in the history
Since #43332 and #56327 we cancel rest requests when the rest channel 
closes and transport requests when the transport channel closes. This
commit cancels proxy requests and its descendant requests when the proxy
channel closes. This change is also required to support cross-clusters
task cancellation.

Relates #43332
Relates #56327
  • Loading branch information
dnhatn authored Dec 8, 2020
1 parent 12c14ce commit 95936eb
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.test.AbstractMultiClustersTestCase;
import org.elasticsearch.test.InternalTestCluster;
import org.elasticsearch.test.NodeRoles;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
import org.elasticsearch.transport.TransportService;
import org.junit.Before;

import java.util.Collection;
Expand Down Expand Up @@ -127,6 +129,16 @@ public void testProxyConnectionDisconnect() throws Exception {
client(LOCAL_CLUSTER).search(searchRequest, future);
SearchListenerPlugin.waitSearchStarted();
disconnectFromRemoteClusters();
// Cancellable tasks on the remote cluster should be cancelled
assertBusy(() -> {
final Iterable<TransportService> transportServices = cluster("cluster_a").getInstances(TransportService.class);
for (TransportService transportService : transportServices) {
Collection<CancellableTask> cancellableTasks = transportService.getTaskManager().getCancellableTasks().values();
for (CancellableTask cancellableTask : cancellableTasks) {
assertTrue(cancellableTask.getDescription(), cancellableTask.isCancelled());
}
}
});
assertBusy(() -> assertTrue(future.isDone()));
configureAndConnectsToRemoteClusters();
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,20 +297,20 @@ public static void registerRequestHandler(TransportService transportService, Sea
boolean freed = searchService.freeReaderContext(request.id());
channel.sendResponse(new SearchFreeContextResponse(freed));
});
TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_SCROLL_ACTION_NAME, SearchFreeContextResponse::new);
TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_SCROLL_ACTION_NAME, false, SearchFreeContextResponse::new);
transportService.registerRequestHandler(FREE_CONTEXT_ACTION_NAME, ThreadPool.Names.SAME, SearchFreeContextRequest::new,
(request, channel, task) -> {
boolean freed = searchService.freeReaderContext(request.id());
channel.sendResponse(new SearchFreeContextResponse(freed));
});
TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_ACTION_NAME, SearchFreeContextResponse::new);
TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_ACTION_NAME, false, SearchFreeContextResponse::new);
transportService.registerRequestHandler(CLEAR_SCROLL_CONTEXTS_ACTION_NAME, ThreadPool.Names.SAME,
TransportRequest.Empty::new,
(request, channel, task) -> {
searchService.freeAllScrollContexts();
channel.sendResponse(TransportResponse.Empty.INSTANCE);
});
TransportActionProxy.registerProxyAction(transportService, CLEAR_SCROLL_CONTEXTS_ACTION_NAME,
TransportActionProxy.registerProxyAction(transportService, CLEAR_SCROLL_CONTEXTS_ACTION_NAME, false,
(in) -> TransportResponse.Empty.INSTANCE);

transportService.registerRequestHandler(DFS_ACTION_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new,
Expand All @@ -319,57 +319,57 @@ public static void registerRequestHandler(TransportService transportService, Sea
new ChannelActionListener<>(channel, DFS_ACTION_NAME, request))
);

TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, DfsSearchResult::new);
TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, true, DfsSearchResult::new);

transportService.registerRequestHandler(QUERY_ACTION_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new,
(request, channel, task) -> {
searchService.executeQueryPhase(request, keepStatesInContext(channel.getVersion()), (SearchShardTask) task,
new ChannelActionListener<>(channel, QUERY_ACTION_NAME, request));
});
TransportActionProxy.registerProxyActionWithDynamicResponseType(transportService, QUERY_ACTION_NAME,
TransportActionProxy.registerProxyActionWithDynamicResponseType(transportService, QUERY_ACTION_NAME, true,
(request) -> ((ShardSearchRequest)request).numberOfShards() == 1 ? QueryFetchSearchResult::new : QuerySearchResult::new);

transportService.registerRequestHandler(QUERY_ID_ACTION_NAME, ThreadPool.Names.SAME, QuerySearchRequest::new,
(request, channel, task) -> {
searchService.executeQueryPhase(request, (SearchShardTask) task,
new ChannelActionListener<>(channel, QUERY_ID_ACTION_NAME, request));
});
TransportActionProxy.registerProxyAction(transportService, QUERY_ID_ACTION_NAME, QuerySearchResult::new);
TransportActionProxy.registerProxyAction(transportService, QUERY_ID_ACTION_NAME, true, QuerySearchResult::new);

transportService.registerRequestHandler(QUERY_SCROLL_ACTION_NAME, ThreadPool.Names.SAME, InternalScrollSearchRequest::new,
(request, channel, task) -> {
searchService.executeQueryPhase(request, (SearchShardTask) task,
new ChannelActionListener<>(channel, QUERY_SCROLL_ACTION_NAME, request));
});
TransportActionProxy.registerProxyAction(transportService, QUERY_SCROLL_ACTION_NAME, ScrollQuerySearchResult::new);
TransportActionProxy.registerProxyAction(transportService, QUERY_SCROLL_ACTION_NAME, true, ScrollQuerySearchResult::new);

transportService.registerRequestHandler(QUERY_FETCH_SCROLL_ACTION_NAME, ThreadPool.Names.SAME, InternalScrollSearchRequest::new,
(request, channel, task) -> {
searchService.executeFetchPhase(request, (SearchShardTask) task,
new ChannelActionListener<>(channel, QUERY_FETCH_SCROLL_ACTION_NAME, request));
});
TransportActionProxy.registerProxyAction(transportService, QUERY_FETCH_SCROLL_ACTION_NAME, ScrollQueryFetchSearchResult::new);
TransportActionProxy.registerProxyAction(transportService, QUERY_FETCH_SCROLL_ACTION_NAME, true, ScrollQueryFetchSearchResult::new);

transportService.registerRequestHandler(FETCH_ID_SCROLL_ACTION_NAME, ThreadPool.Names.SAME, ShardFetchRequest::new,
(request, channel, task) -> {
searchService.executeFetchPhase(request, (SearchShardTask) task,
new ChannelActionListener<>(channel, FETCH_ID_SCROLL_ACTION_NAME, request));
});
TransportActionProxy.registerProxyAction(transportService, FETCH_ID_SCROLL_ACTION_NAME, FetchSearchResult::new);
TransportActionProxy.registerProxyAction(transportService, FETCH_ID_SCROLL_ACTION_NAME, true, FetchSearchResult::new);

transportService.registerRequestHandler(FETCH_ID_ACTION_NAME, ThreadPool.Names.SAME, true, true, ShardFetchSearchRequest::new,
(request, channel, task) -> {
searchService.executeFetchPhase(request, (SearchShardTask) task,
new ChannelActionListener<>(channel, FETCH_ID_ACTION_NAME, request));
});
TransportActionProxy.registerProxyAction(transportService, FETCH_ID_ACTION_NAME, FetchSearchResult::new);
TransportActionProxy.registerProxyAction(transportService, FETCH_ID_ACTION_NAME, true, FetchSearchResult::new);

// this is cheap, it does not fetch during the rewrite phase, so we can let it quickly execute on a networking thread
transportService.registerRequestHandler(QUERY_CAN_MATCH_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new,
(request, channel, task) -> {
searchService.canMatch(request, new ChannelActionListener<>(channel, QUERY_CAN_MATCH_NAME, request));
});
TransportActionProxy.registerProxyAction(transportService, QUERY_CAN_MATCH_NAME, SearchService.CanMatchResponse::new);
TransportActionProxy.registerProxyAction(transportService, QUERY_CAN_MATCH_NAME, true, SearchService.CanMatchResponse::new);
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;
import java.util.function.Function;

/**
Expand Down Expand Up @@ -55,9 +58,20 @@ private static class ProxyRequestHandler<T extends ProxyRequest> implements Tran
public void messageReceived(T request, TransportChannel channel, Task task) throws Exception {
DiscoveryNode targetNode = request.targetNode;
TransportRequest wrappedRequest = request.wrapped;
assert assertConsistentTaskType(task, wrappedRequest);
TaskId taskId = task.taskInfo(service.localNode.getId(), false).getTaskId();
wrappedRequest.setParentTask(taskId);
service.sendRequest(targetNode, action, wrappedRequest,
new ProxyResponseHandler<>(channel, responseFunction.apply(wrappedRequest)));
}

private boolean assertConsistentTaskType(Task proxyTask, TransportRequest wrapped) {
final Task targetTask = wrapped.createTask(0, proxyTask.getType(), proxyTask.getAction(), TaskId.EMPTY_TASK_ID, Map.of());
assert targetTask instanceof CancellableTask == proxyTask instanceof CancellableTask :
"Cancellable property of proxy action [" + proxyTask.getAction() + "] is configured inconsistently: " +
"expected [" + (targetTask instanceof CancellableTask) + "] actual [" + (proxyTask instanceof CancellableTask) + "]";
return true;
}
}

private static class ProxyResponseHandler<T extends TransportResponse> implements TransportResponseHandler<T> {
Expand Down Expand Up @@ -117,27 +131,54 @@ public void writeTo(StreamOutput out) throws IOException {
}
}

private static class CancellableProxyRequest<T extends TransportRequest> extends ProxyRequest<T> {
CancellableProxyRequest(StreamInput in, Writeable.Reader<T> reader) throws IOException {
super(in, reader);
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, "", parentTaskId, headers) {
@Override
public boolean shouldCancelChildrenOnCancellation() {
return true;
}

@Override
public String getDescription() {
return "proxy task [" + wrapped.getDescription() + "]";
}
};
}
}

/**
* Registers a proxy request handler that allows to forward requests for the given action to another node. To be used when the
* response type changes based on the upcoming request (quite rare)
*/
public static void registerProxyActionWithDynamicResponseType(TransportService service, String action,
public static void registerProxyActionWithDynamicResponseType(TransportService service, String action, boolean cancellable,
Function<TransportRequest,
Writeable.Reader<? extends TransportResponse>> responseFunction) {
RequestHandlerRegistry<? extends TransportRequest> requestHandler = service.getRequestHandler(action);
service.registerRequestHandler(getProxyAction(action), ThreadPool.Names.SAME, true, false,
in -> new ProxyRequest<>(in, requestHandler::newRequest), new ProxyRequestHandler<>(service, action, responseFunction));
in -> cancellable ?
new CancellableProxyRequest<>(in, requestHandler::newRequest) :
new ProxyRequest<>(in, requestHandler::newRequest),
new ProxyRequestHandler<>(service, action, responseFunction));
}

/**
* Registers a proxy request handler that allows to forward requests for the given action to another node. To be used when the
* response type is always the same (most of the cases).
*/
public static void registerProxyAction(TransportService service, String action,
public static void registerProxyAction(TransportService service, String action, boolean cancellable,
Writeable.Reader<? extends TransportResponse> reader) {
RequestHandlerRegistry<? extends TransportRequest> requestHandler = service.getRequestHandler(action);
service.registerRequestHandler(getProxyAction(action), ThreadPool.Names.SAME, true, false,
in -> new ProxyRequest<>(in, requestHandler::newRequest), new ProxyRequestHandler<>(service, action, request -> reader));
in -> cancellable ?
new CancellableProxyRequest<>(in, requestHandler::newRequest) :
new ProxyRequest<>(in, requestHandler::newRequest),
new ProxyRequestHandler<>(service, action, request -> reader));
}

private static final String PROXY_ACTION_PREFIX = "internal:transport/proxy/";
Expand Down
Loading

0 comments on commit 95936eb

Please sign in to comment.