Skip to content

Commit

Permalink
Preserve thread context when connecting to remote cluster (#31574)
Browse files Browse the repository at this point in the history
Establishing remote cluster connections uses a queue to coordinate multiple concurrent connect
attempts. Connect attempts can be initiated by user triggered searches as well as by system events 
(e.g. when nodes disconnect). Multiple such concurrent events can lead to the connectListener of
one event to be called under the thread context of another connect attempt. This can lead to the
situation as seen in #31462 where the connect listener is executed under the system context, which 
breaks when fetching the search shards from the remote cluster.

Closes #31462
  • Loading branch information
ywelsch authored Jun 27, 2018
1 parent 4fc833b commit b724619
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.action.admin.cluster.state.ClusterStateAction;
import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest;
import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
Expand Down Expand Up @@ -369,9 +370,11 @@ void forceConnect() {
private void connect(ActionListener<Void> connectListener, boolean forceRun) {
final boolean runConnect;
final Collection<ActionListener<Void>> toNotify;
final ActionListener<Void> listener = connectListener == null ? null :
ContextPreservingActionListener.wrapPreservingContext(connectListener, transportService.getThreadPool().getThreadContext());
synchronized (queue) {
if (connectListener != null && queue.offer(connectListener) == false) {
connectListener.onFailure(new RejectedExecutionException("connect queue is full"));
if (listener != null && queue.offer(listener) == false) {
listener.onFailure(new RejectedExecutionException("connect queue is full"));
return;
}
if (forceRun == false && queue.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.CancellableThreads;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.core.internal.io.IOUtils;
Expand Down Expand Up @@ -555,6 +556,64 @@ public void testFetchShards() throws Exception {
}
}

public void testFetchShardsThreadContextHeader() throws Exception {
List<DiscoveryNode> knownNodes = new CopyOnWriteArrayList<>();
try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, Version.CURRENT);
MockTransportService discoverableTransport = startTransport("discoverable_node", knownNodes, Version.CURRENT)) {
DiscoveryNode seedNode = seedTransport.getLocalDiscoNode();
knownNodes.add(seedTransport.getLocalDiscoNode());
knownNodes.add(discoverableTransport.getLocalDiscoNode());
Collections.shuffle(knownNodes, random());
try (MockTransportService service = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool, null)) {
service.start();
service.acceptIncomingRequests();
List<DiscoveryNode> nodes = Collections.singletonList(seedNode);
try (RemoteClusterConnection connection = new RemoteClusterConnection(Settings.EMPTY, "test-cluster",
nodes, service, Integer.MAX_VALUE, n -> true)) {
SearchRequest request = new SearchRequest("test-index");
Thread[] threads = new Thread[10];
for (int i = 0; i < threads.length; i++) {
final String threadId = Integer.toString(i);
threads[i] = new Thread(() -> {
ThreadContext threadContext = seedTransport.threadPool.getThreadContext();
threadContext.putHeader("threadId", threadId);
AtomicReference<ClusterSearchShardsResponse> reference = new AtomicReference<>();
AtomicReference<Exception> failReference = new AtomicReference<>();
final ClusterSearchShardsRequest searchShardsRequest = new ClusterSearchShardsRequest("test-index")
.indicesOptions(request.indicesOptions()).local(true).preference(request.preference())
.routing(request.routing());
CountDownLatch responseLatch = new CountDownLatch(1);
connection.fetchSearchShards(searchShardsRequest,
new LatchedActionListener<>(ActionListener.wrap(
resp -> {
reference.set(resp);
assertEquals(threadId, seedTransport.threadPool.getThreadContext().getHeader("threadId"));
},
failReference::set), responseLatch));
try {
responseLatch.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
assertNull(failReference.get());
assertNotNull(reference.get());
ClusterSearchShardsResponse clusterSearchShardsResponse = reference.get();
assertEquals(knownNodes, Arrays.asList(clusterSearchShardsResponse.getNodes()));
});
}
for (int i = 0; i < threads.length; i++) {
threads[i].start();
}

for (int i = 0; i < threads.length; i++) {
threads[i].join();
}
assertTrue(connection.assertNoRunningConnections());
}
}
}
}

public void testFetchShardsSkipUnavailable() throws Exception {
List<DiscoveryNode> knownNodes = new CopyOnWriteArrayList<>();
try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, Version.CURRENT)) {
Expand Down

0 comments on commit b724619

Please sign in to comment.