From 9c4df89a6ff551779309ca8f9f22863e3fb8ed25 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 11 Jun 2024 17:24:27 -0700 Subject: [PATCH] fix test Signed-off-by: Yaliang Wu --- .../preprocess/MultiModalConnectorPreProcessFunction.java | 1 + .../java/org/opensearch/ml/rest/RestBedRockInferenceIT.java | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java index 54d56f12f1..231c68c48d 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java @@ -11,6 +11,7 @@ import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import java.util.HashMap; import java.util.List; import java.util.Map; diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 4656acc6b7..b7a57ff223 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -73,7 +73,7 @@ public void test_bedrock_multimodal_model() throws Exception { TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", imageBase64)).build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); - Map inferenceResult = predictTextEmbedding(modelId, mlInput); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); List output = (List) inferenceResult.get("inference_results"); assertEquals(errorMsg, 1, output.size()); @@ -86,7 +86,7 @@ public void test_bedrock_multimodal_model() throws Exception { assertEquals(errorMsg, 1024, ((List) ((Map) outputList.get(0)).get("data")).size()); } } - + public void test_bedrock_embedding_model() throws Exception { // Skip test if key is null if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) {