Skip to content

Commit

Permalink
Fix for race condition in node-join/node-left loop (opensearch-projec…
Browse files Browse the repository at this point in the history
…t#15521)

* Add custom connect to node for handleJoinRequest

Signed-off-by: Rahul Karajgikar <karajgik@amazon.com>
  • Loading branch information
rahulkarajgikar authored and dk2k committed Oct 21, 2024
1 parent 0309e73 commit 7e3024e
Show file tree
Hide file tree
Showing 21 changed files with 844 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix search_as_you_type not supporting multi-fields ([#15988](https://github.com/opensearch-project/OpenSearch/pull/15988))
- Avoid infinite loop when `flat_object` field contains invalid token ([#15985](https://github.com/opensearch-project/OpenSearch/pull/15985))
- Fix infinite loop in nested agg ([#15931](https://github.com/opensearch-project/OpenSearch/pull/15931))
- Fix race condition in node-join and node-left ([#15521](https://github.com/opensearch-project/OpenSearch/pull/15521))

### Security

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ public class NodeConnectionsService extends AbstractLifecycleComponent {

// contains an entry for every node in the latest cluster state, as well as for nodes from which we are in the process of
// disconnecting
private final Map<DiscoveryNode, ConnectionTarget> targetsByNode = new HashMap<>();
protected final Map<DiscoveryNode, ConnectionTarget> targetsByNode = new HashMap<>();

private final TimeValue reconnectInterval;
private volatile ConnectionChecker connectionChecker;
protected volatile ConnectionChecker connectionChecker;

@Inject
public NodeConnectionsService(Settings settings, ThreadPool threadPool, TransportService transportService) {
Expand All @@ -115,6 +115,11 @@ public NodeConnectionsService(Settings settings, ThreadPool threadPool, Transpor
this.reconnectInterval = NodeConnectionsService.CLUSTER_NODE_RECONNECT_INTERVAL_SETTING.get(settings);
}

// exposed for testing
protected ConnectionTarget createConnectionTarget(DiscoveryNode discoveryNode) {
return new ConnectionTarget(discoveryNode);
}

/**
* Connect to all the given nodes, but do not disconnect from any extra nodes. Calls the completion handler on completion of all
* connection attempts to _new_ nodes, but not on attempts to re-establish connections to nodes that are already known.
Expand Down Expand Up @@ -159,6 +164,14 @@ public void connectToNodes(DiscoveryNodes discoveryNodes, Runnable onCompletion)
runnables.forEach(Runnable::run);
}

public void setPendingDisconnections(Set<DiscoveryNode> nodes) {
nodes.forEach(transportService::setPendingDisconnection);
}

public void clearPendingDisconnections() {
transportService.clearPendingDisconnections();
}

/**
* Disconnect from any nodes to which we are currently connected which do not appear in the given nodes. Does not wait for the
* disconnections to complete, because they might have to wait for ongoing connection attempts first.
Expand Down Expand Up @@ -211,7 +224,7 @@ private void awaitPendingActivity(Runnable onCompletion) {
* nodes which are in the process of disconnecting. The onCompletion handler is called after all ongoing connection/disconnection
* attempts have completed.
*/
private void connectDisconnectedTargets(Runnable onCompletion) {
protected void connectDisconnectedTargets(Runnable onCompletion) {
final List<Runnable> runnables = new ArrayList<>();
synchronized (mutex) {
final Collection<ConnectionTarget> connectionTargets = targetsByNode.values();
Expand Down Expand Up @@ -321,7 +334,7 @@ private enum ActivityType {
*
* @opensearch.internal
*/
private class ConnectionTarget {
protected class ConnectionTarget {
private final DiscoveryNode discoveryNode;

private PlainListenableActionFuture<Void> future = PlainListenableActionFuture.newListenableFuture();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.opensearch.cluster.ClusterStateTaskConfig;
import org.opensearch.cluster.ClusterStateUpdateTask;
import org.opensearch.cluster.LocalClusterUpdateTask;
import org.opensearch.cluster.NodeConnectionsService;
import org.opensearch.cluster.block.ClusterBlocks;
import org.opensearch.cluster.coordination.ClusterFormationFailureHelper.ClusterFormationState;
import org.opensearch.cluster.coordination.CoordinationMetadata.VotingConfigExclusion;
Expand Down Expand Up @@ -187,6 +188,7 @@ public class Coordinator extends AbstractLifecycleComponent implements Discovery
private final NodeHealthService nodeHealthService;
private final PersistedStateRegistry persistedStateRegistry;
private final RemoteStoreNodeService remoteStoreNodeService;
private NodeConnectionsService nodeConnectionsService;

/**
* @param nodeName The name of the node, used to name the {@link java.util.concurrent.ExecutorService} of the {@link SeedHostsResolver}.
Expand Down Expand Up @@ -418,7 +420,11 @@ PublishWithJoinResponse handlePublishRequest(PublishRequest publishRequest) {

synchronized (mutex) {
final DiscoveryNode sourceNode = publishRequest.getAcceptedState().nodes().getClusterManagerNode();
logger.trace("handlePublishRequest: handling [{}] from [{}]", publishRequest, sourceNode);
logger.debug(
"handlePublishRequest: handling version [{}] from [{}]",
publishRequest.getAcceptedState().getVersion(),
sourceNode
);

if (sourceNode.equals(getLocalNode()) && mode != Mode.LEADER) {
// Rare case in which we stood down as leader between starting this publication and receiving it ourselves. The publication
Expand Down Expand Up @@ -630,7 +636,6 @@ private void handleJoinRequest(JoinRequest joinRequest, JoinHelper.JoinCallback

transportService.connectToNode(joinRequest.getSourceNode(), ActionListener.wrap(ignore -> {
final ClusterState stateForJoinValidation = getStateForClusterManagerService();

if (stateForJoinValidation.nodes().isLocalNodeElectedClusterManager()) {
onJoinValidators.forEach(a -> a.accept(joinRequest.getSourceNode(), stateForJoinValidation));
if (stateForJoinValidation.getBlocks().hasGlobalBlock(STATE_NOT_RECOVERED_BLOCK) == false) {
Expand Down Expand Up @@ -814,6 +819,10 @@ public void onFailure(String source, Exception e) {
public ClusterTasksResult<LocalClusterUpdateTask> execute(ClusterState currentState) {
if (currentState.nodes().isLocalNodeElectedClusterManager() == false) {
allocationService.cleanCaches();
// This set only needs to be maintained on active cluster-manager
// This is cleaned up to avoid stale entries which would block future reconnections
logger.trace("Removing all pending disconnections as part of cluster-manager cleanup");
nodeConnectionsService.clearPendingDisconnections();
}
return unchanged();
}
Expand Down Expand Up @@ -914,11 +923,18 @@ public DiscoveryStats stats() {
@Override
public void startInitialJoin() {
synchronized (mutex) {
logger.trace("Starting initial join, becoming candidate");
becomeCandidate("startInitialJoin");
}
clusterBootstrapService.scheduleUnconfiguredBootstrap();
}

@Override
public void setNodeConnectionsService(NodeConnectionsService nodeConnectionsService) {
assert this.nodeConnectionsService == null : "nodeConnectionsService is already set";
this.nodeConnectionsService = nodeConnectionsService;
}

@Override
protected void doStop() {
configuredHostsResolver.stop();
Expand Down Expand Up @@ -1356,6 +1372,9 @@ assert getLocalNode().equals(clusterState.getNodes().get(getLocalNode().getId())
currentPublication = Optional.of(publication);

final DiscoveryNodes publishNodes = publishRequest.getAcceptedState().nodes();
// marking pending disconnects before publish
// if a nodes tries to send a joinRequest while it is pending disconnect, it should fail
nodeConnectionsService.setPendingDisconnections(new HashSet<>(clusterChangedEvent.nodesDelta().removedNodes()));
leaderChecker.setCurrentNodes(publishNodes);
followersChecker.setCurrentNodes(publishNodes);
lagDetector.setTrackedNodes(publishNodes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public Publication(PublishRequest publishRequest, AckListener ackListener, LongS
}

public void start(Set<DiscoveryNode> faultyNodes) {
logger.trace("publishing {} to {}", publishRequest, publicationTargets);
logger.debug("publishing version {} to {}", publishRequest.getAcceptedState().getVersion(), publicationTargets);

for (final DiscoveryNode faultyNode : faultyNodes) {
onFaultyNode(faultyNode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ public String executor() {
}

public void sendClusterState(DiscoveryNode destination, ActionListener<PublishWithJoinResponse> listener) {
logger.debug("sending cluster state over transport to node: {}", destination.getName());
logger.trace("sending cluster state over transport to node: {}", destination.getName());
if (sendFullVersion || previousState.nodes().nodeExists(destination) == false) {
logger.trace("sending full cluster state version [{}] to [{}]", newState.version(), destination);
sendFullClusterState(destination, listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ private void runTask(UpdateTask task) {
try {
applyChanges(task, previousClusterState, newClusterState, stopWatch);
TimeValue executionTime = TimeValue.timeValueMillis(Math.max(0, currentTimeInMillis() - startTimeMS));
// At this point, cluster state appliers and listeners are completed
logger.debug(
"processing [{}]: took [{}] done applying updated cluster state (version: {}, uuid: {})",
task.source,
Expand All @@ -510,6 +511,7 @@ private void runTask(UpdateTask task) {
newClusterState.stateUUID()
);
warnAboutSlowTaskIfNeeded(executionTime, task.source, stopWatch);
// Then we call the ClusterApplyListener of the task
task.listener.onSuccess(task.source);
} catch (Exception e) {
TimeValue executionTime = TimeValue.timeValueMillis(Math.max(0, currentTimeInMillis() - startTimeMS));
Expand Down Expand Up @@ -578,6 +580,7 @@ private void applyChanges(UpdateTask task, ClusterState previousClusterState, Cl

logger.debug("apply cluster state with version {}", newClusterState.version());
callClusterStateAppliers(clusterChangedEvent, stopWatch);
logger.debug("completed calling appliers of cluster state for version {}", newClusterState.version());

nodeConnectionsService.disconnectFromNodesExcept(newClusterState.nodes());

Expand All @@ -594,6 +597,7 @@ private void applyChanges(UpdateTask task, ClusterState previousClusterState, Cl
state.set(newClusterState);

callClusterStateListeners(clusterChangedEvent, stopWatch);
logger.debug("completed calling listeners of cluster state for version {}", newClusterState.version());
}

protected void connectToNodesAndWait(ClusterState newClusterState) {
Expand Down
5 changes: 5 additions & 0 deletions server/src/main/java/org/opensearch/discovery/Discovery.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

package org.opensearch.discovery;

import org.opensearch.cluster.NodeConnectionsService;
import org.opensearch.cluster.coordination.ClusterStatePublisher;
import org.opensearch.common.lifecycle.LifecycleComponent;

Expand All @@ -54,4 +55,8 @@ public interface Discovery extends LifecycleComponent, ClusterStatePublisher {
*/
void startInitialJoin();

/**
* Sets the NodeConnectionsService which is an abstraction used for connection management
*/
void setNodeConnectionsService(NodeConnectionsService nodeConnectionsService);
}
1 change: 1 addition & 0 deletions server/src/main/java/org/opensearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -1602,6 +1602,7 @@ public Node start() throws NodeValidationException {

injector.getInstance(GatewayService.class).start();
Discovery discovery = injector.getInstance(Discovery.class);
discovery.setNodeConnectionsService(nodeConnectionsService);
clusterService.getClusterManagerService().setClusterStatePublisher(discovery::publish);

// Start the transport service now so the publish address will be added to the local disco node in ClusterService
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ public class ClusterConnectionManager implements ConnectionManager {

private final ConcurrentMap<DiscoveryNode, Transport.Connection> connectedNodes = ConcurrentCollections.newConcurrentMap();
private final ConcurrentMap<DiscoveryNode, ListenableFuture<Void>> pendingConnections = ConcurrentCollections.newConcurrentMap();
/**
This set is used only by cluster-manager nodes.
Nodes are marked as pending disconnect right before cluster state publish phase.
They are cleared up as part of cluster state apply commit phase
This is to avoid connections from being made to nodes that are in the process of leaving the cluster
Note: If a disconnect is initiated while a connect is in progress, this Set will not handle this case.
Callers need to ensure that connects and disconnects are sequenced.
*/
private final Set<DiscoveryNode> pendingDisconnections = ConcurrentCollections.newConcurrentSet();
private final AbstractRefCounted connectingRefCounter = new AbstractRefCounted("connection manager") {
@Override
protected void closeInternal() {
Expand Down Expand Up @@ -122,12 +131,19 @@ public void connectToNode(
ConnectionValidator connectionValidator,
ActionListener<Void> listener
) throws ConnectTransportException {
logger.trace("connecting to node [{}]", node);
ConnectionProfile resolvedProfile = ConnectionProfile.resolveConnectionProfile(connectionProfile, defaultProfile);
if (node == null) {
listener.onFailure(new ConnectTransportException(null, "can't connect to a null node"));
return;
}

// if node-left is still in progress, we fail the connect request early
if (pendingDisconnections.contains(node)) {
listener.onFailure(new IllegalStateException("cannot make a new connection as disconnect to node [" + node + "] is pending"));
return;
}

if (connectingRefCounter.tryIncRef() == false) {
listener.onFailure(new IllegalStateException("connection manager is closed"));
return;
Expand Down Expand Up @@ -170,6 +186,7 @@ public void connectToNode(
conn.addCloseListener(ActionListener.wrap(() -> {
logger.trace("unregistering {} after connection close and marking as disconnected", node);
connectedNodes.remove(node, finalConnection);
pendingDisconnections.remove(node);
connectionListener.onNodeDisconnected(node, conn);
}));
}
Expand Down Expand Up @@ -226,6 +243,19 @@ public void disconnectFromNode(DiscoveryNode node) {
// if we found it and removed it we close
nodeChannels.close();
}
pendingDisconnections.remove(node);
logger.trace("Removed node [{}] from pending disconnections list", node);
}

@Override
public void setPendingDisconnection(DiscoveryNode node) {
logger.trace("marking disconnection as pending for node: [{}]", node);
pendingDisconnections.add(node);
}

@Override
public void clearPendingDisconnections() {
pendingDisconnections.clear();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ void connectToNode(

void disconnectFromNode(DiscoveryNode node);

void setPendingDisconnection(DiscoveryNode node);

void clearPendingDisconnections();

Set<DiscoveryNode> getAllConnectedNodes();

int size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,16 @@ public void disconnectFromNode(DiscoveryNode node) {
delegate.disconnectFromNode(node);
}

@Override
public void setPendingDisconnection(DiscoveryNode node) {
delegate.setPendingDisconnection(node);
}

@Override
public void clearPendingDisconnections() {
delegate.clearPendingDisconnections();
}

@Override
public ConnectionProfile getConnectionProfile() {
return delegate.getConnectionProfile();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,18 @@ public void disconnectFromNode(DiscoveryNode node) {
connectionManager.disconnectFromNode(node);
}

public void setPendingDisconnection(DiscoveryNode node) {
connectionManager.setPendingDisconnection(node);
}

/**
* Wipes out all pending disconnections.
* This is called on cluster-manager failover to remove stale entries
*/
public void clearPendingDisconnections() {
connectionManager.clearPendingDisconnections();
}

public void addMessageListener(TransportMessageListener listener) {
messageListener.listeners.add(listener);
}
Expand Down
Loading

0 comments on commit 7e3024e

Please sign in to comment.