Skip to content

Commit

Permalink
add eligible node role settings (#1197)
Browse files Browse the repository at this point in the history
* add eligible node role settings

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* add more comment

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

---------

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn authored Aug 10, 2023
1 parent deb0008 commit b8c73fd
Show file tree
Hide file tree
Showing 28 changed files with 287 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
modelAccessControlHelper.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
listener
Expand All @@ -141,7 +142,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
}
// mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();
DiscoveryNode[] allEligibleNodes = nodeFilter.getEligibleNodes();
DiscoveryNode[] allEligibleNodes = nodeFilter.getEligibleNodes(functionName);
Map<String, DiscoveryNode> nodeMapping = new HashMap<>();
for (DiscoveryNode node : allEligibleNodes) {
nodeMapping.put(node.getId(), node);
Expand All @@ -161,7 +162,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
nodeIds.add(nodeId);
}
}
String[] workerNodes = mlModelManager.getWorkerNodes(modelId);
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName);
if (workerNodes != null && workerNodes.length > 0) {
Set<String> difference = new HashSet<String>(Arrays.asList(workerNodes));
difference.removeAll(Arrays.asList(targetNodeIds));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.common.inject.Inject;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
Expand Down Expand Up @@ -42,6 +43,7 @@ public TransportExecuteTaskAction(
@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<MLExecuteTaskResponse> listener) {
MLExecuteTaskRequest mlPredictionTaskRequest = MLExecuteTaskRequest.fromActionRequest(request);
mlExecuteTaskRunner.run(mlPredictionTaskRequest, transportService, listener);
FunctionName functionName = mlPredictionTaskRequest.getFunctionName();
mlExecuteTaskRunner.run(functionName, mlPredictionTaskRequest, transportService, listener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.autoredeploy.MLModelAutoReDeployer;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
Expand Down Expand Up @@ -116,26 +117,26 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLForw
switch (requestType) {
case DEPLOY_MODEL_DONE:
Set<String> workNodes = mlTaskManager.getWorkNodes(taskId);
MLTaskCache mlTaskCache = mlTaskManager.getMLTaskCache(taskId);
FunctionName functionName = mlTaskCache.getMlTask().getFunctionName();
if (workNodes != null) {
workNodes.remove(workerNodeId);
}

if (error != null) {
mlTaskManager.addNodeError(taskId, workerNodeId, error);
} else {
mlModelManager.addModelWorkerNode(modelId, workerNodeId);
syncModelWorkerNodes(modelId);
syncModelWorkerNodes(modelId, functionName);
}

if (workNodes == null || workNodes.size() == 0) {
MLTaskCache mlTaskCache = mlTaskManager.getMLTaskCache(taskId);
int currentWorkerNodeCount = mlTaskCache.getWorkerNodeSize();
MLTaskState taskState = mlTaskCache.hasError() ? MLTaskState.COMPLETED_WITH_ERROR : MLTaskState.COMPLETED;
if (mlTaskCache.allNodeFailed()) {
taskState = MLTaskState.FAILED;
currentWorkerNodeCount = 0;
} else {
syncModelWorkerNodes(modelId);
syncModelWorkerNodes(modelId, functionName);
}
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
builder.put(MLTask.STATE_FIELD, taskState);
Expand Down Expand Up @@ -196,9 +197,9 @@ private boolean triggerNextModelDeployAndCheckIfRestRetryTimes(Set<String> workN
return false;
}

private void syncModelWorkerNodes(String modelId) {
private void syncModelWorkerNodes(String modelId, FunctionName functionName) {
DiscoveryNode[] allNodes = nodeHelper.getAllNodes();
String[] workerNodes = mlModelManager.getWorkerNodes(modelId);
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName);
if (allNodes.length > 1 && workerNodes != null && workerNodes.length > 0) {
log.debug("Sync to other nodes about worker nodes of model {}: {}", modelId, Arrays.toString(workerNodes));
MLSyncUpInput syncUpInput = MLSyncUpInput.builder().addedWorkerNodes(ImmutableMap.of(modelId, workerNodes)).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
Expand Down Expand Up @@ -86,6 +87,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
mlModelManager.getModel(modelId, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
modelAccessControlHelper
.validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
Expand All @@ -97,12 +99,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
String requestId = mlPredictionTaskRequest.getRequestID();
log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId());
long startTime = System.nanoTime();
mlPredictTaskRunner.run(mlPredictionTaskRequest, transportService, ActionListener.runAfter(listener, () -> {
long endTime = System.nanoTime();
double durationInMs = (endTime - startTime) / 1e6;
modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
log.debug("completed predict request " + requestId + " for model " + modelId);
}));
mlPredictTaskRunner
.run(functionName, mlPredictionTaskRequest, transportService, ActionListener.runAfter(listener, () -> {
long endTime = System.nanoTime();
double durationInMs = (endTime - startTime) / 1e6;
modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
log.debug("completed predict request " + requestId + " for model " + modelId);
}));
}
}, e -> {
log.error("Failed to Validate Access for ModelId " + modelId, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen
}));
return;
}
mlTaskDispatcher.dispatch(ActionListener.wrap(node -> {
mlTaskDispatcher.dispatch(registerModelInput.getFunctionName(), ActionListener.wrap(node -> {
String nodeId = node.getId();
mlTask.setWorkerNodes(ImmutableList.of(nodeId));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,6 @@ public TransportTrainingTaskAction(
@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<MLTaskResponse> listener) {
MLTrainingTaskRequest trainingRequest = MLTrainingTaskRequest.fromActionRequest(request);
mlTrainingTaskRunner.run(trainingRequest, transportService, listener);
mlTrainingTaskRunner.run(trainingRequest.getMlInput().getFunctionName(), trainingRequest, transportService, listener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ public TransportTrainAndPredictionTaskAction(
@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<MLTaskResponse> listener) {
MLTrainingTaskRequest trainingRequest = MLTrainingTaskRequest.fromActionRequest(request);
mlTrainAndPredictTaskRunner.run(trainingRequest, transportService, listener);
mlTrainAndPredictTaskRunner.run(trainingRequest.getMlInput().getFunctionName(), trainingRequest, transportService, listener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
Expand Down Expand Up @@ -238,7 +239,8 @@ private MLUndeployModelNodeResponse createUndeployModelNodeResponse(MLUndeployMo
String[] removedModelIds = specifiedModelIds ? modelIds : mlModelManager.getAllModelIds();
if (removedModelIds != null) {
for (String modelId : removedModelIds) {
String[] workerNodes = mlModelManager.getWorkerNodes(modelId);
FunctionName functionName = mlModelManager.getModelFunctionName(modelId);
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName);
modelWorkerNodesMap.put(modelId, workerNodes);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
package org.opensearch.ml.cluster;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_EXCLUDE_NODE_NAMES;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
Expand All @@ -21,6 +22,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.Strings;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.utils.MLNodeUtils;

import lombok.extern.log4j.Log4j2;
Expand All @@ -31,6 +33,8 @@ public class DiscoveryNodeHelper {
private final HotDataNodePredicate eligibleNodeFilter;
private volatile Boolean onlyRunOnMLNode;
private volatile Set<String> excludedNodeNames;
private volatile Set<String> remoteModelEligibleNodeRoles;
private volatile Set<String> localModelEligibleNodeRoles;

public DiscoveryNodeHelper(ClusterService clusterService, Settings settings) {
this.clusterService = clusterService;
Expand All @@ -41,44 +45,61 @@ public DiscoveryNodeHelper(ClusterService clusterService, Settings settings) {
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_EXCLUDE_NODE_NAMES, it -> excludedNodeNames = Strings.commaDelimitedListToSet(it));
remoteModelEligibleNodeRoles = new HashSet<>();
remoteModelEligibleNodeRoles.addAll(ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES.get(settings));
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES, it -> {
remoteModelEligibleNodeRoles = new HashSet<>(it);
});
localModelEligibleNodeRoles = new HashSet<>();
localModelEligibleNodeRoles.addAll(ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES.get(settings));
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES, it -> {
localModelEligibleNodeRoles = new HashSet<>(it);
});
}

public String[] getEligibleNodeIds() {
DiscoveryNode[] nodes = getEligibleNodes();
public String[] getEligibleNodeIds(FunctionName functionName) {
DiscoveryNode[] nodes = getEligibleNodes(functionName);
String[] nodeIds = new String[nodes.length];
for (int i = 0; i < nodes.length; i++) {
nodeIds[i] = nodes[i].getId();
}
return nodeIds;
}

public DiscoveryNode[] getEligibleNodes() {
public DiscoveryNode[] getEligibleNodes(FunctionName functionName) {
ClusterState state = this.clusterService.state();
final List<DiscoveryNode> eligibleMLNodes = new ArrayList<>();
final List<DiscoveryNode> eligibleDataNodes = new ArrayList<>();
final List<DiscoveryNode> eligibleNodes = new ArrayList<>();
for (DiscoveryNode node : state.nodes()) {
if (excludedNodeNames != null && excludedNodeNames.contains(node.getName())) {
continue;
}
if (MLNodeUtils.isMLNode(node)) {
eligibleMLNodes.add(node);
}
if (!onlyRunOnMLNode && node.isDataNode() && isEligibleDataNode(node)) {
eligibleDataNodes.add(node);
if (functionName == FunctionName.REMOTE) {// remote model
getEligibleNodes(remoteModelEligibleNodeRoles, eligibleNodes, node);
} else { // local model
if (onlyRunOnMLNode) {
if (MLNodeUtils.isMLNode(node)) {
eligibleNodes.add(node);
}
} else {
getEligibleNodes(localModelEligibleNodeRoles, eligibleNodes, node);
}
}
}
if (eligibleMLNodes.size() > 0) {
DiscoveryNode[] mlNodes = eligibleMLNodes.toArray(new DiscoveryNode[0]);
log.debug("Find {} dedicated ML nodes: {}", eligibleMLNodes.size(), Arrays.toString(mlNodes));
return mlNodes;
} else {
DiscoveryNode[] dataNodes = eligibleDataNodes.toArray(new DiscoveryNode[0]);
log.debug("Find no dedicated ML nodes. But have {} data nodes: {}", eligibleDataNodes.size(), Arrays.toString(dataNodes));
return dataNodes;
return eligibleNodes.toArray(new DiscoveryNode[0]);
}

private void getEligibleNodes(Set<String> allowedNodeRoles, List<DiscoveryNode> eligibleNodes, DiscoveryNode node) {
if (allowedNodeRoles.contains("data") && isEligibleDataNode(node)) {
eligibleNodes.add(node);
}
for (String nodeRole : allowedNodeRoles) {
if (!"data".equals(nodeRole) && node.getRoles().stream().anyMatch(r -> r.roleName().equals(nodeRole))) {
eligibleNodes.add(node);
}
}
}

public String[] filterEligibleNodes(String[] nodeIds) {
public String[] filterEligibleNodes(FunctionName functionName, String[] nodeIds) {
if (nodeIds == null || nodeIds.length == 0) {
return nodeIds;
}
Expand All @@ -88,14 +109,30 @@ public String[] filterEligibleNodes(String[] nodeIds) {
if (excludedNodeNames != null && excludedNodeNames.contains(node.getName())) {
continue;
}
if (MLNodeUtils.isMLNode(node)) {
eligibleNodes.add(node.getId());
if (functionName == FunctionName.REMOTE) {// remote model
getEligibleNodes(remoteModelEligibleNodeRoles, eligibleNodes, node);
} else { // local model
if (onlyRunOnMLNode) {
if (MLNodeUtils.isMLNode(node)) {
eligibleNodes.add(node.getId());
}
} else {
getEligibleNodes(localModelEligibleNodeRoles, eligibleNodes, node);
}
}
if (!onlyRunOnMLNode && node.isDataNode() && isEligibleDataNode(node)) {
}
return eligibleNodes.toArray(new String[0]);
}

private void getEligibleNodes(Set<String> allowedNodeRoles, Set<String> eligibleNodes, DiscoveryNode node) {
if (allowedNodeRoles.contains("data") && isEligibleDataNode(node)) {
eligibleNodes.add(node.getId());
}
for (String nodeRole : allowedNodeRoles) {
if (!"data".equals(nodeRole) && node.getRoles().stream().anyMatch(r -> r.roleName().equals(nodeRole))) {
eligibleNodes.add(node.getId());
}
}
return eligibleNodes.toArray(new String[0]);
}

public DiscoveryNode[] getAllNodes() {
Expand Down
Loading

0 comments on commit b8c73fd

Please sign in to comment.