From c6901c02506a02e1682d9ca9ff43595f5a914124 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Wed, 12 Jun 2024 08:18:37 +0800 Subject: [PATCH] Add more ITs Signed-off-by: zane-neo --- ...MultiModalConnectorPreProcessFunction.java | 2 +- .../ml/rest/MLCommonsRestTestCase.java | 10 +- .../ml/rest/RestBedRockInferenceIT.java | 133 +++++++++++++++--- .../BedRockMultiModalConnectorBodies.json | 4 +- 4 files changed, 127 insertions(+), 22 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java index 231c68c48d..008b1efe58 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/functions/preprocess/MultiModalConnectorPreProcessFunction.java @@ -48,7 +48,7 @@ public void validate(MLInput mlInput) { @Override public RemoteInferenceInputDataSet process(MLInput mlInput) { TextDocsInputDataSet inputData = (TextDocsInputDataSet) mlInput.getInputDataset(); - Map parametersMap = new HashMap<>(); + Map parametersMap = new HashMap<>(); parametersMap.put("inputText", inputData.getDocs().get(0)); if (inputData.getDocs().size() > 1) { parametersMap.put("inputImage", inputData.getDocs().get(1)); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 886494de3c..b7bf944e9b 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -59,6 +59,7 @@ import org.junit.Before; import org.opensearch.client.Request; import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; import org.opensearch.client.RestClient; import org.opensearch.client.RestClientBuilder; import org.opensearch.common.io.PathUtils; @@ -913,8 +914,13 @@ public Map predictTextEmbedding(String modelId) throws IOException { public Map predictTextEmbeddingModel(String modelId, MLInput input) throws IOException { String requestBody = TestHelper.toJsonString(input); - Response response = TestHelper - .makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null); + Response response = null; + try { + response = TestHelper + .makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null); + } catch (ResponseException e) { + response = e.getResponse(); + } return parseResponseToMap(response); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index b7a57ff223..9f5b31f5f8 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -8,6 +8,7 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Map; @@ -35,14 +36,70 @@ public void setup() throws IOException, InterruptedException { Thread.sleep(20000); } + public void test_bedrock_embedding_model() throws Exception { + // Skip test if key is null + if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { + return; + } + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json") + .toURI() + ) + ); + Map templateMap = StringUtils.gson.fromJson(templates, Map.class); + for (Map.Entry templateEntry : templateMap.entrySet()) { + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + String testCaseName = templateEntry.getKey(); + String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); + String modelId = registerRemoteModel( + String + .format( + StringUtils.gson.toJson(templateEntry.getValue()), + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN + ), + bedrockEmbeddingModelName, + true + ); + + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); + List output = (List) inferenceResult.get("inference_results"); + assertEquals(errorMsg, 2, output.size()); + assertTrue(errorMsg, output.get(0) instanceof Map); + assertTrue(errorMsg, output.get(1) instanceof Map); + validateOutput(errorMsg, (Map) output.get(0)); + validateOutput(errorMsg, (Map) output.get(1)); + } + } + + private void validateOutput(String errorMsg, Map output) { + assertTrue(errorMsg, output.containsKey("output")); + assertTrue(errorMsg, output.get("output") instanceof List); + List outputList = (List) output.get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertEquals(errorMsg, 1536, ((List) ((Map) outputList.get(0)).get("data")).size()); + } + public void test_bedrock_multimodal_model() throws Exception { - String imageBase64 = - "iVBORw0KGgoAAAANSUhEUgAAAEkAAAAaCAYAAAD7aXGFAAABXmlDQ1BJQ0MgUHJvZmlsZQAAKJFtkD9LA0EQxd+ZaEADRpRUFulUiBIvAbGMUVRIcUTFP5WXvTOJ5OJydyJ24mcQO1sRrCWFFn6EgKBoIYoI9uI1mpyzOfUSdYdlfjxmZmcf0BFWOS8HARgV28zNTsVWVtdioRd0UfRSTKjM4mlFyVIJvnP7ca4hiXw1KmZFjftG4PTtttS/3njar8l/69tOt6ZbjPIHXZlx0wakBLGyY3PBe8QDJi1FfCC44PGJ4LzHF82axVyGuEYcYUVVI34gjudb9EILG+Vt9rWD2D6sV5YWKEfpDmIaM8hSxKBARgrjmMQcefR/T6rZk8EWOHZhooQCirCpO00KRxk68TwqYBhDnFhGQswVXv/20Ne0ZyBp0FPDvrYZAc4doO/M14Ye6TtHwKXCVVP9cVZygtZG0vNf6qkCnYeu+7oMhEaA+o3rvlddt34MBO6o1/kEFollXGoMcoEAAABWZVhJZk1NACoAAAAIAAGHaQAEAAAAAQAAABoAAAAAAAOShgAHAAAAEgAAAESgAgAEAAAAAQAAAEmgAwAEAAAAAQAAABoAAAAAQVNDSUkAAABTY3JlZW5zaG90dJ8lxQAAAdRpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6bWV0YS8iIHg6eG1wdGs9IlhNUCBDb3JlIDYuMC4wIj4KICAgPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICAgICAgPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIKICAgICAgICAgICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iPgogICAgICAgICA8ZXhpZjpQaXhlbFlEaW1lbnNpb24+MjY8L2V4aWY6UGl4ZWxZRGltZW5zaW9uPgogICAgICAgICA8ZXhpZjpQaXhlbFhEaW1lbnNpb24+NzM8L2V4aWY6UGl4ZWxYRGltZW5zaW9uPgogICAgICAgICA8ZXhpZjpVc2VyQ29tbWVudD5TY3JlZW5zaG90PC9leGlmOlVzZXJDb21tZW50PgogICAgICA8L3JkZjpEZXNjcmlwdGlvbj4KICAgPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KaUYItQAABhNJREFUWAntWAtMFGcQHu4OH4g2UhW1PBtRW18V0aCpAQEVvYqpD0xMaqI2ISKlpqRGq0YxatMGa60FxUBsQmIaSWqkoYhK1UqgYK1UrCCoRahW8VUFNYaH3W/O//f22N3b1vpKmOT2n/fuPzv/zOx5+Pr6PqROMIyAzVD6EggtFiv5+Pjwk16/fu1fPbGPz6tksVjo3r17yu+uru1LH6SFC9+n+Pj5vMGpUyN0N6olyMnJJZvNSseO/UQbNqzRUmGeRVfykgiQCU8bnv4dnvYOnoH/ziCZCLJukJDGNptNKWxWTTeQ4acF7mwhHz06jAYNGqxlruK53qdv3340aVIMDRz4mkrPiPD29qawsHEUHj6BevTwNlLVlGnvUlFds2Y9TZgwkVpb28huj1IZDx36Jm3dup15+/Z9RxkZW1XynTu/IX//QLpz5zbNnRsnZXFx79Ls2fOoXz9f7ioQwH9Dw0XKzEynkyd/kbpAVq9eTxMnRlBLSyulpq5S6FTq1q0b65SWFtO6datU+q5Ely5dKC3tKxoy5A2VqL6+jtau/UTFMyJ0M6mkpJjtUP1DQtRvfPr0GdLnuHHhEhfIgAF+jJ4+fUqwKDbWTkuXLqP+/QfIAEEI/8HBr9PGjZ93yCzIAFarRdnURhkg8Fpb27EYwpYt6R0CBIOAgCBKT89SnsPD0F4IdYN05EgRPXzomDMjI2OEPq9hYWMl7dj04yOJgIrNHTiwn/VwtJYt+5hxZMWOHV8rGWWnBQvmUW7ut8y3Wq2ETfXu3Vv6FgiOp6enjVv1ihUpNGfODNq8eZMQa67JySky6Ddv3qRNm1Jp2rQogv3587Xk5eWlelmaTh4xdYPU0tJC1641stqYMY+D0rVrV2V468P8trY28vDwoIiIyEfuiKKiJjPe3t5OZWUljM+f/x7rgcCx2bs3l5qbm+nq1SuUlbWd9uzZzXo4Hnb7TMZdLzjWmGVwJJua7tD9+/ddVVR0TMxUprGPxMRFdPToj9Te3sb2SUkJhMCZBd0gwcHx42Xsx88vQPqLjp7CG25qaqKqqt+ZHx3teCAQyBrApUt/Kg/lOBJBQcHMa2ioV3z+zLjzJTs7U6k7LcwaOfItZ5HEc3J2Sdwd4uXVg/AyAYWF+XTr1i2VCYK1bdsXKp4RYRikgoLv2RapjnMMiIyM5rWysoIOHz7E+LBhw3nFRQQUhRWAo9KzZy/G6+r+4FXrcuPGdWYHBgZ1EOOIInvMQmjoGKl66tRvEndGKip+dSYNccMg1dbWyDccEzOFHYlOUVCQTwcP7ue6hTfXp09fQsYgoID8/Dxe0apxJAHoYnogjra3tyOgenpm+M6jRW3tWU0TfKuhXJgBwyDBwblzNewHdcnfP4A7DJyXl5fSgwcPSHxU2u1xhKMIwAfjlSt/MX758iXZADAW6AHGAkBzs/mM0fMlnhnykJAhmmp4sWgWZsBtkFDwAIGBwbKo1tc/zoiyslKWY6YKDXXUo+rqM8zDBXXp9u2/mUar1wN8kQMuXqzj9UkuJ04cl+YjRoySuDPifCSd+Vq42yCJNu7p6alkiqNzFRcflb7y8/cxjixDIAFFRQd4FRexcT8/fxo//m3BlmtCwlLlmHoyXVmpXUOksgkEnQ9ZDoiNfUfpxo4XIEzxFZGU9JEg3a5ug3T3brNsl716vcIORWBAXLhwnh8IqYuNYrYS2Sfujs4kZi5M0fhrA75QrxITP6RZs+JZFRvLy9srzJ5oLSz8ge1RIzMysrkU4BNn7Nhwhc7SnMf0bug2SDCsqDgh7dFlXFuq8/FqbLwqi70wQnakpX3KJAbNxYsTlCEyj3bt2k0zZ85iPupccnKCPJrC9r+u6elf0tmzVWyOAXX58lVKMylSZq3PeMLHCPO/FW7cRbwV4FqtE11OQHl5xzkIskOHCpXvvTTV/AQ+2jsK7cqVKeQ6Ipj59DDaaErKB3TmzGncRgXI/iVLFsk5TiXUIDye13/cw4eP4k7mGhiNZ3xiVvfu3Wnw4KHsp6am2u207nrD5xYk1wd5kWlTNelF3sCzeLbOIJmIcmeQOoNkIgImVP4BXZkNVryYcSoAAAAASUVORK5CYII="; // Skip test if key is null if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { log.info("#### The AWS credentials are not set. Skipping test. ####"); return; } + String imageBase64 = + "iVBORw0KGgoAAAANSUhEUgAAAEkAAAAaCAYAAAD7aXGFAAABXmlDQ1BJQ0MgUHJvZmlsZQAAKJFtkD9LA0EQxd+ZaEADRpRUFulUiBIvAbGMUVRIcUTFP5WXvTOJ5OJydyJ24mcQO1sRrCWFFn6EgKBoIYoI9uI1mpyzOfUSdYdlfjxmZmcf0BFWOS8HARgV28zNTsVWVtdioRd0UfRSTKjM4mlFyVIJvnP7ca4hiXw1KmZFjftG4PTtttS/3njar8l/69tOt6ZbjPIHXZlx0wakBLGyY3PBe8QDJi1FfCC44PGJ4LzHF82axVyGuEYcYUVVI34gjudb9EILG+Vt9rWD2D6sV5YWKEfpDmIaM8hSxKBARgrjmMQcefR/T6rZk8EWOHZhooQCirCpO00KRxk68TwqYBhDnFhGQswVXv/20Ne0ZyBp0FPDvrYZAc4doO/M14Ye6TtHwKXCVVP9cVZygtZG0vNf6qkCnYeu+7oMhEaA+o3rvlddt34MBO6o1/kEFollXGoMcoEAAABWZVhJZk1NACoAAAAIAAGHaQAEAAAAAQAAABoAAAAAAAOShgAHAAAAEgAAAESgAgAEAAAAAQAAAEmgAwAEAAAAAQAAABoAAAAAQVNDSUkAAABTY3JlZW5zaG90dJ8lxQAAAdRpVFh0WE1MOmNvbS5hZG9iZS54bXAAAAAAADx4OnhtcG1ldGEgeG1sbnM6eD0iYWRvYmU6bnM6bWV0YS8iIHg6eG1wdGs9IlhNUCBDb3JlIDYuMC4wIj4KICAgPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICAgICAgPHJkZjpEZXNjcmlwdGlvbiByZGY6YWJvdXQ9IiIKICAgICAgICAgICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iPgogICAgICAgICA8ZXhpZjpQaXhlbFlEaW1lbnNpb24+MjY8L2V4aWY6UGl4ZWxZRGltZW5zaW9uPgogICAgICAgICA8ZXhpZjpQaXhlbFhEaW1lbnNpb24+NzM8L2V4aWY6UGl4ZWxYRGltZW5zaW9uPgogICAgICAgICA8ZXhpZjpVc2VyQ29tbWVudD5TY3JlZW5zaG90PC9leGlmOlVzZXJDb21tZW50PgogICAgICA8L3JkZjpEZXNjcmlwdGlvbj4KICAgPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KaUYItQAABhNJREFUWAntWAtMFGcQHu4OH4g2UhW1PBtRW18V0aCpAQEVvYqpD0xMaqI2ISKlpqRGq0YxatMGa60FxUBsQmIaSWqkoYhK1UqgYK1UrCCoRahW8VUFNYaH3W/O//f22N3b1vpKmOT2n/fuPzv/zOx5+Pr6PqROMIyAzVD6EggtFiv5+Pjwk16/fu1fPbGPz6tksVjo3r17yu+uru1LH6SFC9+n+Pj5vMGpUyN0N6olyMnJJZvNSseO/UQbNqzRUmGeRVfykgiQCU8bnv4dnvYOnoH/ziCZCLJukJDGNptNKWxWTTeQ4acF7mwhHz06jAYNGqxlruK53qdv3340aVIMDRz4mkrPiPD29qawsHEUHj6BevTwNlLVlGnvUlFds2Y9TZgwkVpb28huj1IZDx36Jm3dup15+/Z9RxkZW1XynTu/IX//QLpz5zbNnRsnZXFx79Ls2fOoXz9f7ioQwH9Dw0XKzEynkyd/kbpAVq9eTxMnRlBLSyulpq5S6FTq1q0b65SWFtO6datU+q5Ely5dKC3tKxoy5A2VqL6+jtau/UTFMyJ0M6mkpJjtUP1DQtRvfPr0GdLnuHHhEhfIgAF+jJ4+fUqwKDbWTkuXLqP+/QfIAEEI/8HBr9PGjZ93yCzIAFarRdnURhkg8Fpb27EYwpYt6R0CBIOAgCBKT89SnsPD0F4IdYN05EgRPXzomDMjI2OEPq9hYWMl7dj04yOJgIrNHTiwn/VwtJYt+5hxZMWOHV8rGWWnBQvmUW7ut8y3Wq2ETfXu3Vv6FgiOp6enjVv1ihUpNGfODNq8eZMQa67JySky6Ddv3qRNm1Jp2rQogv3587Xk5eWlelmaTh4xdYPU0tJC1641stqYMY+D0rVrV2V468P8trY28vDwoIiIyEfuiKKiJjPe3t5OZWUljM+f/x7rgcCx2bs3l5qbm+nq1SuUlbWd9uzZzXo4Hnb7TMZdLzjWmGVwJJua7tD9+/ddVVR0TMxUprGPxMRFdPToj9Te3sb2SUkJhMCZBd0gwcHx42Xsx88vQPqLjp7CG25qaqKqqt+ZHx3teCAQyBrApUt/Kg/lOBJBQcHMa2ioV3z+zLjzJTs7U6k7LcwaOfItZ5HEc3J2Sdwd4uXVg/AyAYWF+XTr1i2VCYK1bdsXKp4RYRikgoLv2RapjnMMiIyM5rWysoIOHz7E+LBhw3nFRQQUhRWAo9KzZy/G6+r+4FXrcuPGdWYHBgZ1EOOIInvMQmjoGKl66tRvEndGKip+dSYNccMg1dbWyDccEzOFHYlOUVCQTwcP7ue6hTfXp09fQsYgoID8/Dxe0apxJAHoYnogjra3tyOgenpm+M6jRW3tWU0TfKuhXJgBwyDBwblzNewHdcnfP4A7DJyXl5fSgwcPSHxU2u1xhKMIwAfjlSt/MX758iXZADAW6AHGAkBzs/mM0fMlnhnykJAhmmp4sWgWZsBtkFDwAIGBwbKo1tc/zoiyslKWY6YKDXXUo+rqM8zDBXXp9u2/mUar1wN8kQMuXqzj9UkuJ04cl+YjRoySuDPifCSd+Vq42yCJNu7p6alkiqNzFRcflb7y8/cxjixDIAFFRQd4FRexcT8/fxo//m3BlmtCwlLlmHoyXVmpXUOksgkEnQ9ZDoiNfUfpxo4XIEzxFZGU9JEg3a5ug3T3brNsl716vcIORWBAXLhwnh8IqYuNYrYS2Sfujs4kZi5M0fhrA75QrxITP6RZs+JZFRvLy9srzJ5oLSz8ge1RIzMysrkU4BNn7Nhwhc7SnMf0bug2SDCsqDgh7dFlXFuq8/FqbLwqi70wQnakpX3KJAbNxYsTlCEyj3bt2k0zZ85iPupccnKCPJrC9r+u6elf0tmzVWyOAXX58lVKMylSZq3PeMLHCPO/FW7cRbwV4FqtE11OQHl5xzkIskOHCpXvvTTV/AQ+2jsK7cqVKeQ6Ipj59DDaaErKB3TmzGncRgXI/iVLFsk5TiXUIDye13/cw4eP4k7mGhiNZ3xiVvfu3Wnw4KHsp6am2u207nrD5xYk1wd5kWlTNelF3sCzeLbOIJmIcmeQOoNkIgImVP4BXZkNVryYcSoAAAAASUVORK5CYII="; String templates = Files .readString( Path @@ -87,9 +144,10 @@ public void test_bedrock_multimodal_model() throws Exception { } } - public void test_bedrock_embedding_model() throws Exception { + public void test_bedrock_multimodal_model_empty_imageInput() throws Exception { // Skip test if key is null if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { + log.info("#### The AWS credentials are not set. Skipping test. ####"); return; } String templates = Files @@ -98,7 +156,7 @@ public void test_bedrock_embedding_model() throws Exception { .of( RestMLPredictionAction.class .getClassLoader() - .getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json") + .getResource("org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json") .toURI() ) ); @@ -120,26 +178,67 @@ public void test_bedrock_embedding_model() throws Exception { true ); - TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello")).build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); assertTrue(errorMsg, inferenceResult.containsKey("inference_results")); List output = (List) inferenceResult.get("inference_results"); - assertEquals(errorMsg, 2, output.size()); + assertEquals(errorMsg, 1, output.size()); assertTrue(errorMsg, output.get(0) instanceof Map); - assertTrue(errorMsg, output.get(1) instanceof Map); - validateOutput(errorMsg, (Map) output.get(0)); - validateOutput(errorMsg, (Map) output.get(1)); + assertTrue(errorMsg, ((Map) output.get(0)).get("output") instanceof List); + List outputList = (List) ((Map) output.get(0)).get("output"); + assertEquals(errorMsg, 1, outputList.size()); + assertTrue(errorMsg, outputList.get(0) instanceof Map); + assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); + assertEquals(errorMsg, 1024, ((List) ((Map) outputList.get(0)).get("data")).size()); } } - private void validateOutput(String errorMsg, Map output) { - assertTrue(errorMsg, output.containsKey("output")); - assertTrue(errorMsg, output.get("output") instanceof List); - List outputList = (List) output.get("output"); - assertEquals(errorMsg, 1, outputList.size()); - assertTrue(errorMsg, outputList.get(0) instanceof Map); - assertTrue(errorMsg, ((Map) outputList.get(0)).get("data") instanceof List); - assertEquals(errorMsg, 1536, ((List) ((Map) outputList.get(0)).get("data")).size()); + public void test_bedrock_multimodal_model_empty_imageInput_null_textInput() throws Exception { + // Skip test if key is null + if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { + log.info("#### The AWS credentials are not set. Skipping test. ####"); + return; + } + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json") + .toURI() + ) + ); + Map templateMap = StringUtils.gson.fromJson(templates, Map.class); + for (Map.Entry templateEntry : templateMap.entrySet()) { + String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); + String testCaseName = templateEntry.getKey(); + String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); + String modelId = registerRemoteModel( + String + .format( + StringUtils.gson.toJson(templateEntry.getValue()), + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN + ), + bedrockEmbeddingModelName, + true + ); + + List input = new ArrayList<>(); + input.add(null); + TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(input).build(); + MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build(); + Map inferenceResult = predictTextEmbeddingModel(modelId, mlInput); + assertTrue(errorMsg, inferenceResult.containsKey("status")); + assertEquals(errorMsg, 400, inferenceResult.get("status")); + assertTrue(errorMsg, inferenceResult.containsKey("error")); + assertTrue(errorMsg, inferenceResult.get("error") instanceof Map); + assertEquals(errorMsg, "illegal_argument_exception", ((Map) inferenceResult.get("error")).get("type")); + assertEquals(errorMsg, "No input text or image provided", ((Map) inferenceResult.get("error")).get("reason")); + } } } diff --git a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json index ff8628fe2a..dbba50c434 100644 --- a/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json +++ b/plugin/src/test/resources/org/opensearch/ml/rest/templates/BedRockMultiModalConnectorBodies.json @@ -23,7 +23,7 @@ "content-type": "application/json", "x-amz-content-sha256": "required" }, - "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"inputImage\": \"${parameters.inputImage}\" }", + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"inputImage\": \"${parameters.inputImage:-null}\" }", "pre_process_function": "connector.pre_process.multimodal.embedding", "post_process_function": "connector.post_process.bedrock.embedding" } @@ -54,7 +54,7 @@ "content-type": "application/json", "x-amz-content-sha256": "required" }, - "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"inputImages\": \"${parameters.inputImages}\" }", + "request_body": "{ \"inputText\": \"${parameters.inputText}\", \"inputImage\": \"${parameters.inputImage:-null}\" }", "pre_process_function": "connector.pre_process.multimodal.embedding", "post_process_function": "connector.post_process.bedrock.embedding" }