Skip to content

Commit

Permalink
Performance enhacement for predict action by caching model info (#1472)
Browse files Browse the repository at this point in the history
* Performance enhacement for predict action by caching model info

Signed-off-by: zane-neo <zaniu@amazon.com>

* Add context.restore() to avoid missing info

Signed-off-by: zane-neo <zaniu@amazon.com>

---------

Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo authored Oct 10, 2023
1 parent 70f8477 commit a985f6e
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
Expand Down Expand Up @@ -87,42 +88,66 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLTaskResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
mlModelManager.getModel(modelId, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName);
modelAccessControlHelper
.validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
wrappedListener
.onFailure(
new MLValidationException("User Doesn't have privilege to perform this operation on this model")
);
} else {
String requestId = mlPredictionTaskRequest.getRequestID();
log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId());
long startTime = System.nanoTime();
mlPredictTaskRunner
.run(
functionName,
mlPredictionTaskRequest,
transportService,
ActionListener.runAfter(wrappedListener, () -> {
long endTime = System.nanoTime();
double durationInMs = (endTime - startTime) / 1e6;
modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
log.debug("completed predict request " + requestId + " for model " + modelId);
})
);
}
}, e -> {
log.error("Failed to Validate Access for ModelId " + modelId, e);
wrappedListener.onFailure(e);
}));
}, e -> {
log.error("Failed to find model " + modelId, e);
wrappedListener.onFailure(e);
}));
MLModel cachedMlModel = modelCacheHelper.getModelInfo(modelId);
ActionListener<MLModel> modelActionListener = new ActionListener<>() {
@Override
public void onResponse(MLModel mlModel) {
context.restore();
modelCacheHelper.setModelInfo(modelId, mlModel);
FunctionName functionName = mlModel.getAlgorithm();
mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName);
modelAccessControlHelper
.validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
wrappedListener
.onFailure(
new MLValidationException("User Doesn't have privilege to perform this operation on this model")
);
} else {
executePredict(mlPredictionTaskRequest, wrappedListener, modelId);
}
}, e -> {
log.error("Failed to Validate Access for ModelId " + modelId, e);
wrappedListener.onFailure(e);
}));
}

@Override
public void onFailure(Exception e) {
log.error("Failed to find model " + modelId, e);
wrappedListener.onFailure(e);
}
};

if (cachedMlModel != null) {
modelActionListener.onResponse(cachedMlModel);
} else if (modelAccessControlHelper.skipModelAccessControl(user)) {
executePredict(mlPredictionTaskRequest, wrappedListener, modelId);
} else {
mlModelManager.getModel(modelId, modelActionListener);
}
}
}

private void executePredict(
MLPredictionTaskRequest mlPredictionTaskRequest,
ActionListener<MLTaskResponse> wrappedListener,
String modelId
) {
String requestId = mlPredictionTaskRequest.getRequestID();
log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId());
long startTime = System.nanoTime();
mlPredictTaskRunner
.run(
mlPredictionTaskRequest.getMlInput().getAlgorithm(),
mlPredictionTaskRequest,
transportService,
ActionListener.runAfter(wrappedListener, () -> {
long endTime = System.nanoTime();
double durationInMs = (endTime - startTime) / 1e6;
modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
log.debug("completed predict request " + requestId + " for model " + modelId);
})
);
}
}
20 changes: 19 additions & 1 deletion plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.stream.DoubleStream;

import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.engine.MLExecutable;
import org.opensearch.ml.engine.Predictable;
Expand All @@ -34,6 +35,7 @@ public class MLModelCache {
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) MLExecutable executor;
private final Set<String> targetWorkerNodes;
private final Set<String> workerNodes;
private MLModel modelInfo;
private final Queue<Double> modelInferenceDurationQueue;
private final Queue<Double> predictRequestDurationQueue;
private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Long memSizeEstimationCPU;
Expand Down Expand Up @@ -77,12 +79,16 @@ public String[] getTargetWorkerNodes() {
* @param isFromUndeploy
*/
public void removeWorkerNode(String nodeId, boolean isFromUndeploy) {
if ((deployToAllNodes != null && deployToAllNodes) || isFromUndeploy) {
if (this.isDeployToAllNodes() || isFromUndeploy) {
targetWorkerNodes.remove(nodeId);
}
if (isFromUndeploy)
deployToAllNodes = false;
workerNodes.remove(nodeId);
// when the model is not deployed to any node, we should remove the modelInfo from cache
if (targetWorkerNodes.isEmpty() || workerNodes.isEmpty()) {
modelInfo = null;
}
}

public void removeWorkerNodes(Set<String> removedNodes, boolean isFromUndeploy) {
Expand All @@ -92,6 +98,9 @@ public void removeWorkerNodes(Set<String> removedNodes, boolean isFromUndeploy)
if (isFromUndeploy)
deployToAllNodes = false;
workerNodes.removeAll(removedNodes);
if (targetWorkerNodes.isEmpty() || workerNodes.isEmpty()) {
modelInfo = null;
}
}

/**
Expand All @@ -112,6 +121,14 @@ public String[] getWorkerNodes() {
return workerNodes.toArray(new String[0]);
}

public void setModelInfo(MLModel modelInfo) {
this.modelInfo = modelInfo;
}

public MLModel getCachedModelInfo() {
return modelInfo;
}

public void syncWorkerNode(Set<String> workerNodes) {
this.workerNodes.clear();
this.workerNodes.addAll(workerNodes);
Expand All @@ -129,6 +146,7 @@ public void clear() {
modelState = null;
functionName = null;
workerNodes.clear();
modelInfo = null;
modelInferenceDurationQueue.clear();
predictRequestDurationQueue.clear();
if (predictor != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
Expand Down Expand Up @@ -429,6 +430,16 @@ public boolean getDeployToAllNodes(String modelId) {
return mlModelCache.isDeployToAllNodes();
}

public void setModelInfo(String modelId, MLModel mlModel) {
MLModelCache mlModelCache = getExistingModelCache(modelId);
mlModelCache.setModelInfo(mlModel);
}

public MLModel getModelInfo(String modelId) {
MLModelCache mlModelCache = getExistingModelCache(modelId);
return mlModelCache.getCachedModelInfo();
}

private MLModelCache getExistingModelCache(String modelId) {
MLModelCache modelCache = modelCaches.get(modelId);
if (modelCache == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.model;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand All @@ -26,6 +27,7 @@
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.exception.MLLimitExceededException;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.MLModelState;
Expand Down Expand Up @@ -251,6 +253,7 @@ public void testSyncWorkerNodes_ModelState() {
cacheHelper.syncWorkerNodes(modelWorkerNodes);
assertEquals(2, cacheHelper.getAllModels().length);
assertEquals(0, cacheHelper.getWorkerNodes(modelId2).length);
assertNull(cacheHelper.getModelInfo(modelId2));
assertArrayEquals(new String[] { newNodeId }, cacheHelper.getWorkerNodes(modelId));
}

Expand Down Expand Up @@ -323,6 +326,15 @@ public void test_removeWorkerNodes_with_deployToAllNodesStatus_isTrue() {
cacheHelper.removeWorkerNodes(ImmutableSet.of(nodeId), false);
cacheHelper.removeWorkerNode(modelId, nodeId, false);
assertEquals(0, cacheHelper.getWorkerNodes(modelId).length);
assertNull(cacheHelper.getModelInfo(modelId));
}

public void test_setModelInfo_success() {
cacheHelper.initModelState(modelId, MLModelState.DEPLOYED, FunctionName.TEXT_EMBEDDING, targetWorkerNodes, true);
MLModel model = mock(MLModel.class);
when(model.getModelId()).thenReturn("mockId");
cacheHelper.setModelInfo(modelId, model);
assertEquals("mockId", cacheHelper.getModelInfo(modelId).getModelId());
}

}

0 comments on commit a985f6e

Please sign in to comment.