diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 26a5a66de3..53aa0fa267 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -16,6 +16,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; @@ -87,42 +88,66 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); - mlModelManager.getModel(modelId, ActionListener.wrap(mlModel -> { - FunctionName functionName = mlModel.getAlgorithm(); - mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName); - modelAccessControlHelper - .validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { - if (!access) { - wrappedListener - .onFailure( - new MLValidationException("User Doesn't have privilege to perform this operation on this model") - ); - } else { - String requestId = mlPredictionTaskRequest.getRequestID(); - log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId()); - long startTime = System.nanoTime(); - mlPredictTaskRunner - .run( - functionName, - mlPredictionTaskRequest, - transportService, - ActionListener.runAfter(wrappedListener, () -> { - long endTime = System.nanoTime(); - double durationInMs = (endTime - startTime) / 1e6; - modelCacheHelper.addPredictRequestDuration(modelId, durationInMs); - log.debug("completed predict request " + requestId + " for model " + modelId); - }) - ); - } - }, e -> { - log.error("Failed to Validate Access for ModelId " + modelId, e); - wrappedListener.onFailure(e); - })); - }, e -> { - log.error("Failed to find model " + modelId, e); - wrappedListener.onFailure(e); - })); + MLModel cachedMlModel = modelCacheHelper.getModelInfo(modelId); + ActionListener modelActionListener = new ActionListener<>() { + @Override + public void onResponse(MLModel mlModel) { + context.restore(); + modelCacheHelper.setModelInfo(modelId, mlModel); + FunctionName functionName = mlModel.getAlgorithm(); + mlPredictionTaskRequest.getMlInput().setAlgorithm(functionName); + modelAccessControlHelper + .validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { + if (!access) { + wrappedListener + .onFailure( + new MLValidationException("User Doesn't have privilege to perform this operation on this model") + ); + } else { + executePredict(mlPredictionTaskRequest, wrappedListener, modelId); + } + }, e -> { + log.error("Failed to Validate Access for ModelId " + modelId, e); + wrappedListener.onFailure(e); + })); + } + @Override + public void onFailure(Exception e) { + log.error("Failed to find model " + modelId, e); + wrappedListener.onFailure(e); + } + }; + + if (cachedMlModel != null) { + modelActionListener.onResponse(cachedMlModel); + } else if (modelAccessControlHelper.skipModelAccessControl(user)) { + executePredict(mlPredictionTaskRequest, wrappedListener, modelId); + } else { + mlModelManager.getModel(modelId, modelActionListener); + } } } + + private void executePredict( + MLPredictionTaskRequest mlPredictionTaskRequest, + ActionListener wrappedListener, + String modelId + ) { + String requestId = mlPredictionTaskRequest.getRequestID(); + log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId()); + long startTime = System.nanoTime(); + mlPredictTaskRunner + .run( + mlPredictionTaskRequest.getMlInput().getAlgorithm(), + mlPredictionTaskRequest, + transportService, + ActionListener.runAfter(wrappedListener, () -> { + long endTime = System.nanoTime(); + double durationInMs = (endTime - startTime) / 1e6; + modelCacheHelper.addPredictRequestDuration(modelId, durationInMs); + log.debug("completed predict request " + requestId + " for model " + modelId); + }) + ); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java index 19ca890df5..5fd7d71ce0 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCache.java @@ -14,6 +14,7 @@ import java.util.stream.DoubleStream; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.engine.MLExecutable; import org.opensearch.ml.engine.Predictable; @@ -34,6 +35,7 @@ public class MLModelCache { private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) MLExecutable executor; private final Set targetWorkerNodes; private final Set workerNodes; + private MLModel modelInfo; private final Queue modelInferenceDurationQueue; private final Queue predictRequestDurationQueue; private @Setter(AccessLevel.PROTECTED) @Getter(AccessLevel.PROTECTED) Long memSizeEstimationCPU; @@ -77,12 +79,16 @@ public String[] getTargetWorkerNodes() { * @param isFromUndeploy */ public void removeWorkerNode(String nodeId, boolean isFromUndeploy) { - if ((deployToAllNodes != null && deployToAllNodes) || isFromUndeploy) { + if (this.isDeployToAllNodes() || isFromUndeploy) { targetWorkerNodes.remove(nodeId); } if (isFromUndeploy) deployToAllNodes = false; workerNodes.remove(nodeId); + // when the model is not deployed to any node, we should remove the modelInfo from cache + if (targetWorkerNodes.isEmpty() || workerNodes.isEmpty()) { + modelInfo = null; + } } public void removeWorkerNodes(Set removedNodes, boolean isFromUndeploy) { @@ -92,6 +98,9 @@ public void removeWorkerNodes(Set removedNodes, boolean isFromUndeploy) if (isFromUndeploy) deployToAllNodes = false; workerNodes.removeAll(removedNodes); + if (targetWorkerNodes.isEmpty() || workerNodes.isEmpty()) { + modelInfo = null; + } } /** @@ -112,6 +121,14 @@ public String[] getWorkerNodes() { return workerNodes.toArray(new String[0]); } + public void setModelInfo(MLModel modelInfo) { + this.modelInfo = modelInfo; + } + + public MLModel getCachedModelInfo() { + return modelInfo; + } + public void syncWorkerNode(Set workerNodes) { this.workerNodes.clear(); this.workerNodes.addAll(workerNodes); @@ -129,6 +146,7 @@ public void clear() { modelState = null; functionName = null; workerNodes.clear(); + modelInfo = null; modelInferenceDurationQueue.clear(); predictRequestDurationQueue.clear(); if (predictor != null) { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 74dbc26d61..680524324d 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -18,6 +18,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; @@ -429,6 +430,16 @@ public boolean getDeployToAllNodes(String modelId) { return mlModelCache.isDeployToAllNodes(); } + public void setModelInfo(String modelId, MLModel mlModel) { + MLModelCache mlModelCache = getExistingModelCache(modelId); + mlModelCache.setModelInfo(mlModel); + } + + public MLModel getModelInfo(String modelId) { + MLModelCache mlModelCache = getExistingModelCache(modelId); + return mlModelCache.getCachedModelInfo(); + } + private MLModelCache getExistingModelCache(String modelId) { MLModelCache modelCache = modelCaches.get(modelId); if (modelCache == null) { diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java index 78d67b3ce1..603c315a12 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java @@ -5,6 +5,7 @@ package org.opensearch.ml.model; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -26,6 +27,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; @@ -251,6 +253,7 @@ public void testSyncWorkerNodes_ModelState() { cacheHelper.syncWorkerNodes(modelWorkerNodes); assertEquals(2, cacheHelper.getAllModels().length); assertEquals(0, cacheHelper.getWorkerNodes(modelId2).length); + assertNull(cacheHelper.getModelInfo(modelId2)); assertArrayEquals(new String[] { newNodeId }, cacheHelper.getWorkerNodes(modelId)); } @@ -323,6 +326,15 @@ public void test_removeWorkerNodes_with_deployToAllNodesStatus_isTrue() { cacheHelper.removeWorkerNodes(ImmutableSet.of(nodeId), false); cacheHelper.removeWorkerNode(modelId, nodeId, false); assertEquals(0, cacheHelper.getWorkerNodes(modelId).length); + assertNull(cacheHelper.getModelInfo(modelId)); + } + + public void test_setModelInfo_success() { + cacheHelper.initModelState(modelId, MLModelState.DEPLOYED, FunctionName.TEXT_EMBEDDING, targetWorkerNodes, true); + MLModel model = mock(MLModel.class); + when(model.getModelId()).thenReturn("mockId"); + cacheHelper.setModelInfo(modelId, model); + assertEquals("mockId", cacheHelper.getModelInfo(modelId).getModelId()); } }