diff --git a/core/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java b/core/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java index 2b16c26931b86..59da9bee7efe2 100644 --- a/core/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java +++ b/core/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java @@ -61,8 +61,10 @@ import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; +import java.util.stream.Stream; /** * Represents a connection to a single remote cluster. In contrast to a local cluster a remote cluster is not joined such that the @@ -206,6 +208,53 @@ public String executor() { }); } + /** + * Collects all nodes on the connected cluster and returns / passes a nodeID to {@link DiscoveryNode} lookup function + * that returns null if the node ID is not found. + */ + void collectNodes(ActionListener> listener) { + Runnable runnable = () -> { + final ClusterStateRequest request = new ClusterStateRequest(); + request.clear(); + request.nodes(true); + request.local(true); // run this on the node that gets the request it's as good as any other + final DiscoveryNode node = nodeSupplier.get(); + transportService.sendRequest(node, ClusterStateAction.NAME, request, TransportRequestOptions.EMPTY, + new TransportResponseHandler() { + @Override + public ClusterStateResponse newInstance() { + return new ClusterStateResponse(); + } + + @Override + public void handleResponse(ClusterStateResponse response) { + DiscoveryNodes nodes = response.getState().nodes(); + listener.onResponse(nodes::get); + } + + @Override + public void handleException(TransportException exp) { + listener.onFailure(exp); + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + }); + }; + if (connectedNodes.isEmpty()) { + // just in case if we are not connected for some reason we try to connect and if we fail we have to notify the listener + // this will cause some back pressure on the search end and eventually will cause rejections but that's fine + // we can't proceed with a search on a cluster level. + // in the future we might want to just skip the remote nodes in such a case but that can already be implemented on the + // caller end since they provide the listener. + ensureConnected(ActionListener.wrap((x) -> runnable.run(), listener::onFailure)); + } else { + runnable.run(); + } + } + /** * Returns a connection to the remote cluster. This connection might be a proxy connection that redirects internally to the * given node. diff --git a/core/src/main/java/org/elasticsearch/transport/RemoteClusterService.java b/core/src/main/java/org/elasticsearch/transport/RemoteClusterService.java index 621713c8ab11e..c4b64e860b2b5 100644 --- a/core/src/main/java/org/elasticsearch/transport/RemoteClusterService.java +++ b/core/src/main/java/org/elasticsearch/transport/RemoteClusterService.java @@ -18,17 +18,12 @@ */ package org.elasticsearch.transport; -import org.apache.logging.log4j.util.Supplier; import org.apache.lucene.util.IOUtils; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.LatchedActionListener; import org.elasticsearch.action.OriginalIndices; -import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup; import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsRequest; import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse; -import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.action.search.SearchShardIterator; import org.elasticsearch.action.support.GroupedActionListener; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.PlainActionFuture; @@ -40,15 +35,10 @@ import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.CountDown; -import org.elasticsearch.index.Index; -import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.search.internal.AliasFilter; import java.io.Closeable; import java.io.IOException; import java.net.InetSocketAddress; -import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -59,6 +49,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Predicate; import java.util.stream.Collectors; @@ -346,4 +337,44 @@ public void getRemoteConnectionInfos(ActionListener clusters, ActionListener> listener) { + Map remoteClusters = this.remoteClusters; + for (String cluster : clusters) { + if (remoteClusters.containsKey(cluster) == false) { + listener.onFailure(new IllegalArgumentException("no such remote cluster: [" + cluster + "]")); + return; + } + } + + final Map> clusterMap = new HashMap<>(); + CountDown countDown = new CountDown(clusters.size()); + Function nullFunction = s -> null; + for (final String cluster : clusters) { + RemoteClusterConnection connection = remoteClusters.get(cluster); + connection.collectNodes(new ActionListener>() { + @Override + public void onResponse(Function nodeLookup) { + synchronized (clusterMap) { + clusterMap.put(cluster, nodeLookup); + } + if (countDown.countDown()) { + listener.onResponse((clusterAlias, nodeId) + -> clusterMap.getOrDefault(clusterAlias, nullFunction).apply(nodeId)); + } + } + + @Override + public void onFailure(Exception e) { + if (countDown.fastForward()) { // we need to check if it's true since we could have multiple failures + listener.onFailure(e); + } + } + }); + } + } } diff --git a/core/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java b/core/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java index 3c1181b68258d..44a134857f93f 100644 --- a/core/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java +++ b/core/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java @@ -55,11 +55,6 @@ import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.transport.RemoteClusterConnection; -import org.elasticsearch.transport.RemoteConnectionInfo; -import org.elasticsearch.transport.RemoteTransportException; -import org.elasticsearch.transport.TransportConnectionListener; -import org.elasticsearch.transport.TransportService; import java.io.IOException; import java.net.InetAddress; @@ -78,6 +73,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; @@ -357,7 +353,6 @@ public void run() { public void testFetchShards() throws Exception { List knownNodes = new CopyOnWriteArrayList<>(); - try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, Version.CURRENT); MockTransportService discoverableTransport = startTransport("discoverable_node", knownNodes, Version.CURRENT)) { DiscoveryNode seedNode = seedTransport.getLocalDiscoNode(); @@ -785,4 +780,42 @@ public void onFailure(Exception e) { } } } + + public void testCollectNodes() throws Exception { + List knownNodes = new CopyOnWriteArrayList<>(); + try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, Version.CURRENT)) { + DiscoveryNode seedNode = seedTransport.getLocalDiscoNode(); + knownNodes.add(seedTransport.getLocalDiscoNode()); + try (MockTransportService service = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool, null)) { + service.start(); + service.acceptIncomingRequests(); + try (RemoteClusterConnection connection = new RemoteClusterConnection(Settings.EMPTY, "test-cluster", + Arrays.asList(seedNode), service, Integer.MAX_VALUE, n -> true)) { + if (randomBoolean()) { + updateSeedNodes(connection, Arrays.asList(seedNode)); + } + CountDownLatch responseLatch = new CountDownLatch(1); + AtomicReference> reference = new AtomicReference<>(); + AtomicReference failReference = new AtomicReference<>(); + ActionListener> shardsListener = ActionListener.wrap( + x -> { + reference.set(x); + responseLatch.countDown(); + }, + x -> { + failReference.set(x); + responseLatch.countDown(); + }); + connection.collectNodes(shardsListener); + responseLatch.await(); + assertNull(failReference.get()); + assertNotNull(reference.get()); + Function function = reference.get(); + assertEquals(seedNode, function.apply(seedNode.getId())); + assertNull(function.apply(seedNode.getId() + "foo")); + assertTrue(connection.assertNoRunningConnections()); + } + } + } + } } diff --git a/core/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java b/core/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java index 32a672e1bbc9a..0c4e0c31d6d81 100644 --- a/core/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java +++ b/core/src/test/java/org/elasticsearch/transport/RemoteClusterServiceTests.java @@ -18,6 +18,7 @@ */ package org.elasticsearch.transport; +import org.apache.lucene.util.IOUtils; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; @@ -34,11 +35,14 @@ import java.net.InetSocketAddress; import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; public class RemoteClusterServiceTests extends ESTestCase { @@ -265,4 +269,151 @@ private ActionListener connectionListener(final CountDownLatch latch) { return ActionListener.wrap(x -> latch.countDown(), x -> fail()); } + + public void testCollectNodes() throws InterruptedException, IOException { + final Settings settings = Settings.EMPTY; + final List knownNodes_c1 = new CopyOnWriteArrayList<>(); + final List knownNodes_c2 = new CopyOnWriteArrayList<>(); + + try (MockTransportService c1N1 = + startTransport("cluster_1_node_1", knownNodes_c1, Version.CURRENT); + MockTransportService c1N2 = + startTransport("cluster_1_node_2", knownNodes_c1, Version.CURRENT); + MockTransportService c2N1 = + startTransport("cluster_2_node_1", knownNodes_c2, Version.CURRENT); + MockTransportService c2N2 = + startTransport("cluster_2_node_2", knownNodes_c2, Version.CURRENT)) { + final DiscoveryNode c1N1Node = c1N1.getLocalDiscoNode(); + final DiscoveryNode c1N2Node = c1N2.getLocalDiscoNode(); + final DiscoveryNode c2N1Node = c2N1.getLocalDiscoNode(); + final DiscoveryNode c2N2Node = c2N2.getLocalDiscoNode(); + knownNodes_c1.add(c1N1Node); + knownNodes_c1.add(c1N2Node); + knownNodes_c2.add(c2N1Node); + knownNodes_c2.add(c2N2Node); + Collections.shuffle(knownNodes_c1, random()); + Collections.shuffle(knownNodes_c2, random()); + + try (MockTransportService transportService = MockTransportService.createNewService( + settings, + Version.CURRENT, + threadPool, + null)) { + transportService.start(); + transportService.acceptIncomingRequests(); + final Settings.Builder builder = Settings.builder(); + builder.putArray( + "search.remote.cluster_1.seeds", c1N1Node.getAddress().toString()); + builder.putArray( + "search.remote.cluster_2.seeds", c2N1Node.getAddress().toString()); + try (RemoteClusterService service = + new RemoteClusterService(settings, transportService)) { + assertFalse(service.isCrossClusterSearchEnabled()); + service.initializeRemoteClusters(); + assertFalse(service.isCrossClusterSearchEnabled()); + + final InetSocketAddress c1N1Address = c1N1Node.getAddress().address(); + final InetSocketAddress c1N2Address = c1N2Node.getAddress().address(); + final InetSocketAddress c2N1Address = c2N1Node.getAddress().address(); + final InetSocketAddress c2N2Address = c2N2Node.getAddress().address(); + + final CountDownLatch firstLatch = new CountDownLatch(1); + service.updateRemoteCluster( + "cluster_1", + Arrays.asList(c1N1Address, c1N2Address), + connectionListener(firstLatch)); + firstLatch.await(); + + final CountDownLatch secondLatch = new CountDownLatch(1); + service.updateRemoteCluster( + "cluster_2", + Arrays.asList(c2N1Address, c2N2Address), + connectionListener(secondLatch)); + secondLatch.await(); + CountDownLatch latch = new CountDownLatch(1); + service.collectNodes(new HashSet<>(Arrays.asList("cluster_1", "cluster_2")), + new ActionListener>() { + @Override + public void onResponse(BiFunction func) { + try { + assertEquals(c1N1Node, func.apply("cluster_1", c1N1Node.getId())); + assertEquals(c1N2Node, func.apply("cluster_1", c1N2Node.getId())); + assertEquals(c2N1Node, func.apply("cluster_2", c2N1Node.getId())); + assertEquals(c2N2Node, func.apply("cluster_2", c2N2Node.getId())); + } finally { + latch.countDown(); + } + } + + @Override + public void onFailure(Exception e) { + try { + throw new AssertionError(e); + } finally { + latch.countDown(); + } + } + }); + latch.await(); + { + CountDownLatch failLatch = new CountDownLatch(1); + AtomicReference ex = new AtomicReference<>(); + service.collectNodes(new HashSet<>(Arrays.asList("cluster_1", "cluster_2", "no such cluster")), + new ActionListener>() { + @Override + public void onResponse(BiFunction stringStringDiscoveryNodeBiFunction) { + try { + fail("should not be called"); + } finally { + failLatch.countDown(); + } + } + + @Override + public void onFailure(Exception e) { + try { + ex.set(e); + } finally { + failLatch.countDown(); + } + } + }); + failLatch.await(); + assertNotNull(ex.get()); + assertTrue(ex.get() instanceof IllegalArgumentException); + assertEquals("no such remote cluster: [no such cluster]", ex.get().getMessage()); + } + { + // close all targets and check for the transport level failure path + IOUtils.close(c1N1, c1N2, c2N1, c2N2); + CountDownLatch failLatch = new CountDownLatch(1); + AtomicReference ex = new AtomicReference<>(); + service.collectNodes(new HashSet<>(Arrays.asList("cluster_1", "cluster_2")), + new ActionListener>() { + @Override + public void onResponse(BiFunction stringStringDiscoveryNodeBiFunction) { + try { + fail("should not be called"); + } finally { + failLatch.countDown(); + } + } + + @Override + public void onFailure(Exception e) { + try { + ex.set(e); + } finally { + failLatch.countDown(); + } + } + }); + failLatch.await(); + assertNotNull(ex.get()); + assertTrue(ex.get().getClass().toString(), ex.get() instanceof TransportException); + } + } + } + } + } }