Skip to content

Commit

Permalink
Fix model delete failed after model undeployed
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Jul 10, 2024
1 parent 91d0127 commit f2406a9
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 348 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@

import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.CommonValue.UNDEPLOYED;
import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import org.opensearch.action.FailedNodeException;
Expand All @@ -29,6 +32,7 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
Expand All @@ -44,7 +48,6 @@
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

Expand Down Expand Up @@ -80,7 +83,7 @@ public TransportUndeployModelAction(
actionFilters,
MLUndeployModelNodesRequest::new,
MLUndeployModelNodeRequest::new,
ThreadPool.Names.MANAGEMENT,
DEPLOY_THREAD_POOL,
MLUndeployModelNodeResponse.class
);
this.mlModelManager = mlModelManager;
Expand All @@ -92,34 +95,25 @@ public TransportUndeployModelAction(
}

@Override
protected void doExecute(Task task, MLUndeployModelNodesRequest request, ActionListener<MLUndeployModelNodesResponse> listener) {
ActionListener<MLUndeployModelNodesResponse> wrappedListener = ActionListener.wrap(undeployModelNodesResponse -> {
processUndeployModelResponseAndUpdate(undeployModelNodesResponse, listener);
}, listener::onFailure);
super.doExecute(task, request, wrappedListener);
}

void processUndeployModelResponseAndUpdate(
MLUndeployModelNodesResponse undeployModelNodesResponse,
ActionListener<MLUndeployModelNodesResponse> listener
protected MLUndeployModelNodesResponse newResponse(
MLUndeployModelNodesRequest nodesRequest,
List<MLUndeployModelNodeResponse> responses,
List<FailedNodeException> failures
) {
List<MLUndeployModelNodeResponse> responses = undeployModelNodesResponse.getNodes();
if (responses == null || responses.isEmpty()) {
listener.onResponse(undeployModelNodesResponse);
return;
if (CollectionUtils.isEmpty(responses)) {
return new MLUndeployModelNodesResponse(clusterService.getClusterName(), responses, failures);
}

Map<String, List<String>> actualRemovedNodesMap = new HashMap<>();
Map<String, String[]> modelWorkNodesBeforeRemoval = new HashMap<>();
responses.forEach(r -> {
Map<String, String[]> nodeCounts = r.getModelWorkerNodeBeforeRemoval();

if (nodeCounts != null) {
for (Map.Entry<String, String[]> entry : nodeCounts.entrySet()) {
// when undeploy an undeployed model, the entry.getvalue() is null
// when undeploy a undeployed model, the entry.getvalue() is null
if (entry.getValue() != null
&& (!modelWorkNodesBeforeRemoval.containsKey(entry.getKey())
|| modelWorkNodesBeforeRemoval.get(entry.getKey()).length < entry.getValue().length)) {
|| modelWorkNodesBeforeRemoval.get(entry.getKey()).length < entry.getValue().length)) {
modelWorkNodesBeforeRemoval.put(entry.getKey(), entry.getValue());
}
}
Expand All @@ -144,8 +138,9 @@ void processUndeployModelResponseAndUpdate(
.build();

MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(nodeFilter.getAllNodes(), syncUpInput);
CountDownLatch countDownLatch = new CountDownLatch(1);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
if (actualRemovedNodesMap.size() > 0) {
if (!actualRemovedNodesMap.isEmpty()) {
BulkRequest bulkRequest = new BulkRequest();
Map<String, Boolean> deployToAllNodes = new HashMap<>();
for (String modelId : actualRemovedNodesMap.keySet()) {
Expand Down Expand Up @@ -188,24 +183,32 @@ void processUndeployModelResponseAndUpdate(
"updated model state as undeployed for : {}",
Arrays.toString(actualRemovedNodesMap.keySet().toArray(new String[0]))
);
}, e -> { log.error("Failed to update model state as undeployed", e); });
countDownLatch.countDown();
}, e -> {
log.error("Failed to update model state as undeployed", e);
countDownLatch.countDown();
});
client.bulk(bulkRequest, ActionListener.runAfter(actionListener, () -> {
syncUpUndeployedModels(syncUpRequest);
listener.onResponse(undeployModelNodesResponse);
context.restore();
}));
} else {
syncUpUndeployedModels(syncUpRequest);
listener.onResponse(undeployModelNodesResponse);
context.restore();
countDownLatch.countDown();
}
}
if (countDownLatch.getCount() != 0) {
try {
boolean success = countDownLatch.await(1000, TimeUnit.MILLISECONDS);
if (!success) {
log.error("Failed to update model state as undeployed in model index after waiting for 1 second, please check model status manually");
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
log.error("Failed to update model state as undeployed in model index, current thread is interrupted", e);
}
}
}

@Override
protected MLUndeployModelNodesResponse newResponse(
MLUndeployModelNodesRequest nodesRequest,
List<MLUndeployModelNodeResponse> responses,
List<FailedNodeException> failures
) {
return new MLUndeployModelNodesResponse(clusterService.getClusterName(), responses, failures);
}

Expand Down
Loading

0 comments on commit f2406a9

Please sign in to comment.