diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 53aa0fa267..1d58a5d0db 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -137,9 +137,16 @@ private void executePredict( String requestId = mlPredictionTaskRequest.getRequestID(); log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId()); long startTime = System.nanoTime(); + // For remote text embedding model, neural search will set mlPredictionTaskRequest.getMlInput().getAlgorithm() as + // TEXT_EMBEDDING. In ml-commons we should always use the real function name of model: REMOTE. So we try to get + // from model cache first. + FunctionName functionName = modelCacheHelper + .getOptionalFunctionName(modelId) + .orElse(mlPredictionTaskRequest.getMlInput().getAlgorithm()); mlPredictTaskRunner .run( - mlPredictionTaskRequest.getMlInput().getAlgorithm(), + // This is by design to NOT use mlPredictionTaskRequest.getMlInput().getAlgorithm() here + functionName, mlPredictionTaskRequest, transportService, ActionListener.runAfter(wrappedListener, () -> { diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 554065ed95..553ffeb664 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -431,8 +431,10 @@ public boolean getDeployToAllNodes(String modelId) { } public void setModelInfo(String modelId, MLModel mlModel) { - MLModelCache mlModelCache = getExistingModelCache(modelId); - mlModelCache.setModelInfo(mlModel); + MLModelCache mlModelCache = modelCaches.get(modelId); + if (mlModelCache != null) { + mlModelCache.setModelInfo(mlModel); + } } public MLModel getModelInfo(String modelId) {