Skip to content

Commit

Permalink
throw exception if remote model doesn't return 2xx status code; fix p… (
Browse files Browse the repository at this point in the history
#1473)

* throw exception if remote model doesn't return 2xx status code; fix predict runner

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* fix kmeans model deploy bug

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* support multiple docs for remote embedding model

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* fix ut

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

---------

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn authored Oct 10, 2023
1 parent 5fc555d commit 513ca39
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST);
}
String modelResponse = responseBuilder.toString();
if (statusCode < 200 || statusCode >= 300) {
throw new OpenSearchStatusException(modelResponse, RestStatus.fromCode(statusCode));
}

ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
tensors.setStatusCode(statusCode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ private static RemoteInferenceInputDataSet processTextDocsInput(TextDocsInputDat
docs.add(null);
}
}
if (preProcessFunction.contains("${parameters")) {
if (preProcessFunction.contains("${parameters.")) {
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
preProcessFunction = substitutor.replace(preProcessFunction);
}
Expand Down Expand Up @@ -164,7 +164,7 @@ public static ModelTensors processOutput(String modelResponse, Connector connect
// execute user defined painless script.
Optional<String> processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse);
String response = processedResponse.orElse(modelResponse);
boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent();
boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent() && org.opensearch.ml.common.utils.StringUtils.isJson(response);
if (responseFilter == null) {
connector.parseResponse(response, modelTensors, scriptReturnModelTensor);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.util.EntityUtils;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
import org.opensearch.rest.RestStatus;
import org.opensearch.script.ScriptService;

import java.security.AccessController;
Expand Down Expand Up @@ -103,9 +105,13 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
return null;
});
String modelResponse = responseRef.get();
Integer statusCode = statusCodeRef.get();
if (statusCode < 200 || statusCode >= 300) {
throw new OpenSearchStatusException(modelResponse, RestStatus.fromCode(statusCode));
}

ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
tensors.setStatusCode(statusCodeRef.get());
tensors.setStatusCode(statusCode);
tensorOutputs.add(tensors);
} catch (RuntimeException e) {
log.error("Fail to execute http connector", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,15 @@ default ModelTensorOutput executePredict(MLInput mlInput) {

if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset();
List<String> textDocs = new ArrayList<>(textDocsInputDataSet.getDocs());
preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tensorOutputs);
int processedDocs = 0;
while(processedDocs < textDocsInputDataSet.getDocs().size()) {
List<String> textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size());
List<ModelTensors> tempTensorOutputs = new ArrayList<>();
preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tempTensorOutputs);
processedDocs += Math.max(tempTensorOutputs.size(), 1);
tensorOutputs.addAll(tempTensorOutputs);
}

} else {
preparePayloadAndInvokeRemoteModel(mlInput, tensorOutputs);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,36 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
}

@Test
public void executePredict_RemoteInferenceInput_InvalidToken() throws IOException {
exceptionRule.expect(OpenSearchStatusException.class);
exceptionRule.expectMessage("{\"message\":\"The security token included in the request is invalid\"}");
String jsonString = "{\"message\":\"The security token included in the request is invalid\"}";
InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes());
AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
when(httpRequest.call()).thenReturn(response);
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
when(httpResponse.statusCode()).thenReturn(403);
when(response.httpResponse()).thenReturn(httpResponse);
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);

ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Map<String, String> credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
}

@Test
public void executePredict_RemoteInferenceInput() throws IOException {
String jsonString = "{\"key\":\"value\"}";
Expand Down Expand Up @@ -176,7 +206,7 @@ public void executePredict_TextDocsInferenceInput() throws IOException {
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input", "test input data")).build();
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.cluster.ClusterStateTaskConfig;
import org.opensearch.ingest.TestTemplateService;
import org.opensearch.ml.common.FunctionName;
Expand Down Expand Up @@ -120,12 +121,34 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size());
Assert.assertEquals("test result", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("response"));
}

@Test
public void executePredict_TextDocsInput_LimitExceed() throws IOException {
exceptionRule.expect(OpenSearchStatusException.class);
exceptionRule.expectMessage("{\"message\": \"Too many requests\"}");
ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": ${parameters.input}}")
.build();
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"message\": \"Too many requests\"}");
when(response.getEntity()).thenReturn(entity);
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 429, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
}

@Test
public void executePredict_TextDocsInput() throws IOException {
String preprocessResult1 = "{\"parameters\": { \"input\": \"test doc1\" } }";
Expand Down Expand Up @@ -161,7 +184,7 @@ public void executePredict_TextDocsInput() throws IOException {
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getData());
Assert.assertArrayEquals(new Number[] {-0.014555434, -0.002135904, 0.0035105038}, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(1).getData());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -746,8 +746,8 @@ public void deployModel(
CLUSTER_SERVICE,
clusterService
);
// deploy remote model or model trained by built-in algorithm like kmeans
if (mlModel.getConnector() != null) {
// deploy remote model with internal connector or model trained by built-in algorithm like kmeans
if (mlModel.getConnector() != null || FunctionName.REMOTE != mlModel.getAlgorithm()) {
setupPredictable(modelId, mlModel, params);
wrappedListener.onResponse("successful");
return;
Expand All @@ -756,6 +756,7 @@ public void deployModel(
GetRequest getConnectorRequest = new GetRequest();
FetchSourceContext fetchContext = new FetchSourceContext(true, null, null);
getConnectorRequest.index(ML_CONNECTOR_INDEX).id(mlModel.getConnectorId()).fetchSourceContext(fetchContext);
// get connector and deploy remote model with standalone connector
client.get(getConnectorRequest, ActionListener.wrap(getResponse -> {
if (getResponse != null && getResponse.isExists()) {
try (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,9 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
FunctionName algorithm = mlInput.getAlgorithm();
// run predict
if (modelId != null) {
try {
Predictable predictor = mlModelManager.getPredictor(modelId);
if (predictor != null) {
Predictable predictor = mlModelManager.getPredictor(modelId);
if (predictor != null) {
try {
if (!predictor.isModelReady()) {
throw new IllegalArgumentException("Model not ready: " + modelId);
}
Expand All @@ -226,11 +226,12 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
internalListener.onResponse(response);
return;
} else if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) {
throw new IllegalArgumentException("Model not ready to be used: " + modelId);
} catch (Exception e) {
handlePredictFailure(mlTask, internalListener, e, false);
return;
}
} catch (Exception e) {
handlePredictFailure(mlTask, internalListener, e, false);
} else if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) {
throw new IllegalArgumentException("Model not ready to be used: " + modelId);
}

// search model by model id.
Expand All @@ -249,6 +250,7 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
GetResponse getResponse = r;
String algorithmName = getResponse.getSource().get(ALGORITHM_FIELD).toString();
MLModel mlModel = MLModel.parse(xContentParser, algorithmName);
mlModel.setModelId(modelId);
User resourceUser = mlModel.getUser();
User requestUser = getUserContext(client);
if (!checkUserPermissions(requestUser, resourceUser, modelId)) {
Expand All @@ -260,7 +262,9 @@ private void predict(String modelId, MLTask mlTask, MLInput mlInput, ActionListe
return;
}
// run predict
mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), mlTask.isAsync());
if (mlTaskManager.contains(mlTask.getTaskId())) {
mlTaskManager.updateTaskStateAsRunning(mlTask.getTaskId(), mlTask.isAsync());
}
MLOutput output = mlEngine.predict(mlInput, mlModel);
if (output instanceof MLPredictionOutput) {
((MLPredictionOutput) output).setStatus(MLTaskState.COMPLETED.name());
Expand Down

0 comments on commit 513ca39

Please sign in to comment.