diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index e7f9c88c66..c924808f5c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -101,6 +101,9 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST); } String modelResponse = responseBuilder.toString(); + if (statusCode < 200 || statusCode >= 300) { + throw new OpenSearchStatusException(modelResponse, RestStatus.fromCode(statusCode)); + } ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters); tensors.setStatusCode(statusCode); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index ab4ed7fd6c..cbfad6fd5f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -102,7 +102,7 @@ private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDat docs.add(null); } } - if (preProcessFunction.contains("${parameters")) { + if (preProcessFunction.contains("${parameters.")) { StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); preProcessFunction = substitutor.replace(preProcessFunction); } @@ -164,7 +164,7 @@ public static ModelTensors processOutput(String modelResponse, Connector connect // execute user defined painless script. Optional processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse); String response = processedResponse.orElse(modelResponse); - boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent(); + boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent() && org.opensearch.ml.common.utils.StringUtils.isJson(response); if (responseFilter == null) { connector.parseResponse(response, modelTensors, scriptReturnModelTensor); } else { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 016baf7229..1a993b88d2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -16,6 +16,7 @@ import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.util.EntityUtils; +import org.opensearch.OpenSearchStatusException; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.exception.MLException; @@ -23,6 +24,7 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.engine.annotation.ConnectorExecutor; import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; +import org.opensearch.rest.RestStatus; import org.opensearch.script.ScriptService; import java.security.AccessController; @@ -103,9 +105,13 @@ public void invokeRemoteModel(MLInput mlInput, Map parameters, S return null; }); String modelResponse = responseRef.get(); + Integer statusCode = statusCodeRef.get(); + if (statusCode < 200 || statusCode >= 300) { + throw new OpenSearchStatusException(modelResponse, RestStatus.fromCode(statusCode)); + } ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters); - tensors.setStatusCode(statusCodeRef.get()); + tensors.setStatusCode(statusCode); tensorOutputs.add(tensors); } catch (RuntimeException e) { log.error("Fail to execute http connector", e); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 8712f771c7..176fdb428e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -32,8 +32,15 @@ default ModelTensorOutput executePredict(MLInput mlInput) { if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); - List textDocs = new ArrayList<>(textDocsInputDataSet.getDocs()); - preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tensorOutputs); + int processedDocs = 0; + while(processedDocs < textDocsInputDataSet.getDocs().size()) { + List textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size()); + List tempTensorOutputs = new ArrayList<>(); + preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tempTensorOutputs); + processedDocs += Math.max(tempTensorOutputs.size(), 1); + tensorOutputs.addAll(tempTensorOutputs); + } + } else { preparePayloadAndInvokeRemoteModel(mlInput, tensorOutputs); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 8d39db9a79..8b0d5a8173 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -118,6 +118,36 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); } + @Test + public void executePredict_RemoteInferenceInput_InvalidToken() throws IOException { + exceptionRule.expect(OpenSearchStatusException.class); + exceptionRule.expectMessage("{\"message\":\"The security token included in the request is invalid\"}"); + String jsonString = "{\"message\":\"The security token included in the request is invalid\"}"; + InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes()); + AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream); + when(response.responseBody()).thenReturn(Optional.of(abortableInputStream)); + when(httpRequest.call()).thenReturn(response); + SdkHttpResponse httpResponse = mock(SdkHttpResponse.class); + when(httpResponse.statusCode()).thenReturn(403); + when(response.httpResponse()).thenReturn(httpResponse); + when(httpClient.prepareRequest(any())).thenReturn(httpRequest); + + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Map credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); + Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); + connector.decrypt((c) -> encryptor.decrypt(c)); + AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); + + MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build(); + executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + } + @Test public void executePredict_RemoteInferenceInput() throws IOException { String jsonString = "{\"key\":\"value\"}"; @@ -176,7 +206,7 @@ public void executePredict_TextDocsInferenceInput() throws IOException { connector.decrypt((c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); - MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input", "test input data")).build(); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build(); ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 5cbdce053e..c2e3b423f9 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -20,6 +20,7 @@ import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchStatusException; import org.opensearch.cluster.ClusterStateTaskConfig; import org.opensearch.ingest.TestTemplateService; import org.opensearch.ml.common.FunctionName; @@ -120,12 +121,34 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); Assert.assertEquals("test result", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response")); } + @Test + public void executePredict_TextDocsInput_LimitExceed() throws IOException { + exceptionRule.expect(OpenSearchStatusException.class); + exceptionRule.expectMessage("{\"message\": \"Too many requests\"}"); + ConnectorAction predictAction = ConnectorAction.builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .build(); + when(httpClient.execute(any())).thenReturn(response); + HttpEntity entity = new StringEntity("{\"message\": \"Too many requests\"}"); + when(response.getEntity()).thenReturn(entity); + StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 429, "OK"); + when(response.getStatusLine()).thenReturn(statusLine); + Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); + when(executor.getHttpClient()).thenReturn(httpClient); + MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); + executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + } + @Test public void executePredict_TextDocsInput() throws IOException { String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }"; @@ -161,7 +184,7 @@ public void executePredict_TextDocsInput() throws IOException { when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); - Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); + Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData()); Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData()); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 99cfd18bc8..a101a1241d 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -746,8 +746,8 @@ public void deployModel( CLUSTER_SERVICE, clusterService ); - // deploy remote model or model trained by built-in algorithm like kmeans - if (mlModel.getConnector() != null) { + // deploy remote model with internal connector or model trained by built-in algorithm like kmeans + if (mlModel.getConnector() != null || FunctionName.REMOTE != mlModel.getAlgorithm()) { setupPredictable(modelId, mlModel, params); wrappedListener.onResponse("successful"); return; @@ -756,6 +756,7 @@ public void deployModel( GetRequest getConnectorRequest = new GetRequest(); FetchSourceContext fetchContext = new FetchSourceContext(true, null, null); getConnectorRequest.index(ML_CONNECTOR_INDEX).id(mlModel.getConnectorId()).fetchSourceContext(fetchContext); + // get connector and deploy remote model with standalone connector client.get(getConnectorRequest, ActionListener.wrap(getResponse -> { if (getResponse != null && getResponse.isExists()) { try ( diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 3a623a1b65..0b0d7f2dd3 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -210,9 +210,9 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe FunctionName algorithm = mlInput.getAlgorithm(); // run predict if (modelId != null) { - try { - Predictable predictor = mlModelManager.getPredictor(modelId); - if (predictor != null) { + Predictable predictor = mlModelManager.getPredictor(modelId); + if (predictor != null) { + try { if (!predictor.isModelReady()) { throw new IllegalArgumentException("Model not ready: " + modelId); } @@ -226,11 +226,12 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe MLTaskResponse response = MLTaskResponse.builder().output(output).build(); internalListener.onResponse(response); return; - } else if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) { - throw new IllegalArgumentException("Model not ready to be used: " + modelId); + } catch (Exception e) { + handlePredictFailure(mlTask, internalListener, e, false); + return; } - } catch (Exception e) { - handlePredictFailure(mlTask, internalListener, e, false); + } else if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) { + throw new IllegalArgumentException("Model not ready to be used: " + modelId); } // search model by model id. @@ -249,6 +250,7 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe GetResponse getResponse = r; String algorithmName = getResponse.getSource().get(ALGORITHM_FIELD).toString(); MLModel mlModel = MLModel.parse(xContentParser, algorithmName); + mlModel.setModelId(modelId); User resourceUser = mlModel.getUser(); User requestUser = getUserContext(client); if (!checkUserPermissions(requestUser, resourceUser, modelId)) { @@ -260,7 +262,9 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe return; } // run predict - mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), mlTask.isAsync()); + if (mlTaskManager.contains(mlTask.getTaskId())) { + mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), mlTask.isAsync()); + } MLOutput output = mlEngine.predict(mlInput, mlModel); if (output instanceof MLPredictionOutput) { ((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());