diff --git a/core/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java b/core/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java
index 2b16c26931b86..59da9bee7efe2 100644
--- a/core/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java
+++ b/core/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java
@@ -61,8 +61,10 @@
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
+import java.util.stream.Stream;
/**
* Represents a connection to a single remote cluster. In contrast to a local cluster a remote cluster is not joined such that the
@@ -206,6 +208,53 @@ public String executor() {
});
}
+ /**
+ * Collects all nodes on the connected cluster and returns / passes a nodeID to {@link DiscoveryNode} lookup function
+ * that returns null if the node ID is not found.
+ */
+ void collectNodes(ActionListener> listener) {
+ Runnable runnable = () -> {
+ final ClusterStateRequest request = new ClusterStateRequest();
+ request.clear();
+ request.nodes(true);
+ request.local(true); // run this on the node that gets the request it's as good as any other
+ final DiscoveryNode node = nodeSupplier.get();
+ transportService.sendRequest(node, ClusterStateAction.NAME, request, TransportRequestOptions.EMPTY,
+ new TransportResponseHandler() {
+ @Override
+ public ClusterStateResponse newInstance() {
+ return new ClusterStateResponse();
+ }
+
+ @Override
+ public void handleResponse(ClusterStateResponse response) {
+ DiscoveryNodes nodes = response.getState().nodes();
+ listener.onResponse(nodes::get);
+ }
+
+ @Override
+ public void handleException(TransportException exp) {
+ listener.onFailure(exp);
+ }
+
+ @Override
+ public String executor() {
+ return ThreadPool.Names.SAME;
+ }
+ });
+ };
+ if (connectedNodes.isEmpty()) {
+ // just in case if we are not connected for some reason we try to connect and if we fail we have to notify the listener
+ // this will cause some back pressure on the search end and eventually will cause rejections but that's fine
+ // we can't proceed with a search on a cluster level.
+ // in the future we might want to just skip the remote nodes in such a case but that can already be implemented on the
+ // caller end since they provide the listener.
+ ensureConnected(ActionListener.wrap((x) -> runnable.run(), listener::onFailure));
+ } else {
+ runnable.run();
+ }
+ }
+
/**
* Returns a connection to the remote cluster. This connection might be a proxy connection that redirects internally to the
* given node.
diff --git a/core/src/main/java/org/elasticsearch/transport/RemoteClusterService.java b/core/src/main/java/org/elasticsearch/transport/RemoteClusterService.java
index 621713c8ab11e..c4b64e860b2b5 100644
--- a/core/src/main/java/org/elasticsearch/transport/RemoteClusterService.java
+++ b/core/src/main/java/org/elasticsearch/transport/RemoteClusterService.java
@@ -18,17 +18,12 @@
*/
package org.elasticsearch.transport;
-import org.apache.logging.log4j.util.Supplier;
import org.apache.lucene.util.IOUtils;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.action.LatchedActionListener;
import org.elasticsearch.action.OriginalIndices;
-import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup;
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsRequest;
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
-import org.elasticsearch.action.search.SearchRequest;
-import org.elasticsearch.action.search.SearchShardIterator;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.PlainActionFuture;
@@ -40,15 +35,10 @@
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.CountDown;
-import org.elasticsearch.index.Index;
-import org.elasticsearch.index.shard.ShardId;
-import org.elasticsearch.search.internal.AliasFilter;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
-import java.util.ArrayList;
-import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
@@ -59,6 +49,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
@@ -346,4 +337,44 @@ public void getRemoteConnectionInfos(ActionListener clusters, ActionListener> listener) {
+ Map remoteClusters = this.remoteClusters;
+ for (String cluster : clusters) {
+ if (remoteClusters.containsKey(cluster) == false) {
+ listener.onFailure(new IllegalArgumentException("no such remote cluster: [" + cluster + "]"));
+ return;
+ }
+ }
+
+ final Map> clusterMap = new HashMap<>();
+ CountDown countDown = new CountDown(clusters.size());
+ Function nullFunction = s -> null;
+ for (final String cluster : clusters) {
+ RemoteClusterConnection connection = remoteClusters.get(cluster);
+ connection.collectNodes(new ActionListener>() {
+ @Override
+ public void onResponse(Function nodeLookup) {
+ synchronized (clusterMap) {
+ clusterMap.put(cluster, nodeLookup);
+ }
+ if (countDown.countDown()) {
+ listener.onResponse((clusterAlias, nodeId)
+ -> clusterMap.getOrDefault(clusterAlias, nullFunction).apply(nodeId));
+ }
+ }
+
+ @Override
+ public void onFailure(Exception e) {
+ if (countDown.fastForward()) { // we need to check if it's true since we could have multiple failures
+ listener.onFailure(e);
+ }
+ }
+ });
+ }
+ }
}
diff --git a/core/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java b/core/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java
index 3c1181b68258d..44a134857f93f 100644
--- a/core/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java
+++ b/core/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java
@@ -55,11 +55,6 @@
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
-import org.elasticsearch.transport.RemoteClusterConnection;
-import org.elasticsearch.transport.RemoteConnectionInfo;
-import org.elasticsearch.transport.RemoteTransportException;
-import org.elasticsearch.transport.TransportConnectionListener;
-import org.elasticsearch.transport.TransportService;
import java.io.IOException;
import java.net.InetAddress;
@@ -78,6 +73,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
@@ -357,7 +353,6 @@ public void run() {
public void testFetchShards() throws Exception {
List knownNodes = new CopyOnWriteArrayList<>();
-
try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, Version.CURRENT);
MockTransportService discoverableTransport = startTransport("discoverable_node", knownNodes, Version.CURRENT)) {
DiscoveryNode seedNode = seedTransport.getLocalDiscoNode();
@@ -785,4 +780,42 @@ public void onFailure(Exception e) {
}
}
}
+
+ public void testCollectNodes() throws Exception {
+ List knownNodes = new CopyOnWriteArrayList<>();
+ try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, Version.CURRENT)) {
+ DiscoveryNode seedNode = seedTransport.getLocalDiscoNode();
+ knownNodes.add(seedTransport.getLocalDiscoNode());
+ try (MockTransportService service = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool, null)) {
+ service.start();
+ service.acceptIncomingRequests();
+ try (RemoteClusterConnection connection = new RemoteClusterConnection(Settings.EMPTY, "test-cluster",
+ Arrays.asList(seedNode), service, Integer.MAX_VALUE, n -> true)) {
+ if (randomBoolean()) {
+ updateSeedNodes(connection, Arrays.asList(seedNode));
+ }
+ CountDownLatch responseLatch = new CountDownLatch(1);
+ AtomicReference> reference = new AtomicReference<>();
+ AtomicReference failReference = new AtomicReference<>();
+ ActionListener> shardsListener = ActionListener.wrap(
+ x -> {
+ reference.set(x);
+ responseLatch.countDown();
+ },
+ x -> {
+ failReference.set(x);
+ responseLatch.countDown();
+ });
+ connection.collectNodes(shardsListener);
+ responseLatch.await();
+ assertNull(failReference.get());
+ assertNotNull(reference.get());
+ Function function = reference.get();
+ assertEquals(seedNode, function.apply(seedNode.getId()));
+ assertNull(function.apply(seedNode.getId() + "foo"));
+ assertTrue(connection.assertNoRunningConnections());
+ }
+ }
+ }
+ }
}
diff --git a/core/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java b/core/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java
index 32a672e1bbc9a..0c4e0c31d6d81 100644
--- a/core/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java
+++ b/core/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java
@@ -18,6 +18,7 @@
*/
package org.elasticsearch.transport;
+import org.apache.lucene.util.IOUtils;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
@@ -34,11 +35,14 @@
import java.net.InetSocketAddress;
import java.util.Arrays;
import java.util.Collections;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.BiFunction;
public class RemoteClusterServiceTests extends ESTestCase {
@@ -265,4 +269,151 @@ private ActionListener connectionListener(final CountDownLatch latch) {
return ActionListener.wrap(x -> latch.countDown(), x -> fail());
}
+
+ public void testCollectNodes() throws InterruptedException, IOException {
+ final Settings settings = Settings.EMPTY;
+ final List knownNodes_c1 = new CopyOnWriteArrayList<>();
+ final List knownNodes_c2 = new CopyOnWriteArrayList<>();
+
+ try (MockTransportService c1N1 =
+ startTransport("cluster_1_node_1", knownNodes_c1, Version.CURRENT);
+ MockTransportService c1N2 =
+ startTransport("cluster_1_node_2", knownNodes_c1, Version.CURRENT);
+ MockTransportService c2N1 =
+ startTransport("cluster_2_node_1", knownNodes_c2, Version.CURRENT);
+ MockTransportService c2N2 =
+ startTransport("cluster_2_node_2", knownNodes_c2, Version.CURRENT)) {
+ final DiscoveryNode c1N1Node = c1N1.getLocalDiscoNode();
+ final DiscoveryNode c1N2Node = c1N2.getLocalDiscoNode();
+ final DiscoveryNode c2N1Node = c2N1.getLocalDiscoNode();
+ final DiscoveryNode c2N2Node = c2N2.getLocalDiscoNode();
+ knownNodes_c1.add(c1N1Node);
+ knownNodes_c1.add(c1N2Node);
+ knownNodes_c2.add(c2N1Node);
+ knownNodes_c2.add(c2N2Node);
+ Collections.shuffle(knownNodes_c1, random());
+ Collections.shuffle(knownNodes_c2, random());
+
+ try (MockTransportService transportService = MockTransportService.createNewService(
+ settings,
+ Version.CURRENT,
+ threadPool,
+ null)) {
+ transportService.start();
+ transportService.acceptIncomingRequests();
+ final Settings.Builder builder = Settings.builder();
+ builder.putArray(
+ "search.remote.cluster_1.seeds", c1N1Node.getAddress().toString());
+ builder.putArray(
+ "search.remote.cluster_2.seeds", c2N1Node.getAddress().toString());
+ try (RemoteClusterService service =
+ new RemoteClusterService(settings, transportService)) {
+ assertFalse(service.isCrossClusterSearchEnabled());
+ service.initializeRemoteClusters();
+ assertFalse(service.isCrossClusterSearchEnabled());
+
+ final InetSocketAddress c1N1Address = c1N1Node.getAddress().address();
+ final InetSocketAddress c1N2Address = c1N2Node.getAddress().address();
+ final InetSocketAddress c2N1Address = c2N1Node.getAddress().address();
+ final InetSocketAddress c2N2Address = c2N2Node.getAddress().address();
+
+ final CountDownLatch firstLatch = new CountDownLatch(1);
+ service.updateRemoteCluster(
+ "cluster_1",
+ Arrays.asList(c1N1Address, c1N2Address),
+ connectionListener(firstLatch));
+ firstLatch.await();
+
+ final CountDownLatch secondLatch = new CountDownLatch(1);
+ service.updateRemoteCluster(
+ "cluster_2",
+ Arrays.asList(c2N1Address, c2N2Address),
+ connectionListener(secondLatch));
+ secondLatch.await();
+ CountDownLatch latch = new CountDownLatch(1);
+ service.collectNodes(new HashSet<>(Arrays.asList("cluster_1", "cluster_2")),
+ new ActionListener>() {
+ @Override
+ public void onResponse(BiFunction func) {
+ try {
+ assertEquals(c1N1Node, func.apply("cluster_1", c1N1Node.getId()));
+ assertEquals(c1N2Node, func.apply("cluster_1", c1N2Node.getId()));
+ assertEquals(c2N1Node, func.apply("cluster_2", c2N1Node.getId()));
+ assertEquals(c2N2Node, func.apply("cluster_2", c2N2Node.getId()));
+ } finally {
+ latch.countDown();
+ }
+ }
+
+ @Override
+ public void onFailure(Exception e) {
+ try {
+ throw new AssertionError(e);
+ } finally {
+ latch.countDown();
+ }
+ }
+ });
+ latch.await();
+ {
+ CountDownLatch failLatch = new CountDownLatch(1);
+ AtomicReference ex = new AtomicReference<>();
+ service.collectNodes(new HashSet<>(Arrays.asList("cluster_1", "cluster_2", "no such cluster")),
+ new ActionListener>() {
+ @Override
+ public void onResponse(BiFunction stringStringDiscoveryNodeBiFunction) {
+ try {
+ fail("should not be called");
+ } finally {
+ failLatch.countDown();
+ }
+ }
+
+ @Override
+ public void onFailure(Exception e) {
+ try {
+ ex.set(e);
+ } finally {
+ failLatch.countDown();
+ }
+ }
+ });
+ failLatch.await();
+ assertNotNull(ex.get());
+ assertTrue(ex.get() instanceof IllegalArgumentException);
+ assertEquals("no such remote cluster: [no such cluster]", ex.get().getMessage());
+ }
+ {
+ // close all targets and check for the transport level failure path
+ IOUtils.close(c1N1, c1N2, c2N1, c2N2);
+ CountDownLatch failLatch = new CountDownLatch(1);
+ AtomicReference ex = new AtomicReference<>();
+ service.collectNodes(new HashSet<>(Arrays.asList("cluster_1", "cluster_2")),
+ new ActionListener>() {
+ @Override
+ public void onResponse(BiFunction stringStringDiscoveryNodeBiFunction) {
+ try {
+ fail("should not be called");
+ } finally {
+ failLatch.countDown();
+ }
+ }
+
+ @Override
+ public void onFailure(Exception e) {
+ try {
+ ex.set(e);
+ } finally {
+ failLatch.countDown();
+ }
+ }
+ });
+ failLatch.await();
+ assertNotNull(ex.get());
+ assertTrue(ex.get().getClass().toString(), ex.get() instanceof TransportException);
+ }
+ }
+ }
+ }
+ }
}