From ab2d736b3f4871a8c596bc8746bb9edbfbbc7f26 Mon Sep 17 00:00:00 2001
From: zane-neo <zaniu@amazon.com>
Date: Mon, 14 Oct 2024 15:30:06 +0800
Subject: [PATCH] Add UT and ITs

Signed-off-by: zane-neo <zaniu@amazon.com>
---
 .../BedrockEmbeddingPreProcessFunction.java   |   4 +
 ...edrockEmbeddingPreProcessFunctionTest.java |   8 ++
 .../ml/rest/RestBedRockInferenceIT.java       | 102 ++++++++++++++++++
 .../BedRockEmbeddingV2ModelBodies.json        |  66 ++++++++++++
 4 files changed, 180 insertions(+)
 create mode 100644 plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json

diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java
index cbc140fcc1..34b72bee97 100644
--- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java
+++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunction.java
@@ -15,6 +15,9 @@
 import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
 import org.opensearch.ml.common.input.MLInput;
 
+import lombok.extern.slf4j.Slf4j;
+
+@Slf4j
 public class BedrockEmbeddingPreProcessFunction extends ConnectorPreProcessFunction {
 
     public BedrockEmbeddingPreProcessFunction() {
@@ -40,6 +43,7 @@ public RemoteInferenceInputDataSet process(Map<String, String> connectorParams,
         // Amazon Titan Text Embeddings V2 model: https://docs.aws.amazon.com/bedrock/latest/userguide/titan-embedding-models.html
         // Default dimension is 1024
         int dimensions = Optional.ofNullable(connectorParams.get("dimensions")).map(x -> NumberUtils.toInt(x, 1024)).orElse(1024);
+        log.error("The bedrock dimensions parameter value is: {}", dimensions);
         Map<String, Object> processedResult = Map
             .of("parameters", Map.of("inputText", inputData.getDocs().get(0), "dimensions", dimensions));
         return RemoteInferenceInputDataSet.builder().parameters(convertScriptStringToJsonString(processedResult)).build();
diff --git a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java
index eb6e023c34..228baec782 100644
--- a/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java
+++ b/common/src/test/java/org/opensearch/ml/common/connector/functions/preprocess/BedrockEmbeddingPreProcessFunctionTest.java
@@ -84,4 +84,12 @@ public void process_TextDocsInput_withConnectorParams() {
         assertEquals(2, dataSet.getParameters().size());
         assertEquals("1024", dataSet.getParameters().get("dimensions"));
     }
+
+    @Test
+    public void process_TextDocsInput_withoutConnectorParams() {
+        MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(textDocsInputDataSet).build();
+        RemoteInferenceInputDataSet dataSet = function.apply(Map.of(), mlInput);
+        assertEquals(2, dataSet.getParameters().size());
+        assertEquals("1024", dataSet.getParameters().get("dimensions"));
+    }
 }
diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java
index 286d45d308..d8e21471d9 100644
--- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java
+++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java
@@ -16,6 +16,7 @@
 import org.junit.Before;
 import org.opensearch.ml.common.FunctionName;
 import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
+import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
 import org.opensearch.ml.common.input.MLInput;
 import org.opensearch.ml.common.utils.StringUtils;
 
@@ -242,6 +243,107 @@ public void test_bedrock_multimodal_model_empty_imageInput_null_textInput() thro
         }
     }
 
+    public void test_bedrock_embedding_v2_model_with_connector_dimensions() throws Exception {
+        // Skip test if key is null
+        if (tokenNotSet()) {
+            return;
+        }
+        String templates = Files
+            .readString(
+                Path
+                    .of(
+                        RestMLPredictionAction.class
+                            .getClassLoader()
+                            .getResource("org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json")
+                            .toURI()
+                    )
+            );
+        Map<String, Object> templateMap = StringUtils.gson.fromJson(templates, Map.class);
+        String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5);
+        String testCaseName = "with_connector_dimensions";
+        String modelId = registerRemoteModel(
+            String
+                .format(
+                    StringUtils.gson.toJson(templateMap.get("with_connector_dimensions")),
+                    GITHUB_CI_AWS_REGION,
+                    AWS_ACCESS_KEY_ID,
+                    AWS_SECRET_ACCESS_KEY,
+                    AWS_SESSION_TOKEN
+                ),
+            bedrockEmbeddingModelName,
+            true
+        );
+
+        List<String> input = new ArrayList<>();
+        input.add("Can you tell me a joke?");
+        TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(input).build();
+        MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build();
+        Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput);
+        String errorMsg = String
+            .format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult));
+        assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));
+        List output = (List) inferenceResult.get("inference_results");
+        assertEquals(errorMsg, 1, output.size());
+        assertTrue(errorMsg, output.get(0) instanceof Map);
+        assertTrue(errorMsg, ((Map<?, ?>) output.get(0)).get("output") instanceof List);
+        List outputList = (List) ((Map<?, ?>) output.get(0)).get("output");
+        assertEquals(errorMsg, 1, outputList.size());
+        assertTrue(errorMsg, outputList.get(0) instanceof Map);
+        assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List);
+        assertEquals(errorMsg, 512, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size());
+    }
+
+    public void test_bedrock_embedding_v2_model_with_request_dimensions() throws Exception {
+        // Skip test if key is null
+        if (tokenNotSet()) {
+            return;
+        }
+        String templates = Files
+            .readString(
+                Path
+                    .of(
+                        RestMLPredictionAction.class
+                            .getClassLoader()
+                            .getResource("org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json")
+                            .toURI()
+                    )
+            );
+        Map<String, Object> templateMap = StringUtils.gson.fromJson(templates, Map.class);
+        String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5);
+        String testCaseName = "with_request_dimensions";
+        String modelId = registerRemoteModel(
+            String
+                .format(
+                    StringUtils.gson.toJson(templateMap.get("with_request_dimensions")),
+                    GITHUB_CI_AWS_REGION,
+                    AWS_ACCESS_KEY_ID,
+                    AWS_SECRET_ACCESS_KEY,
+                    AWS_SESSION_TOKEN
+                ),
+            bedrockEmbeddingModelName,
+            true
+        );
+
+        RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet
+            .builder()
+            .parameters(Map.of("inputText", "Can you tell me a joke?", "dimensions", "512"))
+            .build();
+        MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.REMOTE).build();
+        Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput);
+        String errorMsg = String
+            .format(Locale.ROOT, "Failing test case name: %s, inference result: %s", testCaseName, gson.toJson(inferenceResult));
+        assertTrue(errorMsg, inferenceResult.containsKey("inference_results"));
+        List output = (List) inferenceResult.get("inference_results");
+        assertEquals(errorMsg, 1, output.size());
+        assertTrue(errorMsg, output.get(0) instanceof Map);
+        assertTrue(errorMsg, ((Map<?, ?>) output.get(0)).get("output") instanceof List);
+        List outputList = (List) ((Map<?, ?>) output.get(0)).get("output");
+        assertEquals(errorMsg, 1, outputList.size());
+        assertTrue(errorMsg, outputList.get(0) instanceof Map);
+        assertTrue(errorMsg, ((Map<?, ?>) outputList.get(0)).get("data") instanceof List);
+        assertEquals(errorMsg, 512, ((List) ((Map<?, ?>) outputList.get(0)).get("data")).size());
+    }
+
     private boolean tokenNotSet() {
         if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) {
             log.info("#### The AWS credentials are not set. Skipping test. ####");
diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json
new file mode 100644
index 0000000000..a674843b94
--- /dev/null
+++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockEmbeddingV2ModelBodies.json
@@ -0,0 +1,66 @@
+{
+  "with_connector_dimensions": {
+    "name": "Amazon Bedrock Connector: embedding",
+    "description": "The connector to bedrock Titan embedding model",
+    "version": 1,
+    "protocol": "aws_sigv4",
+    "parameters": {
+      "region": "%s",
+      "service_name": "bedrock",
+      "model_name": "amazon.titan-embed-text-v2:0",
+      "input_docs_processed_step_size": "1",
+      "dimensions": "512"
+    },
+    "credential": {
+      "access_key": "%s",
+      "secret_key": "%s",
+      "session_token": "%s"
+    },
+    "actions": [
+      {
+        "action_type": "predict",
+        "method": "POST",
+        "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke",
+        "headers": {
+          "content-type": "application/json",
+          "x-amz-content-sha256": "required"
+        },
+        "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"dimensions\": ${parameters.dimensions}}",
+        "pre_process_function": "connector.pre_process.bedrock.embedding",
+        "post_process_function": "connector.post_process.bedrock.embedding"
+      }
+    ]
+  },
+
+  "with_request_dimensions": {
+    "name": "Amazon Bedrock Connector: embedding",
+    "description": "The connector to bedrock Titan embedding model",
+    "version": 1,
+    "protocol": "aws_sigv4",
+    "parameters": {
+      "region": "%s",
+      "service_name": "bedrock",
+      "model_name": "amazon.titan-embed-text-v2:0",
+      "input_docs_processed_step_size": "1"
+    },
+    "credential": {
+      "access_key": "%s",
+      "secret_key": "%s",
+      "session_token": "%s"
+    },
+    "actions": [
+      {
+        "action_type": "predict",
+        "method": "POST",
+        "url": "https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model_name}/invoke",
+        "headers": {
+          "content-type": "application/json",
+          "x-amz-content-sha256": "required"
+        },
+        "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"dimensions\": ${parameters.dimensions}}",
+        "pre_process_function": "connector.pre_process.bedrock.embedding",
+        "post_process_function": "connector.post_process.bedrock.embedding"
+      }
+    ]
+  }
+}