diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 1a773ac370..60801cc28a 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -67,7 +67,6 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.rest.SecureRestClientBuilder; -import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.DeprecationHandler; import org.opensearch.core.xcontent.MediaType; @@ -88,7 +87,6 @@ import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; -import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; @@ -915,7 +913,6 @@ public Map predictTextEmbedding(String modelId) throws IOException { public Map predictRemoteModel(String modelId, MLInput input) throws IOException { String requestBody = TestHelper.toJsonString(input); - System.out.println("############################## request body is:" + requestBody); Response response = TestHelper .makeRequest(client(), "POST", "/_plugins/_ml/_predict/TEXT_EMBEDDING/" + modelId, null, requestBody, null); return parseResponseToMap(response); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java index 4d831ef032..cda9ae54fa 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java @@ -5,13 +5,6 @@ package org.opensearch.ml.rest; -import lombok.SneakyThrows; -import org.junit.Before; -import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.dataset.TextDocsInputDataSet; -import org.opensearch.ml.common.input.MLInput; -import org.opensearch.ml.common.utils.StringUtils; - import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; @@ -19,6 +12,14 @@ import java.util.Locale; import java.util.Map; +import org.junit.Before; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.utils.StringUtils; + +import lombok.SneakyThrows; + public class RestBedRockInferenceIT extends MLCommonsRestTestCase { private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID"); private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY"); @@ -32,19 +33,38 @@ public void setup() throws IOException, InterruptedException { Thread.sleep(20000); } - public void test_bedrock_embedding_model() throws Exception { // Skip test if key is null if (AWS_ACCESS_KEY_ID == null || AWS_SECRET_ACCESS_KEY == null || AWS_SESSION_TOKEN == null) { return; } - String templates = Files.readString(Path.of(RestMLPredictionAction.class.getClassLoader().getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json").toURI())); + String templates = Files + .readString( + Path + .of( + RestMLPredictionAction.class + .getClassLoader() + .getResource("org/opensearch/ml/rest/templates/BedRockConnectorBodies.json") + .toURI() + ) + ); Map templateMap = StringUtils.gson.fromJson(templates, Map.class); for (Map.Entry templateEntry : templateMap.entrySet()) { String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); String testCaseName = templateEntry.getKey(); String errorMsg = String.format(Locale.ROOT, "Failing test case name: %s", testCaseName); - String modelId = registerRemoteModel(String.format(StringUtils.gson.toJson(templateEntry.getValue()), GITHUB_CI_AWS_REGION, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN), bedrockEmbeddingModelName, true); + String modelId = registerRemoteModel( + String + .format( + StringUtils.gson.toJson(templateEntry.getValue()), + GITHUB_CI_AWS_REGION, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + AWS_SESSION_TOKEN + ), + bedrockEmbeddingModelName, + true + ); TextDocsInputDataSet inputDataSet = TextDocsInputDataSet.builder().docs(List.of("hello", "world")).build(); MLInput mlInput = MLInput.builder().inputDataset(inputDataSet).algorithm(FunctionName.TEXT_EMBEDDING).build();