Skip to content

Commit

Permalink
remove nodes that left the cluster from worker nodes
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
  • Loading branch information
rbhavna committed Oct 14, 2024
1 parent 2f54de1 commit 45c920f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
import static org.opensearch.ml.utils.MLExceptionUtils.toJsonString;
import static org.opensearch.ml.utils.RestActionUtils.getAllNodes;

import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
Expand Down Expand Up @@ -131,7 +135,25 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
syncModelWorkerNodes(modelId, functionName);
}

if (workNodes == null || workNodes.size() == 0) {
Set<String> workNodesRemovedFromCluster = new HashSet<>();

if (workNodes != null && !workNodes.isEmpty()) {
Set<String> allNodesInCluster = new HashSet<>(List.of(getAllNodes(clusterService)));

workNodesRemovedFromCluster = workNodes.stream()
.filter(node -> !allNodesInCluster.contains(node))
.collect(Collectors.toSet());

if (!workNodesRemovedFromCluster.isEmpty()) {
workNodes.removeAll(workNodesRemovedFromCluster);
}
}

if (workNodes == null || workNodes.isEmpty()) {
if (!workNodesRemovedFromCluster.isEmpty()) {
mlTaskCache.updateWorkerNodeCount(workNodesRemovedFromCluster);
mlModelManager.removeModelWorkerNode(modelId, false, workNodesRemovedFromCluster.toArray(new String[0]));
}
int currentWorkerNodeCount = mlTaskCache.getWorkerNodeSize();
MLTaskState taskState = mlTaskCache.hasError() ? MLTaskState.COMPLETED_WITH_ERROR : MLTaskState.COMPLETED;
if (mlTaskCache.allNodeFailed()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
import static org.opensearch.ml.plugin.MachineLearningPlugin.REMOTE_PREDICT_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MODEL_AUTO_DEPLOY_ENABLE;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
import static org.opensearch.ml.utils.RestActionUtils.getAllNodes;

import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;

import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchStatusException;
Expand Down Expand Up @@ -158,6 +163,15 @@ public void dispatchTask(
}
}, listener::onFailure);
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName, true);

if (workerNodes != null && workerNodes.length > 0) {
String[] allNodesInCluster = getAllNodes(clusterService);

workerNodes = Arrays.stream(workerNodes)
.filter(node -> Arrays.asList(allNodesInCluster).contains(node))
.toArray(String[]::new);
}

if (workerNodes == null || workerNodes.length == 0) {
if (FunctionName.isAutoDeployEnabled(autoDeploymentEnabled, functionName)) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
Expand Down
5 changes: 5 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,9 @@ public int errorNodesCount() {
public boolean allNodeFailed() {
return workerNodeSize != null && errors.size() == workerNodeSize;
}

public void updateWorkerNodeCount(Set<String> nodesRemovedFromCluster) {
this.workerNodes.removeAll(nodesRemovedFromCluster);
this.workerNodeSize = this.workerNodeSize - nodesRemovedFromCluster.size();
}
}

0 comments on commit 45c920f

Please sign in to comment.