Skip to content

Commit

Permalink
Fix cluster level restart model not auto redeploy issue
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Nov 22, 2023
1 parent 56f1663 commit 640bd96
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@
import org.opensearch.common.lifecycle.LifecycleListener;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.ml.autoredeploy.MLModelAutoReDeployer;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.threadpool.Scheduler;
import org.opensearch.threadpool.ThreadPool;

import lombok.extern.log4j.Log4j2;

import java.util.Arrays;
import java.util.List;

@Log4j2
public class MLCommonsClusterManagerEventListener implements LocalNodeClusterManagerListener {

Expand All @@ -35,14 +39,17 @@ public class MLCommonsClusterManagerEventListener implements LocalNodeClusterMan

private volatile Integer jobInterval;

private final MLModelAutoReDeployer mlModelAutoReDeployer;

public MLCommonsClusterManagerEventListener(
ClusterService clusterService,
Client client,
Settings settings,
ThreadPool threadPool,
DiscoveryNodeHelper nodeHelper,
MLIndicesHandler mlIndicesHandler,
Encryptor encryptor
Encryptor encryptor,
MLModelAutoReDeployer modelAutoReDeployer
) {
this.clusterService = clusterService;
this.client = client;
Expand All @@ -51,6 +58,7 @@ public MLCommonsClusterManagerEventListener(
this.nodeHelper = nodeHelper;
this.mlIndicesHandler = mlIndicesHandler;
this.encryptor = encryptor;
this.mlModelAutoReDeployer = modelAutoReDeployer;

this.jobInterval = ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS, it -> {
Expand All @@ -62,6 +70,8 @@ public MLCommonsClusterManagerEventListener(

@Override
public void onClusterManager() {
String localNodeId = clusterService.localNode().getId();
mlModelAutoReDeployer.buildAutoReloadArrangement(List.of(localNodeId), localNodeId);
if (syncModelRoutingCron == null) {
startSyncModelRoutingCron();
}
Expand Down
29 changes: 18 additions & 11 deletions plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.ml.autoredeploy.MLModelAutoReDeployer;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
Expand Down Expand Up @@ -63,6 +65,7 @@ public class MLSyncUpCron implements Runnable {
private volatile Boolean mlConfigInited;
@VisibleForTesting
Semaphore updateModelStateSemaphore;
private MLModelAutoReDeployer mlModelAutoReDeployer;

public MLSyncUpCron(
Client client,
Expand Down Expand Up @@ -116,6 +119,8 @@ public void run() {
Set<String> workerNodes = deployingModels.computeIfAbsent(modelId, it -> new HashSet<>());
workerNodes.add(nodeId);
}
} else {

}

String[] runningDeployModelTaskIds = response.getRunningDeployModelTaskIds();
Expand Down Expand Up @@ -270,17 +275,19 @@ void refreshModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, Se
newPlanningWorkerNodes.put(modelId, eligibleNodeIds);
}
}
MLModelState mlModelState = getNewModelState(
deployingModels,
modelWorkerNodes,
modelId,
state,
lastUpdateTime,
planningWorkerNodeCount,
currentWorkerNodeCountInIndex
);
if (mlModelState != null) {
newModelStates.put(modelId, mlModelState);
if (modelWorkerNodes != null && modelWorkerNodes.size() != 0) {
MLModelState mlModelState = getNewModelState(
deployingModels,
modelWorkerNodes,
modelId,
state,
lastUpdateTime,
planningWorkerNodeCount,
currentWorkerNodeCountInIndex
);
if (mlModelState != null) {
newModelStates.put(modelId, mlModelState);
}
}
}
bulkUpdateModelState(modelWorkerNodes, newModelStates, newPlanningWorkerNodes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,8 @@ public Collection<Object> createComponents(
threadPool,
nodeHelper,
mlIndicesHandler,
encryptor
encryptor,
mlModelAutoRedeployer
);

// TODO move this into MLFeatureEnabledSetting
Expand Down

0 comments on commit 640bd96

Please sign in to comment.