Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cancel proxy requests when the proxy channel closes #65850

Merged
merged 5 commits into from
Dec 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.test.AbstractMultiClustersTestCase;
import org.elasticsearch.test.InternalTestCluster;
import org.elasticsearch.test.NodeRoles;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
import org.elasticsearch.transport.TransportService;
import org.junit.Before;

import java.util.Collection;
Expand Down Expand Up @@ -127,6 +129,16 @@ public void testProxyConnectionDisconnect() throws Exception {
client(LOCAL_CLUSTER).search(searchRequest, future);
SearchListenerPlugin.waitSearchStarted();
disconnectFromRemoteClusters();
// Cancellable tasks on the remote cluster should be cancelled
assertBusy(() -> {
final Iterable<TransportService> transportServices = cluster("cluster_a").getInstances(TransportService.class);
for (TransportService transportService : transportServices) {
Collection<CancellableTask> cancellableTasks = transportService.getTaskManager().getCancellableTasks().values();
for (CancellableTask cancellableTask : cancellableTasks) {
assertTrue(cancellableTask.getDescription(), cancellableTask.isCancelled());
}
}
});
assertBusy(() -> assertTrue(future.isDone()));
configureAndConnectsToRemoteClusters();
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,20 +297,20 @@ public static void registerRequestHandler(TransportService transportService, Sea
boolean freed = searchService.freeReaderContext(request.id());
channel.sendResponse(new SearchFreeContextResponse(freed));
});
TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_SCROLL_ACTION_NAME, SearchFreeContextResponse::new);
TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_SCROLL_ACTION_NAME, false, SearchFreeContextResponse::new);
transportService.registerRequestHandler(FREE_CONTEXT_ACTION_NAME, ThreadPool.Names.SAME, SearchFreeContextRequest::new,
(request, channel, task) -> {
boolean freed = searchService.freeReaderContext(request.id());
channel.sendResponse(new SearchFreeContextResponse(freed));
});
TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_ACTION_NAME, SearchFreeContextResponse::new);
TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_ACTION_NAME, false, SearchFreeContextResponse::new);
transportService.registerRequestHandler(CLEAR_SCROLL_CONTEXTS_ACTION_NAME, ThreadPool.Names.SAME,
TransportRequest.Empty::new,
(request, channel, task) -> {
searchService.freeAllScrollContexts();
channel.sendResponse(TransportResponse.Empty.INSTANCE);
});
TransportActionProxy.registerProxyAction(transportService, CLEAR_SCROLL_CONTEXTS_ACTION_NAME,
TransportActionProxy.registerProxyAction(transportService, CLEAR_SCROLL_CONTEXTS_ACTION_NAME, false,
(in) -> TransportResponse.Empty.INSTANCE);

transportService.registerRequestHandler(DFS_ACTION_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new,
Expand All @@ -319,57 +319,57 @@ public static void registerRequestHandler(TransportService transportService, Sea
new ChannelActionListener<>(channel, DFS_ACTION_NAME, request))
);

TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, DfsSearchResult::new);
TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, true, DfsSearchResult::new);

transportService.registerRequestHandler(QUERY_ACTION_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new,
(request, channel, task) -> {
searchService.executeQueryPhase(request, keepStatesInContext(channel.getVersion()), (SearchShardTask) task,
new ChannelActionListener<>(channel, QUERY_ACTION_NAME, request));
});
TransportActionProxy.registerProxyActionWithDynamicResponseType(transportService, QUERY_ACTION_NAME,
TransportActionProxy.registerProxyActionWithDynamicResponseType(transportService, QUERY_ACTION_NAME, true,
(request) -> ((ShardSearchRequest)request).numberOfShards() == 1 ? QueryFetchSearchResult::new : QuerySearchResult::new);

transportService.registerRequestHandler(QUERY_ID_ACTION_NAME, ThreadPool.Names.SAME, QuerySearchRequest::new,
(request, channel, task) -> {
searchService.executeQueryPhase(request, (SearchShardTask) task,
new ChannelActionListener<>(channel, QUERY_ID_ACTION_NAME, request));
});
TransportActionProxy.registerProxyAction(transportService, QUERY_ID_ACTION_NAME, QuerySearchResult::new);
TransportActionProxy.registerProxyAction(transportService, QUERY_ID_ACTION_NAME, true, QuerySearchResult::new);

transportService.registerRequestHandler(QUERY_SCROLL_ACTION_NAME, ThreadPool.Names.SAME, InternalScrollSearchRequest::new,
(request, channel, task) -> {
searchService.executeQueryPhase(request, (SearchShardTask) task,
new ChannelActionListener<>(channel, QUERY_SCROLL_ACTION_NAME, request));
});
TransportActionProxy.registerProxyAction(transportService, QUERY_SCROLL_ACTION_NAME, ScrollQuerySearchResult::new);
TransportActionProxy.registerProxyAction(transportService, QUERY_SCROLL_ACTION_NAME, true, ScrollQuerySearchResult::new);

transportService.registerRequestHandler(QUERY_FETCH_SCROLL_ACTION_NAME, ThreadPool.Names.SAME, InternalScrollSearchRequest::new,
(request, channel, task) -> {
searchService.executeFetchPhase(request, (SearchShardTask) task,
new ChannelActionListener<>(channel, QUERY_FETCH_SCROLL_ACTION_NAME, request));
});
TransportActionProxy.registerProxyAction(transportService, QUERY_FETCH_SCROLL_ACTION_NAME, ScrollQueryFetchSearchResult::new);
TransportActionProxy.registerProxyAction(transportService, QUERY_FETCH_SCROLL_ACTION_NAME, true, ScrollQueryFetchSearchResult::new);

transportService.registerRequestHandler(FETCH_ID_SCROLL_ACTION_NAME, ThreadPool.Names.SAME, ShardFetchRequest::new,
(request, channel, task) -> {
searchService.executeFetchPhase(request, (SearchShardTask) task,
new ChannelActionListener<>(channel, FETCH_ID_SCROLL_ACTION_NAME, request));
});
TransportActionProxy.registerProxyAction(transportService, FETCH_ID_SCROLL_ACTION_NAME, FetchSearchResult::new);
TransportActionProxy.registerProxyAction(transportService, FETCH_ID_SCROLL_ACTION_NAME, true, FetchSearchResult::new);

transportService.registerRequestHandler(FETCH_ID_ACTION_NAME, ThreadPool.Names.SAME, true, true, ShardFetchSearchRequest::new,
(request, channel, task) -> {
searchService.executeFetchPhase(request, (SearchShardTask) task,
new ChannelActionListener<>(channel, FETCH_ID_ACTION_NAME, request));
});
TransportActionProxy.registerProxyAction(transportService, FETCH_ID_ACTION_NAME, FetchSearchResult::new);
TransportActionProxy.registerProxyAction(transportService, FETCH_ID_ACTION_NAME, true, FetchSearchResult::new);

// this is cheap, it does not fetch during the rewrite phase, so we can let it quickly execute on a networking thread
transportService.registerRequestHandler(QUERY_CAN_MATCH_NAME, ThreadPool.Names.SAME, ShardSearchRequest::new,
(request, channel, task) -> {
searchService.canMatch(request, new ChannelActionListener<>(channel, QUERY_CAN_MATCH_NAME, request));
});
TransportActionProxy.registerProxyAction(transportService, QUERY_CAN_MATCH_NAME, SearchService.CanMatchResponse::new);
TransportActionProxy.registerProxyAction(transportService, QUERY_CAN_MATCH_NAME, true, SearchService.CanMatchResponse::new);
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;
import java.util.function.Function;

/**
Expand Down Expand Up @@ -55,9 +58,20 @@ private static class ProxyRequestHandler<T extends ProxyRequest> implements Tran
public void messageReceived(T request, TransportChannel channel, Task task) throws Exception {
DiscoveryNode targetNode = request.targetNode;
TransportRequest wrappedRequest = request.wrapped;
assert assertConsistentTaskType(task, wrappedRequest);
TaskId taskId = task.taskInfo(service.localNode.getId(), false).getTaskId();
wrappedRequest.setParentTask(taskId);
service.sendRequest(targetNode, action, wrappedRequest,
new ProxyResponseHandler<>(channel, responseFunction.apply(wrappedRequest)));
}

private boolean assertConsistentTaskType(Task proxyTask, TransportRequest wrapped) {
final Task targetTask = wrapped.createTask(0, proxyTask.getType(), proxyTask.getAction(), TaskId.EMPTY_TASK_ID, Map.of());
assert targetTask instanceof CancellableTask == proxyTask instanceof CancellableTask :
"Cancellable property of proxy action [" + proxyTask.getAction() + "] is configured inconsistently: " +
"expected [" + (targetTask instanceof CancellableTask) + "] actual [" + (proxyTask instanceof CancellableTask) + "]";
return true;
}
}

private static class ProxyResponseHandler<T extends TransportResponse> implements TransportResponseHandler<T> {
Expand Down Expand Up @@ -117,27 +131,54 @@ public void writeTo(StreamOutput out) throws IOException {
}
}

private static class CancellableProxyRequest<T extends TransportRequest> extends ProxyRequest<T> {
CancellableProxyRequest(StreamInput in, Writeable.Reader<T> reader) throws IOException {
super(in, reader);
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, "", parentTaskId, headers) {
DaveCTurner marked this conversation as resolved.
Show resolved Hide resolved
@Override
public boolean shouldCancelChildrenOnCancellation() {
return true;
}

@Override
public String getDescription() {
return "proxy task [" + wrapped.getDescription() + "]";
}
};
}
}

/**
* Registers a proxy request handler that allows to forward requests for the given action to another node. To be used when the
* response type changes based on the upcoming request (quite rare)
*/
public static void registerProxyActionWithDynamicResponseType(TransportService service, String action,
public static void registerProxyActionWithDynamicResponseType(TransportService service, String action, boolean cancellable,
Function<TransportRequest,
Writeable.Reader<? extends TransportResponse>> responseFunction) {
RequestHandlerRegistry<? extends TransportRequest> requestHandler = service.getRequestHandler(action);
service.registerRequestHandler(getProxyAction(action), ThreadPool.Names.SAME, true, false,
in -> new ProxyRequest<>(in, requestHandler::newRequest), new ProxyRequestHandler<>(service, action, responseFunction));
in -> cancellable ?
new CancellableProxyRequest<>(in, requestHandler::newRequest) :
new ProxyRequest<>(in, requestHandler::newRequest),
new ProxyRequestHandler<>(service, action, responseFunction));
}

/**
* Registers a proxy request handler that allows to forward requests for the given action to another node. To be used when the
* response type is always the same (most of the cases).
*/
public static void registerProxyAction(TransportService service, String action,
public static void registerProxyAction(TransportService service, String action, boolean cancellable,
Writeable.Reader<? extends TransportResponse> reader) {
RequestHandlerRegistry<? extends TransportRequest> requestHandler = service.getRequestHandler(action);
service.registerRequestHandler(getProxyAction(action), ThreadPool.Names.SAME, true, false,
in -> new ProxyRequest<>(in, requestHandler::newRequest), new ProxyRequestHandler<>(service, action, request -> reader));
in -> cancellable ?
new CancellableProxyRequest<>(in, requestHandler::newRequest) :
new ProxyRequest<>(in, requestHandler::newRequest),
new ProxyRequestHandler<>(service, action, request -> reader));
}

private static final String PROXY_ACTION_PREFIX = "internal:transport/proxy/";
Expand Down
Loading