Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Represents a connection to a single remote cluster. In contrast to a local cluster a remote cluster is not joined such that the
Expand Down Expand Up @@ -206,6 +208,53 @@ public String executor() {
});
}

/**
* Collects all nodes on the connected cluster and returns / passes a nodeID to {@link DiscoveryNode} lookup function
* that returns <code>null</code> if the node ID is not found.
*/
void collectNodes(ActionListener<Function<String, DiscoveryNode>> listener) {
Runnable runnable = () -> {
final ClusterStateRequest request = new ClusterStateRequest();
request.clear();
request.nodes(true);
request.local(true); // run this on the node that gets the request it's as good as any other
final DiscoveryNode node = nodeSupplier.get();
transportService.sendRequest(node, ClusterStateAction.NAME, request, TransportRequestOptions.EMPTY,
new TransportResponseHandler<ClusterStateResponse>() {
@Override
public ClusterStateResponse newInstance() {
return new ClusterStateResponse();
}

@Override
public void handleResponse(ClusterStateResponse response) {
DiscoveryNodes nodes = response.getState().nodes();
listener.onResponse(nodes::get);
}

@Override
public void handleException(TransportException exp) {
listener.onFailure(exp);
}

@Override
public String executor() {
return ThreadPool.Names.SAME;
}
});
};
if (connectedNodes.isEmpty()) {
// just in case if we are not connected for some reason we try to connect and if we fail we have to notify the listener
// this will cause some back pressure on the search end and eventually will cause rejections but that's fine
// we can't proceed with a search on a cluster level.
// in the future we might want to just skip the remote nodes in such a case but that can already be implemented on the
// caller end since they provide the listener.
ensureConnected(ActionListener.wrap((x) -> runnable.run(), listener::onFailure));
} else {
runnable.run();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: why using a runnable here? couldn't we do transportService.sendRequest straight-away?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could but that means that we duplicate the code, so in this case I have it all in one runnable and just pass it on.

}
}

/**
* Returns a connection to the remote cluster. This connection might be a proxy connection that redirects internally to the
* given node.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,12 @@
*/
package org.elasticsearch.transport;

import org.apache.logging.log4j.util.Supplier;
import org.apache.lucene.util.IOUtils;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.LatchedActionListener;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup;
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsRequest;
import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchShardIterator;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.PlainActionFuture;
Expand All @@ -40,15 +35,10 @@
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.CountDown;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.search.internal.AliasFilter;

import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
Expand All @@ -59,6 +49,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -346,4 +337,44 @@ public void getRemoteConnectionInfos(ActionListener<Collection<RemoteConnectionI
}
}
}

/**
* Collects all nodes of the given clusters and returns / passes a (clusterAlias, nodeId) to {@link DiscoveryNode}
* function on success.
*/
public void collectNodes(Set<String> clusters, ActionListener<BiFunction<String, String, DiscoveryNode>> listener) {
Map<String, RemoteClusterConnection> remoteClusters = this.remoteClusters;
for (String cluster : clusters) {
if (remoteClusters.containsKey(cluster) == false) {
listener.onFailure(new IllegalArgumentException("no such remote cluster: [" + cluster + "]"));
return;
}
}

final Map<String, Function<String, DiscoveryNode>> clusterMap = new HashMap<>();
CountDown countDown = new CountDown(clusters.size());
Function<String, DiscoveryNode> nullFunction = s -> null;
for (final String cluster : clusters) {
RemoteClusterConnection connection = remoteClusters.get(cluster);
connection.collectNodes(new ActionListener<Function<String, DiscoveryNode>>() {
@Override
public void onResponse(Function<String, DiscoveryNode> nodeLookup) {
synchronized (clusterMap) {
clusterMap.put(cluster, nodeLookup);
}
if (countDown.countDown()) {
listener.onResponse((clusterAlias, nodeId)
-> clusterMap.getOrDefault(clusterAlias, nullFunction).apply(nodeId));
}
}

@Override
public void onFailure(Exception e) {
if (countDown.fastForward()) { // we need to check if it's true since we could have multiple failures
listener.onFailure(e);
}
}
});
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,6 @@
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.RemoteClusterConnection;
import org.elasticsearch.transport.RemoteConnectionInfo;
import org.elasticsearch.transport.RemoteTransportException;
import org.elasticsearch.transport.TransportConnectionListener;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.net.InetAddress;
Expand All @@ -78,6 +73,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;

import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
Expand Down Expand Up @@ -357,7 +353,6 @@ public void run() {

public void testFetchShards() throws Exception {
List<DiscoveryNode> knownNodes = new CopyOnWriteArrayList<>();

try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, Version.CURRENT);
MockTransportService discoverableTransport = startTransport("discoverable_node", knownNodes, Version.CURRENT)) {
DiscoveryNode seedNode = seedTransport.getLocalDiscoNode();
Expand Down Expand Up @@ -785,4 +780,42 @@ public void onFailure(Exception e) {
}
}
}

public void testCollectNodes() throws Exception {
List<DiscoveryNode> knownNodes = new CopyOnWriteArrayList<>();
try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, Version.CURRENT)) {
DiscoveryNode seedNode = seedTransport.getLocalDiscoNode();
knownNodes.add(seedTransport.getLocalDiscoNode());
try (MockTransportService service = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool, null)) {
service.start();
service.acceptIncomingRequests();
try (RemoteClusterConnection connection = new RemoteClusterConnection(Settings.EMPTY, "test-cluster",
Arrays.asList(seedNode), service, Integer.MAX_VALUE, n -> true)) {
if (randomBoolean()) {
updateSeedNodes(connection, Arrays.asList(seedNode));
}
CountDownLatch responseLatch = new CountDownLatch(1);
AtomicReference<Function<String, DiscoveryNode>> reference = new AtomicReference<>();
AtomicReference<Exception> failReference = new AtomicReference<>();
ActionListener<Function<String, DiscoveryNode>> shardsListener = ActionListener.wrap(
x -> {
reference.set(x);
responseLatch.countDown();
},
x -> {
failReference.set(x);
responseLatch.countDown();
});
connection.collectNodes(shardsListener);
responseLatch.await();
assertNull(failReference.get());
assertNotNull(reference.get());
Function<String, DiscoveryNode> function = reference.get();
assertEquals(seedNode, function.apply(seedNode.getId()));
assertNull(function.apply(seedNode.getId() + "foo"));
assertTrue(connection.assertNoRunningConnections());
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.elasticsearch.transport;

import org.apache.lucene.util.IOUtils;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
Expand All @@ -34,11 +35,14 @@
import java.net.InetSocketAddress;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiFunction;

public class RemoteClusterServiceTests extends ESTestCase {

Expand Down Expand Up @@ -265,4 +269,151 @@ private ActionListener<Void> connectionListener(final CountDownLatch latch) {
return ActionListener.wrap(x -> latch.countDown(), x -> fail());
}


public void testCollectNodes() throws InterruptedException, IOException {
final Settings settings = Settings.EMPTY;
final List<DiscoveryNode> knownNodes_c1 = new CopyOnWriteArrayList<>();
final List<DiscoveryNode> knownNodes_c2 = new CopyOnWriteArrayList<>();

try (MockTransportService c1N1 =
startTransport("cluster_1_node_1", knownNodes_c1, Version.CURRENT);
MockTransportService c1N2 =
startTransport("cluster_1_node_2", knownNodes_c1, Version.CURRENT);
MockTransportService c2N1 =
startTransport("cluster_2_node_1", knownNodes_c2, Version.CURRENT);
MockTransportService c2N2 =
startTransport("cluster_2_node_2", knownNodes_c2, Version.CURRENT)) {
final DiscoveryNode c1N1Node = c1N1.getLocalDiscoNode();
final DiscoveryNode c1N2Node = c1N2.getLocalDiscoNode();
final DiscoveryNode c2N1Node = c2N1.getLocalDiscoNode();
final DiscoveryNode c2N2Node = c2N2.getLocalDiscoNode();
knownNodes_c1.add(c1N1Node);
knownNodes_c1.add(c1N2Node);
knownNodes_c2.add(c2N1Node);
knownNodes_c2.add(c2N2Node);
Collections.shuffle(knownNodes_c1, random());
Collections.shuffle(knownNodes_c2, random());

try (MockTransportService transportService = MockTransportService.createNewService(
settings,
Version.CURRENT,
threadPool,
null)) {
transportService.start();
transportService.acceptIncomingRequests();
final Settings.Builder builder = Settings.builder();
builder.putArray(
"search.remote.cluster_1.seeds", c1N1Node.getAddress().toString());
builder.putArray(
"search.remote.cluster_2.seeds", c2N1Node.getAddress().toString());
try (RemoteClusterService service =
new RemoteClusterService(settings, transportService)) {
assertFalse(service.isCrossClusterSearchEnabled());
service.initializeRemoteClusters();
assertFalse(service.isCrossClusterSearchEnabled());

final InetSocketAddress c1N1Address = c1N1Node.getAddress().address();
final InetSocketAddress c1N2Address = c1N2Node.getAddress().address();
final InetSocketAddress c2N1Address = c2N1Node.getAddress().address();
final InetSocketAddress c2N2Address = c2N2Node.getAddress().address();

final CountDownLatch firstLatch = new CountDownLatch(1);
service.updateRemoteCluster(
"cluster_1",
Arrays.asList(c1N1Address, c1N2Address),
connectionListener(firstLatch));
firstLatch.await();

final CountDownLatch secondLatch = new CountDownLatch(1);
service.updateRemoteCluster(
"cluster_2",
Arrays.asList(c2N1Address, c2N2Address),
connectionListener(secondLatch));
secondLatch.await();
CountDownLatch latch = new CountDownLatch(1);
service.collectNodes(new HashSet<>(Arrays.asList("cluster_1", "cluster_2")),
new ActionListener<BiFunction<String, String, DiscoveryNode>>() {
@Override
public void onResponse(BiFunction<String, String, DiscoveryNode> func) {
try {
assertEquals(c1N1Node, func.apply("cluster_1", c1N1Node.getId()));
assertEquals(c1N2Node, func.apply("cluster_1", c1N2Node.getId()));
assertEquals(c2N1Node, func.apply("cluster_2", c2N1Node.getId()));
assertEquals(c2N2Node, func.apply("cluster_2", c2N2Node.getId()));
} finally {
latch.countDown();
}
}

@Override
public void onFailure(Exception e) {
try {
throw new AssertionError(e);
} finally {
latch.countDown();
}
}
});
latch.await();
{
CountDownLatch failLatch = new CountDownLatch(1);
AtomicReference<Exception> ex = new AtomicReference<>();
service.collectNodes(new HashSet<>(Arrays.asList("cluster_1", "cluster_2", "no such cluster")),
new ActionListener<BiFunction<String, String, DiscoveryNode>>() {
@Override
public void onResponse(BiFunction<String, String, DiscoveryNode> stringStringDiscoveryNodeBiFunction) {
try {
fail("should not be called");
} finally {
failLatch.countDown();
}
}

@Override
public void onFailure(Exception e) {
try {
ex.set(e);
} finally {
failLatch.countDown();
}
}
});
failLatch.await();
assertNotNull(ex.get());
assertTrue(ex.get() instanceof IllegalArgumentException);
assertEquals("no such remote cluster: [no such cluster]", ex.get().getMessage());
}
{
// close all targets and check for the transport level failure path
IOUtils.close(c1N1, c1N2, c2N1, c2N2);
CountDownLatch failLatch = new CountDownLatch(1);
AtomicReference<Exception> ex = new AtomicReference<>();
service.collectNodes(new HashSet<>(Arrays.asList("cluster_1", "cluster_2")),
new ActionListener<BiFunction<String, String, DiscoveryNode>>() {
@Override
public void onResponse(BiFunction<String, String, DiscoveryNode> stringStringDiscoveryNodeBiFunction) {
try {
fail("should not be called");
} finally {
failLatch.countDown();
}
}

@Override
public void onFailure(Exception e) {
try {
ex.set(e);
} finally {
failLatch.countDown();
}
}
});
failLatch.await();
assertNotNull(ex.get());
assertTrue(ex.get().getClass().toString(), ex.get() instanceof TransportException);
}
}
}
}
}
}