Skip to content

Commit

Permalink
format code
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Feb 5, 2024
1 parent 4fd86fb commit cb9093b
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,10 @@ public static ModelTensors processOutput(
}

public static ModelTensors processErrorResponse(String errorResponse) {
return ModelTensors.builder().mlModelTensors(List.of(ModelTensor.builder().dataAsMap(Map.of("remote_response", errorResponse)).build())).build();
return ModelTensors
.builder()
.mlModelTensors(List.of(ModelTensor.builder().dataAsMap(Map.of("remote_response", errorResponse)).build()))
.build();
}

private static String fillProcessFunctionParameter(Map<String, String> parameters, String processFunction) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processErrorResponse;
import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
Expand All @@ -20,9 +19,7 @@
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.CompletableFuture;

import com.google.gson.Gson;
import org.apache.http.HttpStatus;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.OpenSearchStatusException;
Expand All @@ -36,6 +33,8 @@
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;

import com.google.gson.Gson;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import software.amazon.awssdk.http.SdkHttpFullResponse;
Expand Down Expand Up @@ -113,7 +112,12 @@ public void onError(Throwable error) {
}
}

private void processResponse(Integer statusCode, String body, Map<String, String> parameters, Map<Integer, ModelTensors> tensorOutputs) {
private void processResponse(
Integer statusCode,
String body,
Map<String, String> parameters,
Map<Integer, ModelTensors> tensorOutputs
) {
ModelTensors tensors;
if (Strings.isBlank(body)) {
log.error("Remote model response body is empty!");
Expand All @@ -140,14 +144,24 @@ private void reOrderTensorResponses(Map<Integer, ModelTensors> tensorOutputs) {
TreeMap<Integer, ModelTensors> sortedMap = new TreeMap<>(tensorOutputs);
log.debug("Reordered tensor outputs size is {}", sortedMap.size());
if (tensorOutputs.size() == 1) {
//batch API case
// batch API case
int status = tensorOutputs.get(0).getStatusCode();
if (status == HttpStatus.SC_OK) {
modelTensors.add(tensorOutputs.get(0));
actionListener.onResponse(modelTensors);
} else {
try {
actionListener.onFailure(new OpenSearchStatusException(AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(tensorOutputs.get(0).getMlModelTensors().get(0).getDataAsMap())), RestStatus.fromCode(status)));
actionListener
.onFailure(
new OpenSearchStatusException(
AccessController
.doPrivileged(
(PrivilegedExceptionAction<String>) () -> gson
.toJson(tensorOutputs.get(0).getMlModelTensors().get(0).getDataAsMap())
),
RestStatus.fromCode(status)
)
);
} catch (PrivilegedActionException e) {
actionListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.fromCode(statusCode)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import java.io.IOException;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -47,12 +48,16 @@
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams;
import org.opensearch.ml.common.output.MLPredictionOutput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.engine.indices.MLInputDatasetHandler;
Expand Down Expand Up @@ -295,6 +300,72 @@ public void testExecuteTask_OnLocalNode_NullModelIdException() {
assertEquals("ModelId is invalid", argumentCaptor.getValue().getMessage());
}

public void testExecuteTask_OnLocalNode_remoteModel_success() {
setupMocks(true, false, false, false);
TextDocsInputDataSet textDocsInputDataSet = new TextDocsInputDataSet(List.of("hello", "world"), null);
MLPredictionTaskRequest textDocsInputRequest = MLPredictionTaskRequest
.builder()
.modelId("test_model")
.mlInput(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(textDocsInputDataSet).build())
.build();
Predictable predictor = mock(Predictable.class);
when(predictor.isModelReady()).thenReturn(true);
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(1);
actionListener
.onResponse(MLTaskResponse.builder().output(ModelTensorOutput.builder().mlModelOutputs(List.of()).build()).build());
return null;
}).when(predictor).asyncPredict(any(), any());
when(mlModelManager.getPredictor(anyString())).thenReturn(predictor);
when(mlModelManager.getWorkerNodes(anyString(), eq(FunctionName.REMOTE), eq(true))).thenReturn(new String[] { "node1" });
taskRunner.dispatchTask(FunctionName.REMOTE, textDocsInputRequest, transportService, listener);
verify(client, never()).get(any(), any());
ArgumentCaptor<MLTaskResponse> argumentCaptor = ArgumentCaptor.forClass(MLTaskResponse.class);
verify(listener).onResponse(argumentCaptor.capture());
assert argumentCaptor.getValue().getOutput() instanceof ModelTensorOutput;
}

public void testExecuteTask_OnLocalNode_localModel_success() {
setupMocks(true, false, false, false);
TextDocsInputDataSet textDocsInputDataSet = new TextDocsInputDataSet(List.of("hello", "world"), null);
MLPredictionTaskRequest textDocsInputRequest = MLPredictionTaskRequest
.builder()
.modelId("test_model")
.mlInput(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build())
.build();
Predictable predictor = mock(Predictable.class);
when(predictor.isModelReady()).thenReturn(true);
when(mlModelManager.getPredictor(anyString())).thenReturn(predictor);
when(mlModelManager.getWorkerNodes(anyString(), eq(FunctionName.TEXT_EMBEDDING), eq(true))).thenReturn(new String[] { "node1" });
when(mlModelManager.trackPredictDuration(anyString(), any())).thenReturn(mock(MLPredictionOutput.class));
taskRunner.dispatchTask(FunctionName.TEXT_EMBEDDING, textDocsInputRequest, transportService, listener);
verify(client, never()).get(any(), any());
ArgumentCaptor<MLTaskResponse> argumentCaptor = ArgumentCaptor.forClass(MLTaskResponse.class);
verify(listener).onResponse(argumentCaptor.capture());
assert argumentCaptor.getValue().getOutput() instanceof MLPredictionOutput;
}

public void testExecuteTask_OnLocalNode_prediction_exception() {
setupMocks(true, false, false, false);
TextDocsInputDataSet textDocsInputDataSet = new TextDocsInputDataSet(List.of("hello", "world"), null);
MLPredictionTaskRequest textDocsInputRequest = MLPredictionTaskRequest
.builder()
.modelId("test_model")
.mlInput(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build())
.build();
Predictable predictable = mock(Predictable.class);
when(mlModelManager.getPredictor(anyString())).thenReturn(predictable);
when(predictable.isModelReady()).thenThrow(new RuntimeException("runtime exception"));
when(mlModelManager.getWorkerNodes(anyString(), eq(FunctionName.TEXT_EMBEDDING), eq(true))).thenReturn(new String[] { "node1" });
when(mlModelManager.trackPredictDuration(anyString(), any())).thenReturn(mock(MLPredictionOutput.class));
taskRunner.dispatchTask(FunctionName.TEXT_EMBEDDING, textDocsInputRequest, transportService, listener);
verify(client, never()).get(any(), any());
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener).onFailure(argumentCaptor.capture());
assert argumentCaptor.getValue() instanceof RuntimeException;
assertEquals("runtime exception", argumentCaptor.getValue().getMessage());
}

public void testExecuteTask_OnLocalNode_NullGetResponse() {
setupMocks(true, false, false, true);

Expand Down

0 comments on commit cb9093b

Please sign in to comment.